1#ifndef __DUNE_ACFEM_TENSORS_OPTIMIZATION_EINSUM_HH__
2#define __DUNE_ACFEM_TENSORS_OPTIMIZATION_EINSUM_HH__
4#include "../../expressions/optimizegeneral.hh"
5#include "../expressionoperations.hh"
6#include "../operations/einsum.hh"
7#include "../operationtraits.hh"
8#include "../operations/transposetraits.hh"
9#include "../modules.hh"
17 constexpr inline bool MultiplicationAdmitsScalarsV<EinsumOperation<Tensor::Seq<>, Tensor::Seq<>, Tensor::Seq<> > > =
true;
19 namespace Tensor::Optimization::Einsum {
21 using DefaultTag = Tensor::Policy::DefaultOptimizationTag;
22 using CommuteTag = OptimizeNext<DefaultTag>;
23 using EinsumTag = OptimizeNext<CommuteTag, 2>;
50 template<
class F,
class T0,
class T1,
class SFINAE =
void>
51 constexpr inline bool IsScalarEinsumV =
false;
53 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
54 constexpr inline bool IsScalarEinsumV<
55 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
57 std::enable_if_t<(EinsumRank<T0, Seq0, T1, Seq1> != 0
58 || !IsConstantExprArg<T0>::value
59 || !IsConstantExprArg<T1>::value
62 template<
class Op,
class T0,
class T1>
63 using ScalarEinsumResult = std::decay_t<
decltype(
operate(
DontOptimize{}, OperationTraits<Op>{}, std::declval<T0>(), std::declval<T1>())(Seq<>{}))>;
65 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
66 constexpr inline bool IsScalarEinsumV<
67 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
68 T0, T1> = (!(ExamineOr<T0, IsRuntimeEqual>::value || ExamineOr<T1, IsRuntimeEqual>::value)
69 || IsFractionConstant<ScalarEinsumResult<EinsumOperation<Seq0, Seq1, Dims>, T0, T1> >::value);
72 template<
class F,
class T0,
class T1, std::enable_if_t<IsScalarEinsumV<F, T0 , T1>,
int> = 0>
75 DUNE_ACFEM_RECORD_OPTIMIZATION;
77 DUNE_ACFEM_EXPRESSION_RESULT(
78 tensor(std::move(
operate(
DontOptimize{}, f, std::forward<T0>(t0), std::forward<T1>(t1))(Seq<>{})))
79 ,
"Compute Total Contractions"
85 template<
class F,
class T0,
class T1>
86 constexpr inline bool ContractionOfFractionConstantsV =
false;
88 template<
class Seq0,
class Seq1,
class Dims,
89 class I0, I0 N0, I0 D0,
class Signature0,
90 class I1, I1 N1, I1 D1,
class Signature1>
91 constexpr inline bool ContractionOfFractionConstantsV<
92 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
93 ConstantTensor<FractionConstant<I0, N0, D0>, Signature0>,
94 ConstantTensor<FractionConstant<I1, N1, D1>, Signature1> > = EinsumRank<Signature0, Seq0, Signature1, Seq1> > 0;
96 template<
class F,
class T0,
class T1, std::enable_if_t<ContractionOfFractionConstantsV<F, T0, T1>,
int> = 0>
99 DUNE_ACFEM_RECORD_OPTIMIZATION;
101 using Signature =
typename F::template Signature<T0, T1>;
102 constexpr auto contractionDim = intFraction<F::contractionDimension()>();
103 constexpr auto value = contractionDim * T0::data() * T1::data();
105 DUNE_ACFEM_EXPRESSION_RESULT(
106 constantTensor(value, Signature{}, Disclosure{})
107 ,
"Compile Time Constant Contraction"
114 constexpr inline std::size_t constness()
116 return (10*(std::size_t)IsConstantExprArg<T>::value
117 + 100*(std::size_t)IsRuntimeEqualExpression<T>::value
118 + 1000*(std::size_t)IsTypedValue<T>::value);
122 constexpr std::size_t leftAffinity()
124 constexpr std::size_t byKind = constness<T>();
125 constexpr std::size_t byWeight = (~0UL >> 16) - Expressions::weight<T>();
127 return (byKind << 48) + byWeight;
130 template<
class F,
class T0,
class T1>
131 constexpr inline bool ShouldCommuteContractionV =
false;
133 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
134 constexpr inline bool ShouldCommuteContractionV<OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0, T1> =
136 && EinsumRank<T0, Seq0, T1, Seq1> == 0
137 && leftAffinity<T0>() < leftAffinity<T1>());
142 template<
class F,
class T0,
class T1, std::enable_if_t<ShouldCommuteContractionV<F, T0, T1>,
int> = 0>
145 DUNE_ACFEM_RECORD_OPTIMIZATION;
147 DUNE_ACFEM_EXPRESSION_RESULT(
148 (
operate(std::forward<F>(f), std::forward<T1>(t1), std::forward<T0>(t0)))
149 ,
"t0: " + t0.name() +
"; t1: " + t1.name() +
"; non scalar left affinity"
155 template<
class F,
class T0,
class T1>
156 constexpr inline bool IsContractionOfConstWithEyeV =
false;
161 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
162 constexpr inline bool IsContractionOfConstWithEyeV<OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0, T1> =
165 && IsConstantTensor<T1>::value);
167 template<
class F,
class T0,
class T1, std::enable_if_t<IsContractionOfConstWithEyeV<F, T0, T1>,
int> = 0>
170 DUNE_ACFEM_RECORD_OPTIMIZATION;
172 using EyeArg = std::decay_t<T0>;
173 using EyeRest = SequenceSliceComplement<typename EyeArg::Signature, typename F::LeftIndexPositions>;
174 using Const = std::decay_t<T1>;
175 using ConstRest = SequenceSliceComplement<typename Const::Signature, typename F::RightIndexPositions>;
177 DUNE_ACFEM_EXPRESSION_RESULT(
179 eye(EyeRest{}, Disclosure{}),
180 constantTensor(std::forward<T1>(t1).data(), ConstRest{}, Disclosure{}))
188 template<
class F,
class T0,
class T1, std::enable_if_t<IsContractionOfConstWithEyeV<F, T1, T0>,
int> = 0>
189 constexpr auto operate(DefaultTag, F&& f, T0&& t0, T1&& t1)
191 DUNE_ACFEM_RECORD_OPTIMIZATION;
193 using Const = std::decay_t<T0>;
194 using ConstRest = SequenceSliceComplement<typename Const::Signature, typename F::LeftIndexPositions>;
195 using EyeArg = std::decay_t<T1>;
196 using EyeRest = SequenceSliceComplement<typename EyeArg::Signature, typename F::RightIndexPositions>;
198 DUNE_ACFEM_EXPRESSION_RESULT(
200 constantTensor(std::forward<T0>(t0).data(), ConstRest{}, Disclosure{}),
201 eye(EyeRest{}, Disclosure{}))
211 namespace Expressions {
216 template<std::
size_t Pos0, std::
size_t Pos1, std::
size_t D0,
class Field,
class T0>
217 constexpr inline bool ReturnFirstV<
218 OperationTraits<EinsumOperation<IndexSequence<Pos0>, IndexSequence<Pos1>, IndexSequence<D0> > >,
219 T0, Tensor::Eye<IndexSequence<D0, D0>, Field>
220 > = Tensor::IsSelfTransposed<Transposition<Pos0, TensorTraits<T0>::rank-1>, T0>::value;
224 template<std::
size_t Pos0, std::
size_t Pos1, std::
size_t D0,
class Field,
class T1>
225 constexpr inline bool ReturnSecondV<
226 OperationTraits<EinsumOperation<IndexSequence<Pos0>, IndexSequence<Pos1>, IndexSequence<D0> > >,
227 Tensor::Eye<IndexSequence<D0, D0>, Field>, T1,
228 std::enable_if_t<!std::is_same<std::decay_t<T1>, Tensor::Eye<IndexSequence<D0, D0>, Field> >::value>
229 > = Tensor::IsSelfTransposed<Transposition<0, Pos1>, T1>::value;
233 template<
class Seq0,
class Seq1,
class BlockSignature,
class Field,
class T0>
234 constexpr inline bool ReturnFirstV<
235 OperationTraits<EinsumOperation<Seq0, Seq1, BlockSignature> >,
236 T0, Tensor::BlockEye<2, BlockSignature, Field>,
238 && isAlignedBlock<BlockSignature::size()>(Seq0{})
239 )> > = (isAlignedBlock<BlockSignature::size()>(Seq1{})
244 template<
class Seq0,
class Seq1,
class BlockSignature,
class Field,
class T1>
245 constexpr inline bool ReturnSecondV<
246 OperationTraits<EinsumOperation<Seq0, Seq1, BlockSignature> >,
247 Tensor::BlockEye<2, BlockSignature, Field>, T1,
249 && isAlignedBlock<BlockSignature::size()>(Seq1{})
250 && !std::is_same<T1, Tensor::BlockEye<2, BlockSignature, Field> >::value
251 )> > = (isAlignedBlock<BlockSignature::size()>(Seq0{})
258 namespace Tensor::Optimization::Einsum {
260 template<
class F,
class T0,
class T1>
261 constexpr inline bool IsContractionOfKroneckersV =
false;
263 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
264 constexpr inline bool IsContractionOfKroneckersV<
265 OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >,
267 > = (IsConstKroneckerDelta<T0>::value
268 && IsConstKroneckerDelta<T1>::value);
270 template<
class F,
class T0,
class T1, std::enable_if_t<IsContractionOfKroneckersV<F, T0, T1>,
int> = 0>
271 constexpr decltype(
auto)
operate(OptimizeTerminal0, F&&, T0&&, T1&&)
273 using LeftPivots =
typename T0::PivotSequence;
274 using RightPivots =
typename T1::PivotSequence;
275 using LeftPos =
typename F::LeftIndexPositions;
276 using RightPos =
typename F::RightIndexPositions;
277 using Signature =
typename F::template Signature<T0, T1>;
278 if constexpr (std::is_same<SequenceSlice<LeftPivots, LeftPos>, SequenceSlice<RightPivots, RightPos> >::value) {
279 DUNE_ACFEM_RECORD_OPTIMIZATION;
281 using LPivot = SequenceSliceComplement<LeftPivots, LeftPos>;
282 using RPivot = SequenceSliceComplement<RightPivots, RightPos>;
283 using PivotIndices = SequenceCat<LPivot, RPivot>;
284 DUNE_ACFEM_EXPRESSION_RESULT(
285 kroneckerDelta(Signature{}, PivotIndices{}, Disclosure{})
286 ,
"non-zero kronecker contraction"
289 DUNE_ACFEM_RECORD_OPTIMIZATION;
291 DUNE_ACFEM_EXPRESSION_RESULT(
292 zeros(Signature{}, Disclosure{})
293 ,
"zero kronecker contraction"
301 template<
class F,
class T0,
class T1,
class SFINAE =
void>
307 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
311 IsEinsumExpression<T1>::value
313 && (TensorTraits<T0>::rank > 0)
314 && (TensorTraits<T1>::rank > 0)
316 && TensorTraits<Operand<0, T1> >::rank == 0
325 class F,
class T0,
class T1,
326 std::enable_if_t<IsMiddleScalarNestedEinsum<F, T0, T1>::value,
int> = 0>
329 DUNE_ACFEM_RECORD_OPTIMIZATION;
331 DUNE_ACFEM_EXPRESSION_RESULT(
333 std::forward<T1>(t1).operand(0_c),
336 std::forward<T0>(t0),
337 std::forward<T1>(t1).operand(1_c)
340 ,
"t0: " + t0.name() +
"; t1: " + t1.name() +
"; middle scalar einsum"
353 namespace Expressions {
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
MakeSequence< std::size_t, N, Offset, Stride, Repeat > MakeIndexSequence
Make a sequence of std::size_t elements.
Definition: generators.hh:34
constexpr auto operate(OptimizeNext< EinsumTag >, F &&f, T0 &&t0, T1 &&t1)
Definition: einsum.hh:327
typename BlockTranspositionHelper< Seq0, Seq1, N >::Type BlockTransposition
Generate the block-transposition of the I0 and the I1-block as permutation.
Definition: permutation.hh:212
BoolConstant< false > FalseType
Alias for std::false_type.
Definition: types.hh:110
BoolConstant< true > TrueType
Alias for std::true_type.
Definition: types.hh:107
Einstein summation, i.e.
Definition: expressionoperations.hh:308
Optimization pattern disambiguation struct.
Definition: optimizationbase.hh:42
Definition: einsum.hh:304