DUNE-ACFEM (unstable)

einsum.hh
1#ifndef __DUNE_ACFEM_TENSORS_OPTIMIZATION_EINSUM_HH__
2#define __DUNE_ACFEM_TENSORS_OPTIMIZATION_EINSUM_HH__
3
4#include "../../expressions/optimizegeneral.hh"
5#include "../expressionoperations.hh"
6#include "../operations/einsum.hh"
7#include "../operationtraits.hh"
8#include "../operations/transposetraits.hh"
9#include "../modules.hh"
10#include "policy.hh"
11
12namespace Dune {
13
14 namespace ACFem {
15
16 template<>
17 constexpr inline bool MultiplicationAdmitsScalarsV<EinsumOperation<Tensor::Seq<>, Tensor::Seq<>, Tensor::Seq<> > > = true;
18
19 namespace Tensor::Optimization::Einsum {
20
21 using DefaultTag = Tensor::Policy::DefaultOptimizationTag;
22 using CommuteTag = OptimizeNext<DefaultTag>;
23 using EinsumTag = OptimizeNext<CommuteTag, 2>;
24
49
50 template<class F, class T0, class T1, class SFINAE = void>
51 constexpr inline bool IsScalarEinsumV = false;
52
53 template<class Seq0, class Seq1, class Dims, class T0, class T1>
54 constexpr inline bool IsScalarEinsumV<
55 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
56 T0, T1,
57 std::enable_if_t<(EinsumRank<T0, Seq0, T1, Seq1> != 0
58 || !IsConstantExprArg<T0>::value
59 || !IsConstantExprArg<T1>::value
60 )> > = false;
61
62 template<class Op, class T0, class T1>
63 using ScalarEinsumResult = std::decay_t<decltype(operate(DontOptimize{}, OperationTraits<Op>{}, std::declval<T0>(), std::declval<T1>())(Seq<>{}))>;
64
65 template<class Seq0, class Seq1, class Dims, class T0, class T1>
66 constexpr inline bool IsScalarEinsumV<
67 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
68 T0, T1> = (!(ExamineOr<T0, IsRuntimeEqual>::value || ExamineOr<T1, IsRuntimeEqual>::value)
69 || IsFractionConstant<ScalarEinsumResult<EinsumOperation<Seq0, Seq1, Dims>, T0, T1> >::value);
70
72 template<class F, class T0, class T1, std::enable_if_t<IsScalarEinsumV<F, T0 , T1>, int> = 0>
73 constexpr decltype(auto) operate(OptimizeTerminal1, F&& f, T0&& t0, T1&& t1)
74 {
75 DUNE_ACFEM_RECORD_OPTIMIZATION;
76
77 DUNE_ACFEM_EXPRESSION_RESULT(
78 tensor(std::move(operate(DontOptimize{}, f, std::forward<T0>(t0), std::forward<T1>(t1))(Seq<>{})))
79 , "Compute Total Contractions"
80 );
81 }
82
84
85 template<class F, class T0, class T1>
86 constexpr inline bool ContractionOfFractionConstantsV = false;
87
88 template<class Seq0, class Seq1, class Dims,
89 class I0, I0 N0, I0 D0, class Signature0,
90 class I1, I1 N1, I1 D1, class Signature1>
91 constexpr inline bool ContractionOfFractionConstantsV<
92 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
93 ConstantTensor<FractionConstant<I0, N0, D0>, Signature0>,
94 ConstantTensor<FractionConstant<I1, N1, D1>, Signature1> > = EinsumRank<Signature0, Seq0, Signature1, Seq1> > 0;
95
96 template<class F, class T0, class T1, std::enable_if_t<ContractionOfFractionConstantsV<F, T0, T1>, int> = 0>
97 constexpr auto operate(OptimizeTerminal1, F&& f, T0&& t0, T1&& t1)
98 {
99 DUNE_ACFEM_RECORD_OPTIMIZATION;
100
101 using Signature = typename F::template Signature<T0, T1>;
102 constexpr auto contractionDim = intFraction<F::contractionDimension()>();
103 constexpr auto value = contractionDim * T0::data() * T1::data();
104
105 DUNE_ACFEM_EXPRESSION_RESULT(
106 constantTensor(value, Signature{}, Disclosure{})
107 , "Compile Time Constant Contraction"
108 );
109 }
110
112
113 template<class T>
114 constexpr inline std::size_t constness()
115 {
116 return (10*(std::size_t)IsConstantExprArg<T>::value
117 + 100*(std::size_t)IsRuntimeEqualExpression<T>::value
118 + 1000*(std::size_t)IsTypedValue<T>::value);
119 }
120
121 template<class T>
122 constexpr std::size_t leftAffinity()
123 {
124 constexpr std::size_t byKind = constness<T>();
125 constexpr std::size_t byWeight = (~0UL >> 16) - Expressions::weight<T>();
126
127 return (byKind << 48) + byWeight;
128 }
129
130 template<class F, class T0, class T1>
131 constexpr inline bool ShouldCommuteContractionV = false;
132
133 template<class Seq0, class Seq1, class Dims, class T0, class T1>
134 constexpr inline bool ShouldCommuteContractionV<OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0, T1> =
135 (Dims::size() > 0
136 && EinsumRank<T0, Seq0, T1, Seq1> == 0
137 && leftAffinity<T0>() < leftAffinity<T1>());
138
142 template<class F, class T0, class T1, std::enable_if_t<ShouldCommuteContractionV<F, T0, T1>, int> = 0>
143 constexpr decltype(auto) operate(CommuteTag, F&& f, T0&& t0, T1&& t1)
144 {
145 DUNE_ACFEM_RECORD_OPTIMIZATION;
146
147 DUNE_ACFEM_EXPRESSION_RESULT(
148 (operate(std::forward<F>(f), std::forward<T1>(t1), std::forward<T0>(t0)))
149 , "t0: " + t0.name() + "; t1: " + t1.name() + "; non scalar left affinity"
150 );
151 }
152
154
155 template<class F, class T0, class T1>
156 constexpr inline bool IsContractionOfConstWithEyeV = false;
157
161 template<class Seq0, class Seq1, class Dims, class T0, class T1>
162 constexpr inline bool IsContractionOfConstWithEyeV<OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0, T1> =
163 (Dims::size() > 0 // identify non-trivial contraction
164 && IsEye<T0>::value
165 && IsConstantTensor<T1>::value);
166
167 template<class F, class T0, class T1, std::enable_if_t<IsContractionOfConstWithEyeV<F, T0, T1>, int> = 0>
168 constexpr auto operate(DefaultTag, F&&, T0&& t0, T1&& t1)
169 {
170 DUNE_ACFEM_RECORD_OPTIMIZATION;
171
172 using EyeArg = std::decay_t<T0>;
173 using EyeRest = SequenceSliceComplement<typename EyeArg::Signature, typename F::LeftIndexPositions>;
174 using Const = std::decay_t<T1>;
175 using ConstRest = SequenceSliceComplement<typename Const::Signature, typename F::RightIndexPositions>;
176
177 DUNE_ACFEM_EXPRESSION_RESULT(
178 operate(ScalarEinsumFunctor{},
179 eye(EyeRest{}, Disclosure{}),
180 constantTensor(std::forward<T1>(t1).data(), ConstRest{}, Disclosure{}))
181 , "einsum eye-ones"
182 );
183 }
184
188 template<class F, class T0, class T1, std::enable_if_t<IsContractionOfConstWithEyeV<F, T1, T0>, int> = 0>
189 constexpr auto operate(DefaultTag, F&& f, T0&& t0, T1&& t1)
190 {
191 DUNE_ACFEM_RECORD_OPTIMIZATION;
192
193 using Const = std::decay_t<T0>;
194 using ConstRest = SequenceSliceComplement<typename Const::Signature, typename F::LeftIndexPositions>;
195 using EyeArg = std::decay_t<T1>;
196 using EyeRest = SequenceSliceComplement<typename EyeArg::Signature, typename F::RightIndexPositions>;
197
198 DUNE_ACFEM_EXPRESSION_RESULT(
199 operate(ScalarEinsumFunctor{},
200 constantTensor(std::forward<T0>(t0).data(), ConstRest{}, Disclosure{}),
201 eye(EyeRest{}, Disclosure{}))
202 , "einsum ones-eye"
203 );
204 }
205
206
207 } // Tensors::Optimization::Einsum::
208
210
211 namespace Expressions {
212
214
216 template<std::size_t Pos0, std::size_t Pos1, std::size_t D0, class Field, class T0>
217 constexpr inline bool ReturnFirstV<
218 OperationTraits<EinsumOperation<IndexSequence<Pos0>, IndexSequence<Pos1>, IndexSequence<D0> > >,
219 T0, Tensor::Eye<IndexSequence<D0, D0>, Field>
220 > = Tensor::IsSelfTransposed<Transposition<Pos0, TensorTraits<T0>::rank-1>, T0>::value;
221
224 template<std::size_t Pos0, std::size_t Pos1, std::size_t D0, class Field, class T1>
225 constexpr inline bool ReturnSecondV<
226 OperationTraits<EinsumOperation<IndexSequence<Pos0>, IndexSequence<Pos1>, IndexSequence<D0> > >,
227 Tensor::Eye<IndexSequence<D0, D0>, Field>, T1,
228 std::enable_if_t<!std::is_same<std::decay_t<T1>, Tensor::Eye<IndexSequence<D0, D0>, Field> >::value>
229 > = Tensor::IsSelfTransposed<Transposition<0, Pos1>, T1>::value;
230
233 template<class Seq0, class Seq1, class BlockSignature, class Field, class T0>
234 constexpr inline bool ReturnFirstV<
235 OperationTraits<EinsumOperation<Seq0, Seq1, BlockSignature> >,
236 T0, Tensor::BlockEye<2, BlockSignature, Field>,
237 std::enable_if_t<(BlockSignature::size() > 0
238 && isAlignedBlock<BlockSignature::size()>(Seq0{})
239 )> > = (isAlignedBlock<BlockSignature::size()>(Seq1{})
240 && Tensor::IsSelfTransposed<BlockTransposition<Seq0, MakeIndexSequence<BlockSignature::size(), TensorTraits<T0>::rank - BlockSignature::size()> >, T0>::value);
241
244 template<class Seq0, class Seq1, class BlockSignature, class Field, class T1>
245 constexpr inline bool ReturnSecondV<
246 OperationTraits<EinsumOperation<Seq0, Seq1, BlockSignature> >,
247 Tensor::BlockEye<2, BlockSignature, Field>, T1,
248 std::enable_if_t<(BlockSignature::size() > 0
249 && isAlignedBlock<BlockSignature::size()>(Seq1{})
250 && !std::is_same<T1, Tensor::BlockEye<2, BlockSignature, Field> >::value
251 )> > = (isAlignedBlock<BlockSignature::size()>(Seq0{})
252 && Tensor::IsSelfTransposed<BlockTransposition<MakeIndexSequence<BlockSignature::size()>, Seq1>, T1>::value);
253
254 } // Expressions::
255
257
258 namespace Tensor::Optimization::Einsum {
259
260 template<class F, class T0, class T1>
261 constexpr inline bool IsContractionOfKroneckersV = false;
262
263 template<class Seq0, class Seq1, class Dims, class T0, class T1>
264 constexpr inline bool IsContractionOfKroneckersV<
265 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
266 T0, T1
267 > = (IsConstKroneckerDelta<T0>::value
268 && IsConstKroneckerDelta<T1>::value);
269
270 template<class F, class T0, class T1, std::enable_if_t<IsContractionOfKroneckersV<F, T0, T1>, int> = 0>
271 constexpr decltype(auto) operate(OptimizeTerminal0, F&&, T0&&, T1&&)
272 {
273 using LeftPivots = typename T0::PivotSequence;
274 using RightPivots = typename T1::PivotSequence;
275 using LeftPos = typename F::LeftIndexPositions;
276 using RightPos = typename F::RightIndexPositions;
277 using Signature = typename F::template Signature<T0, T1>;
278 if constexpr (std::is_same<SequenceSlice<LeftPivots, LeftPos>, SequenceSlice<RightPivots, RightPos> >::value) {
279 DUNE_ACFEM_RECORD_OPTIMIZATION;
280
281 using LPivot = SequenceSliceComplement<LeftPivots, LeftPos>;
282 using RPivot = SequenceSliceComplement<RightPivots, RightPos>;
283 using PivotIndices = SequenceCat<LPivot, RPivot>;
284 DUNE_ACFEM_EXPRESSION_RESULT(
285 kroneckerDelta(Signature{}, PivotIndices{}, Disclosure{})
286 , "non-zero kronecker contraction"
287 );
288 } else {
289 DUNE_ACFEM_RECORD_OPTIMIZATION;
290
291 DUNE_ACFEM_EXPRESSION_RESULT(
292 zeros(Signature{}, Disclosure{})
293 , "zero kronecker contraction"
294 );
295 }
296 }
297
299
301 template<class F, class T0, class T1, class SFINAE = void>
303 : FalseType
304 {};
305
307 template<class Seq0, class Seq1, class Dims, class T0, class T1>
309 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0, T1,
310 std::enable_if_t<(// identify nested multiplication
311 IsEinsumExpression<T1>::value
312 // identify non-scalars
313 && (TensorTraits<T0>::rank > 0)
314 && (TensorTraits<T1>::rank > 0)
315 // identify middel scalar
316 && TensorTraits<Operand<0, T1> >::rank == 0
317 )> >
318 : TrueType
319 {};
320
324 template<
325 class F, class T0, class T1,
326 std::enable_if_t<IsMiddleScalarNestedEinsum<F, T0, T1>::value, int> = 0>
327 constexpr auto operate(OptimizeNext<EinsumTag>, F&& f, T0&& t0, T1&& t1)
328 {
329 DUNE_ACFEM_RECORD_OPTIMIZATION;
330
331 DUNE_ACFEM_EXPRESSION_RESULT(
332 operate(ScalarEinsumFunctor{},
333 std::forward<T1>(t1).operand(0_c),
334 operate(
335 std::forward<F>(f),
336 std::forward<T0>(t0),
337 std::forward<T1>(t1).operand(1_c)
338 )
339 )
340 , "t0: " + t0.name() + "; t1: " + t1.name() + "; middle scalar einsum"
341 );
342 }
343
345
350
351 } // Tensor::Optimization::Einsum::
352
353 namespace Expressions {
354
356
357 } // NS Expressions
358
359 } // NS ACFem
360
361} // NS Dune
362
363#endif // __DUNE_ACFEM_TENSORS_OPTIMIZATION_MULTIPY_HH__
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
MakeSequence< std::size_t, N, Offset, Stride, Repeat > MakeIndexSequence
Make a sequence of std::size_t elements.
Definition: generators.hh:34
constexpr auto operate(OptimizeNext< EinsumTag >, F &&f, T0 &&t0, T1 &&t1)
Definition: einsum.hh:327
typename BlockTranspositionHelper< Seq0, Seq1, N >::Type BlockTransposition
Generate the block-transposition of the I0 and the I1-block as permutation.
Definition: permutation.hh:212
BoolConstant< false > FalseType
Alias for std::false_type.
Definition: types.hh:110
BoolConstant< true > TrueType
Alias for std::true_type.
Definition: types.hh:107
STL namespace.
Einstein summation, i.e.
Definition: expressionoperations.hh:308
Optimization pattern disambiguation struct.
Definition: optimizationbase.hh:42
Creative Commons License   |  Legal Statements / Impressum  |  Hosted by TU Dresden  |  generated with Hugo v0.111.3 (Jul 15, 22:36, 2024)