Loading [MathJax]/extensions/TeX/AMSsymbols.js

DUNE-ACFEM (unstable)

transpose.hh
1#ifndef __DUNE_ACFEM_TENSORS_OPTIMIZATION_TRANSPOSE_HH__
2#define __DUNE_ACFEM_TENSORS_OPTIMIZATION_TRANSPOSE_HH__
3
4#include "../../expressions/optimization.hh"
5#include "../../mpl/sequencesetoperations.hh"
6#include "../operations/transpose.hh"
7#include "../operations/transposetraits.hh"
8#include "../modules.hh"
9
10namespace Dune {
11
12 namespace ACFem {
13
14 namespace Tensor::Optimization {
15
16 namespace Transpose {
17
18 using SubExprTag = Policy::SubExpressionOptimizationTag;
19 using DefaultTag = Policy::DefaultOptimizationTag;
20
22 template<class Permutation, class T, std::enable_if_t<IsTransposeExpression<T>::value, int> = 0>
23 constexpr decltype(auto) operate(DefaultTag, F<TransposeOperation<Permutation> >, T&& t)
24 {
25 DUNE_ACFEM_RECORD_OPTIMIZATION;
26
27 // permute the outer permutation with the inner
28 using TType = std::decay_t<T>;
29 using Perm = PermuteSequence<Permutation, typename TType::Permutation>;
30
31 DUNE_ACFEM_EXPRESSION_RESULT(
32 operate<TransposeOperation<Perm> >(std::forward<T>(t).operand(0_c))
33 , "nested transpositions"
34 );
35 }
36
38 template<class Permutation, class T,
39 std::enable_if_t<IsSelfTransposed<Permutation, T>::value, int> = 0>
40 constexpr decltype(auto) operate(OptimizeTerminal1, F<TransposeOperation<Permutation> >, T&& t)
41 {
42 DUNE_ACFEM_RECORD_OPTIMIZATION;
43
45 return forwardReturnValue<T>(t);
46 }
47
49 template<class Permutation, class T,
50 std::enable_if_t<(IsUnaryMinusExpression<T>::value
51 && Permutation::size() == TensorTraits<T>::rank
52 ), int> = 0>
53 constexpr decltype(auto) operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
54 {
55 DUNE_ACFEM_RECORD_OPTIMIZATION;
56
57 DUNE_ACFEM_EXPRESSION_RESULT(
58 operate<MinusOperation>(
59 operate<TransposeOperation<Permutation> >(std::forward<T>(t).operand(0_c))
60 )
61 , "self transpose"
62 );
63 }
64
65 template<class Permutation, class T, class SFINAE = void>
66 struct IsTransposeOfLeftProductWithScalar
67 : BoolConstant<TensorTraits<Operand<0, T> >::rank == 0>
68 {};
69
70 template<class Permutation, class T>
71 struct IsTransposeOfLeftProductWithScalar<
72 Permutation, T,
73 std::enable_if_t<!IsProductExpression<T>::value> >
74 : FalseType
75 {};
76
80 template<class Permutation, class T,
81 std::enable_if_t<IsTransposeOfLeftProductWithScalar<Permutation, T>::value, int> = 0>
82 constexpr decltype(auto) operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
83 {
84 DUNE_ACFEM_RECORD_OPTIMIZATION;
85
86 DUNE_ACFEM_EXPRESSION_RESULT(
87 operate(
88 std::forward<T>(t).operation(),
89 std::forward<T>(t).operand(0_c),
90 operate<TransposeOperation<Permutation> >(
91 std::forward<T>(t).operand(1_c)
92 )
93 )
94 , "transpose inside product"
95 );
96 }
97
101 template<class Permutation, class T,
102 std::enable_if_t<(IsConstKroneckerDelta<T>::value
103 && Permutation::size() <= TensorTraits<T>::rank
104 ), int> = 0>
105 constexpr decltype(auto) operate(OptimizeTerminal0, F<TransposeOperation<Permutation> >, T&& t)
106 {
107 DUNE_ACFEM_RECORD_OPTIMIZATION;
108
109 using TType = std::decay_t<T>;
110 using Inverse = InversePermutation<Permutation>;
111 using Signature = PermuteSequence<typename TType::Signature, Inverse>;
112 using Pivots = PermuteSequence<typename std::decay_t<T>::PivotSequence, Inverse>;
113 using ValueType = typename TType::ValueType;
114
116 return KroneckerDelta<Signature, Pivots, ValueType>{};
117 }
118
123 template<class T, class Perm, class SFINAE = void>
125 {
126 using Einsum = std::decay_t<T>;
127 static constexpr std::size_t leftRank_ = TensorTraits<Operand<0, Einsum> >::rank - Einsum::defectRank_;
128 static constexpr std::size_t rightRank_ = TensorTraits<Operand<1, Einsum> >::rank - Einsum::defectRank_;
129 using LeftPos = MakeIndexSequence<leftRank_>;
130 using RightPos = MakeIndexSequence<rightRank_, leftRank_>; // offset
131
132 using LeftCommuted = MakeIndexSequence<leftRank_, rightRank_>;
133 using RightCommuted = MakeIndexSequence<rightRank_>;
134
135 using CommutationPermutation = SequenceCat<LeftCommuted, RightCommuted>;
136 using InversePermutation = SequenceCat<RightPos, LeftPos>;
137
138 static constexpr bool isCommutation = std::is_same<Perm, CommutationPermutation>::value;
139
140 // we try to normalize permutations if the operands are
141 // invariant w.r.t. to the permutation which results in
142 // ascending permutation indices.
143 using LeftPermutation = SequenceSlice<Perm, LeftPos>;
144 using RightPermutation = SequenceSlice<Perm, RightPos>;
145 using LeftNormalization = typename SortSequence<LeftPermutation>::Permutation;
146 using RightNormalization = typename SortSequence<RightPermutation>::Permutation;
147
148 static constexpr bool leftNormalizable =
149 IsSelfTransposed<LeftNormalization, Operand<0, Einsum> >::value
150 &&
151 !isSimple(LeftNormalization{});
152 static constexpr bool rightNormalizable =
153 IsSelfTransposed<RightNormalization, Operand<1, Einsum> >::value
154 &&
155 !isSimple(RightNormalization{});
156
157 static constexpr bool isNormalizable = (leftNormalizable || rightNormalizable);
158 using Normalization =
160 ConditionalType<rightNormalizable, OffsetSequence<leftRank_, RightNormalization>, RightPos> >;
161
162#ifndef NDEBUG
163 using LeftTransposed = PermuteSequenceValues<LeftPos, Perm>;
164 using RightTransposed = PermuteSequenceValues<RightPos, Perm>;
165
166 static_assert(isCommutation
167 ==
168 (std::is_same<LeftTransposed, LeftCommuted>::value
169 &&
170 std::is_same<RightTransposed, RightCommuted>::value),
171 "Inconsistent commutation permutation.");
172#endif
173 };
174
175 template<class A, class B>
176 struct EinsumCommutationHelper<A, B, std::enable_if_t<!IsEinsumExpression<A>::value> >
177 {};
178
179 template<class Permutation, class T, class SFINAE = void>
180 struct IsCommutationOfEinsumTranspose
181 : BoolConstant<(!IsSelfTransposed<Permutation, T>::value
182 && (IsEye<Operand<1, T> >::value // don't undo move eye left
183 ||
184 !IsEye<Operand<0, T> >::value)
185 && EinsumCommutationHelper<T, Permutation>::isCommutation)>
186 {};
187
188 template<class Permutation, class T>
189 struct IsCommutationOfEinsumTranspose<
190 Permutation, T,
191 std::enable_if_t<!IsEinsumExpression<T>::value> >
192 : FalseType
193 {};
194
196 template<class Permutation, class T,
197 std::enable_if_t<IsCommutationOfEinsumTranspose<Permutation, T>::value, int> = 0>
198 constexpr decltype(auto) operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
199 {
200 DUNE_ACFEM_RECORD_OPTIMIZATION;
201
202 DUNE_ACFEM_EXPRESSION_RESULT(
203 operate<Operation<T> >(
204 std::forward<T>(t).operand(1_c),
205 std::forward<T>(t).operand(0_c)
206 )
207 , "transpose commutation"
208 );
209 }
210
212 template<class Permutation, class T, class SFINAE = void>
214 : BoolConstant<(!IsSelfTransposed<Permutation, T>::value
215 && !IsEye<Operand<0, T> >::value // no point in doing so
216 && IsEye<Operand<1, T> >::value
217 && !EinsumCommutationHelper<T, Permutation>::isCommutation)>
218 {};
219
220 template<class Permutation, class T>
222 Permutation, T,
223 std::enable_if_t<!IsEinsumExpression<T>::value> >
224 : FalseType
225 {};
226
231 template<class Permutation, class T,
232 std::enable_if_t<IsTransposeOfEinsumWithRightEye<Permutation, T>::value, int> = 0>
233 constexpr decltype(auto) operate(SubExprTag, F<TransposeOperation<Permutation> >, T&& t)
234 {
235 DUNE_ACFEM_RECORD_OPTIMIZATION;
236
237 using Traits = EinsumCommutationHelper<T, Permutation>;
238
239 DUNE_ACFEM_EXPRESSION_RESULT(
241 operate(
242 DontOptimize{},
243 F<TransposeOperation<typename Traits::InversePermutation> >{},
244 operate(
245 std::forward<T>(t).operation(),
246 std::forward<T>(t).operand(1_c),
247 std::forward<T>(t).operand(0_c)
248 )))
249 , "transpose normalize eyes"
250 );
251 }
252
253 template<class Permutation, class T, class SFINAE = void>
254 struct IsNormalizableTranspose
255 : BoolConstant<EinsumCommutationHelper<T, Permutation>::isNormalizable>
256 {};
257
258 template<class Permutation, class T>
259 struct IsNormalizableTranspose<
260 Permutation, T,
261 std::enable_if_t<(IsSelfTransposed<Permutation, T>::value
262 || !IsEinsumExpression<T>::value
263 )> >
264 : FalseType
265 {};
266
268 template<class Permutation, class T,
269 std::enable_if_t<IsNormalizableTranspose<Permutation, T>::value, int> = 0>
270 constexpr decltype(auto) operate(DefaultTag, F<TransposeOperation<Permutation> >, T&& t)
271 {
272 DUNE_ACFEM_RECORD_OPTIMIZATION;
273
274 using Traits = EinsumCommutationHelper<T, Permutation>;
275
276 DUNE_ACFEM_EXPRESSION_RESULT(
277 operate<TransposeOperation<Permutation> >(
279 F<TransposeOperation<typename Traits::Normalization> >{},
280 std::forward<T>(t)
281 ))
282 , "sort transpositions"
283 );
284 }
285
287
301 template<class F, class T0, class T1, class SFINAE = void>
302 constexpr inline bool IsLeftSelfTransposedContractionV = HasInvariantValuesV<
303 MPL::SequenceSetMinus<MakeIndexSequence<TensorTraits<T0>::rank>, typename EinsumTraits<F>::LeftIndexPositions>,
304 typename TransposeTraits<T0>::Permutation>;
305
306 template<class F, class T0, class T1>
307 constexpr inline bool IsLeftSelfTransposedContractionV<
308 F, T0, T1,
309 std::enable_if_t<(!FunctorHas<IsEinsumOperation, F>::value
310 || std::is_same<F, ScalarEinsumFunctor>::value
311 || !IsTransposeExpression<T0>::value
312 )> > = false;
313
320 template<class F, class T0, class T1,
321 std::enable_if_t<IsLeftSelfTransposedContractionV<F, T0, T1>, int> = 0>
322 constexpr decltype(auto) operate(DefaultTag, F, T0&& t0, T1&& t1)
323 {
324 DUNE_ACFEM_RECORD_OPTIMIZATION;
325
326 using Inverse = InversePermutation<typename TransposeTraits<T0>::Permutation>;
327 using Traits = EinsumTraits<F>;
328 using Operation = EinsumOperation<
329 PermuteSequenceValues<typename Traits::LeftIndexPositions, Inverse>,
330 typename Traits::RightIndexPositions,
331 typename Traits::Dimensions>;
332
333 DUNE_ACFEM_EXPRESSION_RESULT(
334 operate<Operation>(
335 std::forward<T0>(t0).operand(0_c),
336 std::forward<T1>(t1))
337 , "left self transposed contraction"
338 );
339 }
340
342
347 template<class F, class T0, class T1, class SFINAE = void>
348 constexpr inline bool IsRightSelfTransposedContractionV = HasInvariantValuesV<
349 MPL::SequenceSetMinus<MakeIndexSequence<TensorTraits<T1>::rank>, typename EinsumTraits<F>::LeftIndexPositions>,
350 typename TransposeTraits<T1>::Permutation>;
351
352 template<class F, class T0, class T1>
353 constexpr inline bool IsRightSelfTransposedContractionV<
354 F, T0, T1,
355 std::enable_if_t<(!FunctorHas<IsEinsumOperation, F>::value
356 || std::is_same<F, ScalarEinsumFunctor>::value
357 || !IsTransposeExpression<T1>::value
358 )> > = false;
359
366 template<class F, class T0, class T1,
367 std::enable_if_t<IsRightSelfTransposedContractionV<F, T0, T1>, int> = 0>
368 constexpr decltype(auto) operate(OptimizeNext<DefaultTag>, F, T0&& t0, T1&& t1)
369 {
370 DUNE_ACFEM_RECORD_OPTIMIZATION;
371
372 using Inverse = InversePermutation<typename TransposeTraits<T1>::Permutation>;
373 using Traits = EinsumTraits<F>;
374 using Operation = EinsumOperation<
375 typename Traits::LeftIndexPositions,
376 PermuteSequenceValues<typename Traits::RightIndexPositions, Inverse>,
377 typename Traits::Dimensions>;
378
379 DUNE_ACFEM_EXPRESSION_RESULT(
380 operate<Operation>(
381 std::forward<T0>(t0),
382 std::forward<T1>(t1).operand(0_c)
383 )
384 , "right self transposed contraction"
385 );
386 }
387
389
390 } // Transpose::
391
392 } // Tensor::Optimization::
393
394 namespace Expressions {
395
397
398 } // Expressions
399
400 } // ACFem
401
402} // Dune
403
404#endif // __DUNE_ACFEM_TENSORS_OPTIMIZATION_TRANSPOSE_HH__
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
MakeSequence< std::size_t, N, Offset, Stride, Repeat > MakeIndexSequence
Make a sequence of std::size_t elements.
Definition: generators.hh:34
typename SequenceCatHelper2< S... >::Type SequenceCat
Concatenate the given sequences, in order, to <S0, S1, ... >.
Definition: transform.hh:220
constexpr auto operate(Expressions::DontOptimize, OperationTraits< TransposeOperation< Perm > >, T &&t)
Definition: transpose.hh:175
constexpr bool HasInvariantValuesV
Evaluate to true if applying the permutation Perm to the values of Seq leaves the values invariant.
Definition: permutation.hh:112
constexpr bool isSimple(Sequence< T, V... > seq)
Definition: compare.hh:334
typename PermuteSequenceValuesHelper< Perm, Sequence >::Type PermuteSequenceValues
Apply the given permutation to the values of the given sequence.
Definition: permutation.hh:101
Constant< bool, V > BoolConstant
Short-cut for integral constant of type bool.
Definition: types.hh:48
BoolConstant< false > FalseType
Alias for std::false_type.
Definition: types.hh:110
STL namespace.
Permutation of index positions of tensors.
Definition: expressionoperations.hh:167
Creative Commons License   |  Legal Statements / Impressum  |  Hosted by TU Dresden & Uni Heidelberg  |  generated with Hugo v0.111.3 (Mar 12, 23:28, 2025)