1#ifndef __DUNE_ACFEM_TENSORS_OPTIMIZATION_TRANSPOSE_HH__
2#define __DUNE_ACFEM_TENSORS_OPTIMIZATION_TRANSPOSE_HH__
4#include "../../expressions/optimization.hh"
5#include "../../mpl/sequencesetoperations.hh"
6#include "../operations/transpose.hh"
7#include "../operations/transposetraits.hh"
8#include "../modules.hh"
14 namespace Tensor::Optimization {
18 using SubExprTag = Policy::SubExpressionOptimizationTag;
19 using DefaultTag = Policy::DefaultOptimizationTag;
22 template<class Permutation, class T, std::enable_if_t<IsTransposeExpression<T>::value,
int> = 0>
23 constexpr decltype(
auto)
operate(DefaultTag, F<TransposeOperation<Permutation> >, T&& t)
25 DUNE_ACFEM_RECORD_OPTIMIZATION;
28 using TType = std::decay_t<T>;
29 using Perm = PermuteSequence<Permutation, typename TType::Permutation>;
31 DUNE_ACFEM_EXPRESSION_RESULT(
32 operate<TransposeOperation<Perm> >(std::forward<T>(t).operand(0_c))
33 ,
"nested transpositions"
38 template<
class Permutation,
class T,
39 std::enable_if_t<IsSelfTransposed<Permutation, T>::value,
int> = 0>
40 constexpr decltype(
auto)
operate(OptimizeTerminal1, F<TransposeOperation<Permutation> >, T&& t)
42 DUNE_ACFEM_RECORD_OPTIMIZATION;
45 return forwardReturnValue<T>(t);
49 template<
class Permutation,
class T,
50 std::enable_if_t<(IsUnaryMinusExpression<T>::value
53 constexpr decltype(
auto)
operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
55 DUNE_ACFEM_RECORD_OPTIMIZATION;
57 DUNE_ACFEM_EXPRESSION_RESULT(
58 operate<MinusOperation>(
59 operate<TransposeOperation<Permutation> >(std::forward<T>(t).operand(0_c))
65 template<
class Permutation,
class T,
class SFINAE =
void>
66 struct IsTransposeOfLeftProductWithScalar
67 : BoolConstant<TensorTraits<Operand<0, T> >::rank == 0>
70 template<
class Permutation,
class T>
71 struct IsTransposeOfLeftProductWithScalar<
73 std::enable_if_t<!IsProductExpression<T>::value> >
80 template<
class Permutation,
class T,
81 std::enable_if_t<IsTransposeOfLeftProductWithScalar<Permutation, T>::value,
int> = 0>
82 constexpr decltype(
auto)
operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
84 DUNE_ACFEM_RECORD_OPTIMIZATION;
86 DUNE_ACFEM_EXPRESSION_RESULT(
88 std::forward<T>(t).operation(),
89 std::forward<T>(t).operand(0_c),
90 operate<TransposeOperation<Permutation> >(
91 std::forward<T>(t).operand(1_c)
94 ,
"transpose inside product"
101 template<
class Permutation,
class T,
102 std::enable_if_t<(IsConstKroneckerDelta<T>::value
105 constexpr decltype(
auto)
operate(OptimizeTerminal0, F<TransposeOperation<Permutation> >, T&& t)
107 DUNE_ACFEM_RECORD_OPTIMIZATION;
109 using TType = std::decay_t<T>;
110 using Inverse = InversePermutation<Permutation>;
111 using Signature = PermuteSequence<typename TType::Signature, Inverse>;
112 using Pivots = PermuteSequence<typename std::decay_t<T>::PivotSequence, Inverse>;
113 using ValueType =
typename TType::ValueType;
116 return KroneckerDelta<Signature, Pivots, ValueType>{};
123 template<
class T,
class Perm,
class SFINAE =
void>
126 using Einsum = std::decay_t<T>;
127 static constexpr std::size_t leftRank_ = TensorTraits<Operand<0, Einsum> >::rank - Einsum::defectRank_;
128 static constexpr std::size_t rightRank_ = TensorTraits<Operand<1, Einsum> >::rank - Einsum::defectRank_;
138 static constexpr bool isCommutation = std::is_same<Perm, CommutationPermutation>::value;
143 using LeftPermutation = SequenceSlice<Perm, LeftPos>;
144 using RightPermutation = SequenceSlice<Perm, RightPos>;
145 using LeftNormalization =
typename SortSequence<LeftPermutation>::Permutation;
146 using RightNormalization =
typename SortSequence<RightPermutation>::Permutation;
148 static constexpr bool leftNormalizable =
149 IsSelfTransposed<LeftNormalization, Operand<0, Einsum> >::value
152 static constexpr bool rightNormalizable =
153 IsSelfTransposed<RightNormalization, Operand<1, Einsum> >::value
157 static constexpr bool isNormalizable = (leftNormalizable || rightNormalizable);
158 using Normalization =
160 ConditionalType<rightNormalizable, OffsetSequence<leftRank_, RightNormalization>, RightPos> >;
166 static_assert(isCommutation
168 (std::is_same<LeftTransposed, LeftCommuted>::value
170 std::is_same<RightTransposed, RightCommuted>::value),
171 "Inconsistent commutation permutation.");
175 template<
class A,
class B>
179 template<
class Permutation,
class T,
class SFINAE =
void>
180 struct IsCommutationOfEinsumTranspose
181 :
BoolConstant<(!IsSelfTransposed<Permutation, T>::value
182 && (IsEye<Operand<1, T> >::value
184 !IsEye<Operand<0, T> >::value)
185 && EinsumCommutationHelper<T, Permutation>::isCommutation)>
188 template<
class Permutation,
class T>
189 struct IsCommutationOfEinsumTranspose<
191 std::enable_if_t<!IsEinsumExpression<T>::value> >
196 template<
class Permutation,
class T,
197 std::enable_if_t<IsCommutationOfEinsumTranspose<Permutation, T>::value,
int> = 0>
200 DUNE_ACFEM_RECORD_OPTIMIZATION;
202 DUNE_ACFEM_EXPRESSION_RESULT(
204 std::forward<T>(t).operand(1_c),
205 std::forward<T>(t).operand(0_c)
207 ,
"transpose commutation"
212 template<
class Permutation,
class T,
class SFINAE =
void>
214 :
BoolConstant<(!IsSelfTransposed<Permutation, T>::value
215 && !IsEye<Operand<0, T> >::value
216 && IsEye<Operand<1, T> >::value
217 && !EinsumCommutationHelper<T, Permutation>::isCommutation)>
220 template<
class Permutation,
class T>
223 std::enable_if_t<!IsEinsumExpression<T>::value> >
231 template<
class Permutation,
class T,
232 std::enable_if_t<IsTransposeOfEinsumWithRightEye<Permutation, T>::value,
int> = 0>
235 DUNE_ACFEM_RECORD_OPTIMIZATION;
237 using Traits = EinsumCommutationHelper<T, Permutation>;
239 DUNE_ACFEM_EXPRESSION_RESULT(
243 F<TransposeOperation<typename Traits::InversePermutation> >{},
245 std::forward<T>(t).operation(),
246 std::forward<T>(t).operand(1_c),
247 std::forward<T>(t).operand(0_c)
249 ,
"transpose normalize eyes"
253 template<
class Permutation,
class T,
class SFINAE =
void>
254 struct IsNormalizableTranspose
255 :
BoolConstant<EinsumCommutationHelper<T, Permutation>::isNormalizable>
258 template<
class Permutation,
class T>
259 struct IsNormalizableTranspose<
261 std::enable_if_t<(IsSelfTransposed<Permutation, T>::value
262 || !IsEinsumExpression<T>::value
268 template<
class Permutation,
class T,
269 std::enable_if_t<IsNormalizableTranspose<Permutation, T>::value,
int> = 0>
270 constexpr decltype(
auto)
operate(DefaultTag, F<TransposeOperation<Permutation> >, T&& t)
272 DUNE_ACFEM_RECORD_OPTIMIZATION;
274 using Traits = EinsumCommutationHelper<T, Permutation>;
276 DUNE_ACFEM_EXPRESSION_RESULT(
277 operate<TransposeOperation<Permutation> >(
279 F<TransposeOperation<typename Traits::Normalization> >{},
282 ,
"sort transpositions"
301 template<
class F,
class T0,
class T1,
class SFINAE =
void>
303 MPL::SequenceSetMinus<MakeIndexSequence<TensorTraits<T0>::rank>,
typename EinsumTraits<F>::LeftIndexPositions>,
304 typename TransposeTraits<T0>::Permutation>;
306 template<
class F,
class T0,
class T1>
307 constexpr inline bool IsLeftSelfTransposedContractionV<
309 std::enable_if_t<(!FunctorHas<IsEinsumOperation, F>::value
310 || std::is_same<F, ScalarEinsumFunctor>::value
311 || !IsTransposeExpression<T0>::value
320 template<
class F,
class T0,
class T1,
321 std::enable_if_t<IsLeftSelfTransposedContractionV<F, T0, T1>,
int> = 0>
322 constexpr decltype(
auto)
operate(DefaultTag, F, T0&& t0, T1&& t1)
324 DUNE_ACFEM_RECORD_OPTIMIZATION;
326 using Inverse = InversePermutation<typename TransposeTraits<T0>::Permutation>;
327 using Traits = EinsumTraits<F>;
328 using Operation = EinsumOperation<
329 PermuteSequenceValues<typename Traits::LeftIndexPositions, Inverse>,
330 typename Traits::RightIndexPositions,
331 typename Traits::Dimensions>;
333 DUNE_ACFEM_EXPRESSION_RESULT(
335 std::forward<T0>(t0).operand(0_c),
336 std::forward<T1>(t1))
337 ,
"left self transposed contraction"
347 template<
class F,
class T0,
class T1,
class SFINAE =
void>
349 MPL::SequenceSetMinus<MakeIndexSequence<TensorTraits<T1>::rank>,
typename EinsumTraits<F>::LeftIndexPositions>,
350 typename TransposeTraits<T1>::Permutation>;
352 template<
class F,
class T0,
class T1>
353 constexpr inline bool IsRightSelfTransposedContractionV<
355 std::enable_if_t<(!FunctorHas<IsEinsumOperation, F>::value
356 || std::is_same<F, ScalarEinsumFunctor>::value
357 || !IsTransposeExpression<T1>::value
366 template<
class F,
class T0,
class T1,
367 std::enable_if_t<IsRightSelfTransposedContractionV<F, T0, T1>,
int> = 0>
368 constexpr decltype(
auto)
operate(OptimizeNext<DefaultTag>, F, T0&& t0, T1&& t1)
370 DUNE_ACFEM_RECORD_OPTIMIZATION;
372 using Inverse = InversePermutation<typename TransposeTraits<T1>::Permutation>;
373 using Traits = EinsumTraits<F>;
374 using Operation = EinsumOperation<
375 typename Traits::LeftIndexPositions,
376 PermuteSequenceValues<typename Traits::RightIndexPositions, Inverse>,
377 typename Traits::Dimensions>;
379 DUNE_ACFEM_EXPRESSION_RESULT(
381 std::forward<T0>(t0),
382 std::forward<T1>(t1).operand(0_c)
384 ,
"right self transposed contraction"
394 namespace Expressions {
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
MakeSequence< std::size_t, N, Offset, Stride, Repeat > MakeIndexSequence
Make a sequence of std::size_t elements.
Definition: generators.hh:34
constexpr auto operate(Expressions::DontOptimize, OperationTraits< TransposeOperation< Perm > >, T &&t)
Definition: transpose.hh:175
constexpr bool HasInvariantValuesV
Evaluate to true if applying the permutation Perm to the values of Seq leaves the values invariant.
Definition: permutation.hh:112
constexpr bool isSimple(Sequence< T, V... > seq)
Definition: compare.hh:334
typename PermuteSequenceValuesHelper< Perm, Sequence >::Type PermuteSequenceValues
Apply the given permutation to the values of the given sequence.
Definition: permutation.hh:101
Constant< bool, V > BoolConstant
Short-cut for integral constant of type bool.
Definition: types.hh:48
BoolConstant< false > FalseType
Alias for std::false_type.
Definition: types.hh:110
Definition: transpose.hh:125
Definition: transpose.hh:218
Permutation of index positions of tensors.
Definition: expressionoperations.hh:167