1#ifndef __DUNE_ACFEM_TENSORS_OPERATIONS_ASSOCIATIVITY_HH__
2#define __DUNE_ACFEM_TENSORS_OPERATIONS_ASSOCIATIVITY_HH__
4#include "../tensorbase.hh"
5#include "../expressionoperations.hh"
15 namespace ProductOperations {
17 template<
class F,
class T0,
class T1,
class SFINAE =
void>
18 struct IsRightAssociative
22 template<
class F,
class T0,
class T1>
23 struct IsRightAssociative<F, T0, T1,
std::enable_if_t<!IsDecay<F>::value> >
24 : IsRightAssociative<std::decay_t<F>, T0, T1>
30 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
31 struct IsRightAssociative<OperationTraits<
EinsumOperation<Seq0, Seq1, Dims> >, T0, T1>
32 : FunctorHas<IsEinsumOperation, Functor<T0> >
35 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
36 struct IsRightAssociative<
38 std::enable_if_t<(FunctorHas<IsTensorProductOperation, Functor<T0> >::value
40 && std::is_same<Seq0, Seq1>::value
41 && std::is_same<typename TensorProductTraits<Operation<T0> >::LeftIndexPositions,
42 typename TensorProductTraits<Operation<T0> >::RightIndexPositions>::value
44 typename TensorProductTraits<Operation<T0> >::RightIndexPositions>::value
50 template<
class F,
class T0,
class T1,
class SFINAE =
void>
55 template<
class F,
class T0,
class T1>
63 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
65 : FunctorHas<IsEinsumOperation, Functor<T1> >
68 template<
class Seq0,
class Seq1,
class Dims,
class T0,
class T1>
71 std::enable_if_t<(FunctorHas<IsTensorProductOperation, Functor<T1> >::value
73 && std::is_same<Seq0, Seq1>::value
74 && std::is_same<typename TensorProductTraits<Operation<T1> >::LeftIndexPositions,
75 typename TensorProductTraits<Operation<T1> >::RightIndexPositions>::value
77 typename TensorProductTraits<Operation<T1> >::RightIndexPositions>::value
89 template<
class InnerLeftIndexPositions,
class InnerRightIndexPositions,
class InnerDims,
90 class LeftIndexPositions,
class RightIndexPositions,
class OuterDims,
91 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3,
92 std::enable_if_t<Rank1 * Rank2 * Rank3 != 0, int> = 0>
93 constexpr auto associateRightOperations(
101 constexpr std::ptrdiff_t rightRank = Rank3;
103 constexpr std::ptrdiff_t innerLeftRank = Rank1;
104 constexpr std::ptrdiff_t innerLeftDefectRank = innerLeftRank - innerDefectRank;
105 constexpr std::ptrdiff_t innerRightRank = Rank2;
106 constexpr std::ptrdiff_t innerRightDefectRank = innerRightRank - innerDefectRank;
113 using InnerLeftPos = SequenceSlice<LeftIndexPositions, InnerLeftPosIndices>;
114 using InnerRightPos = OffsetSequence<-innerLeftDefectRank, SequenceSlice<LeftIndexPositions, InnerRightPosIndices> >;
117 using OuterLeftPos = SequenceSlice<RightIndexPositions, InnerLeftPosIndices>;
118 using OuterRightPos = SequenceSlice<RightIndexPositions, InnerRightPosIndices>;
121 using InnerLeftLookup = SequenceSliceComplement<MakeIndexSequence<innerLeftRank>, InnerLeftIndexPositions>;
122 using InnerRightLookup = SequenceSliceComplement<MakeIndexSequence<innerRightRank>, InnerRightIndexPositions>;
136 using TwoPos = InnerRightMappedPos;
139 using ThreePos = OuterRightPos;
140 using ThreeDims = SequenceSlice<OuterDims, InnerRightPosIndices>;
146 using TwoThreeTwoPosLookup = SequenceSliceComplement<MakeIndexSequence<innerRightRank>, TwoPos>;
149 using TwoThreeThreePosLookup = SequenceSliceComplement<MakeIndexSequence<rightRank>, ThreePos>;
152 using TwoThreePos =
SequenceCat<TwoThreeTwoMappedPos, OffsetSequence<innerRightRank-
TwoPos::size() ,TwoThreeThreeMappedPos> >;
155 std::clog <<
"One: " << OnePos{} << std::endl;
156 std::clog <<
"TwoThree: " << TwoThreePos{} << std::endl;
157 std::clog <<
"Two: " << TwoPos{} << std::endl;
158 std::clog <<
"Three: " << ThreePos{} << std::endl;
161 return std::make_pair(OperationTraits<EinsumOperation<OnePos, TwoThreePos, OneDims> >{},
162 OperationTraits<EinsumOperation<TwoPos, ThreePos, ThreeDims> >{});
165 template<
class InnerLeftIndexPositions,
class InnerRightIndexPositions,
class InnerDims,
166 class LeftIndexPositions,
class RightIndexPositions,
class OuterDims,
167 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3,
168 std::enable_if_t<Rank1 * Rank2 * Rank3 == 0, int> = 0>
169 constexpr auto associateRightOperations(
170 OperationTraits<EinsumOperation<InnerLeftIndexPositions, InnerRightIndexPositions, InnerDims> > f1,
171 OperationTraits<EinsumOperation<LeftIndexPositions, RightIndexPositions, OuterDims> > f2,
172 IndexSequence<Rank1, Rank2, Rank3>)
178 if constexpr (Rank1 == 0) {
179 return std::make_pair(f1, f2);
180 }
else if constexpr (Rank2 == 0) {
181 return std::make_pair(f2, f1);
183 return std::make_pair(f1, f2);
194 template<
class LeftIndexPositions,
class RightIndexPositions,
class OuterDims,
195 class InnerLeftIndexPositions,
class InnerRightIndexPositions,
class InnerDims,
196 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3,
197 std::enable_if_t<Rank1 * Rank2 * Rank3 != 0, int> = 0>
198 constexpr auto associateLeftOperations(
199 OperationTraits<EinsumOperation<LeftIndexPositions, RightIndexPositions, OuterDims> >,
200 OperationTraits<EinsumOperation<InnerLeftIndexPositions, InnerRightIndexPositions, InnerDims> >,
201 IndexSequence<Rank1, Rank2, Rank3>)
205 constexpr std::ptrdiff_t leftRank = Rank1;
207 constexpr std::ptrdiff_t innerLeftRank = Rank2;
208 constexpr std::ptrdiff_t innerLeftDefectRank = innerLeftRank - innerDefectRank;
209 constexpr std::ptrdiff_t innerRightRank = Rank3;
210 constexpr std::ptrdiff_t innerRightDefectRank = innerRightRank - innerDefectRank;
213 using InnerLeftPosIndices = TransformSequence<RightIndexPositions, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, 0, innerLeftDefectRank> >;
214 using InnerRightPosIndices = TransformSequence<RightIndexPositions, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, innerLeftDefectRank, innerLeftDefectRank+innerRightDefectRank> >;
217 using InnerLeftPos = SequenceSlice<RightIndexPositions, InnerLeftPosIndices>;
218 using InnerRightPos = OffsetSequence<-innerLeftDefectRank, SequenceSlice<RightIndexPositions, InnerRightPosIndices> >;
221 using OuterLeftPos = SequenceSlice<LeftIndexPositions, InnerLeftPosIndices>;
222 using OuterRightPos = SequenceSlice<LeftIndexPositions, InnerRightPosIndices>;
225 using InnerLeftLookup = SequenceSliceComplement<MakeIndexSequence<innerLeftRank>, InnerLeftIndexPositions>;
226 using InnerRightLookup = SequenceSliceComplement<MakeIndexSequence<innerRightRank>, InnerRightIndexPositions>;
229 using InnerLeftMappedPos = TransformedSequence<MapSequenceFunctor<InnerLeftLookup>, InnerLeftPos>;
230 using InnerRightMappedPos = TransformedSequence<MapSequenceFunctor<InnerRightLookup>, InnerRightPos>;
233 using OnePos = OuterLeftPos;
234 using OneDims = SequenceSlice<OuterDims, InnerLeftPosIndices>;
237 using TwoPos = InnerLeftMappedPos;
242 using ThreePos = SequenceCat<InnerRightMappedPos, InnerRightIndexPositions>;
243 using ThreeDims = SequenceCat<SequenceSlice<OuterDims, InnerRightPosIndices>, InnerDims>;
251 using OneTwoOnePosLookup = SequenceSliceComplement<MakeIndexSequence<leftRank>, OnePos>;
252 using OneTwoOneMappedPos = TransformedSequence<InverseMapSequenceFunctor<OneTwoOnePosLookup>, OuterRightPos>;
255 using OneTwoTwoPosLookup = SequenceSliceComplement<MakeIndexSequence<innerLeftRank>, TwoPos>;
256 using OneTwoTwoMappedPos = TransformedSequence<InverseMapSequenceFunctor<OneTwoTwoPosLookup>, InnerLeftIndexPositions>;
258 using OneTwoPos =
SequenceCat<OneTwoOneMappedPos, OffsetSequence<leftRank-
OnePos::size(), OneTwoTwoMappedPos> >;
260 return std::make_pair(
261 OperationTraits<EinsumOperation<OnePos, TwoPos, OneDims> >{},
262 OperationTraits<EinsumOperation<OneTwoPos, ThreePos, ThreeDims> >{});
268 template<
class LeftIndexPositions,
class RightIndexPositions,
class OuterDims,
269 class InnerLeftIndexPositions,
class InnerRightIndexPositions,
class InnerDims,
270 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3,
271 std::enable_if_t<Rank1 * Rank2 * Rank3 == 0, int> = 0>
272 constexpr auto associateLeftOperations(
273 OperationTraits<EinsumOperation<LeftIndexPositions, RightIndexPositions, OuterDims> > f1,
274 OperationTraits<EinsumOperation<InnerLeftIndexPositions, InnerRightIndexPositions, InnerDims> > f2,
275 IndexSequence<Rank1, Rank2, Rank3>)
281 if constexpr (Rank1 == 0) {
282 return std::make_pair(f1, f2);
283 }
else if constexpr (Rank2 == 0) {
284 return std::make_pair(f2, f1);
286 return std::make_pair(f1, f2);
295 template<
class Defects,
class Dims,
296 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3>
297 constexpr auto associateRightOperations(
298 OperationTraits<TensorProductOperation<Defects, Defects, Dims> > f1,
299 OperationTraits<TensorProductOperation<Defects, Defects, Dims> > f2,
300 IndexSequence<Rank1, Rank2, Rank3>)
302 return std::make_pair(f1, f2);
312 template<
class Defects,
class Dims,
313 std::size_t Rank1, std::size_t Rank2, std::size_t Rank3>
314 constexpr auto associateLeftOperations(
315 OperationTraits<TensorProductOperation<Defects, Defects, Dims> > f1,
316 OperationTraits<TensorProductOperation<Defects, Defects, Dims> > f2,
317 IndexSequence<Rank1, Rank2, Rank3>)
319 return std::make_pair(f1, f2);
322 template<
class F,
class T0,
class T1,
324 std::enable_if_t<IsRightAssociative<F, T0, T1>::value,
int> = 0>
325 constexpr auto associateRightExpression(F&& f, T0&& t0, T1&& t1, OptimizeInner = OptimizeInner{})
327 auto operations = associateRightOperations(
328 std::forward<T0>(t0).operation(), std::forward<F>(f),
329 IndexSequence<TensorTraits<Operand<0, T0> >::rank, TensorTraits<Operand<1, T0> >::rank, TensorTraits<T1>::rank>{});
332 std::move(operations.first),
333 std::forward<T0>(t0).operand(0_c),
336 std::move(operations.second),
337 std::forward<T0>(t0).operand(1_c),
343 template<
class F,
class T0,
class T1,
344 std::enable_if_t<(!IsRightAssociative<F, T0, T1>::value
345 && FunctorHas<IsProductOperation, F>::value
346 && IsProductExpression<T1>::value
354 template<
class F,
class T0,
class T1,
357 std::enable_if_t<AreProperTensors<T0, T1>::value,
int> = 0>
358 constexpr decltype(
auto) associateRight(F&& f, T0&& t0, T1&& t1,
359 OptimizeOuter = OptimizeOuter{}, OptimizeInner = OptimizeInner{})
361 auto expr = associateRightExpression(std::forward<F>(f), std::forward<T0>(t0), std::forward<T1>(t1), OptimizeInner{});
364 std::move(expr).operation(),
365 std::move(expr).operand(0_c),
366 std::move(expr).operand(1_c));
369 template<
class F,
class T0,
class T1,
373 IsRightAssociative<F, T0, T1>::value),
int> = 0>
374 constexpr decltype(
auto) associateRight(F&& f, T0&& t0, T1&& t1,
375 OptimizeOuter = OptimizeOuter{}, OptimizeInner = OptimizeInner{})
377 auto operations = associateRightOperations(
378 std::forward<T0>(t0).operation(), std::forward<F>(f),
379 IndexSequence<TensorTraits<Operand<0, T0> >::rank, TensorTraits<Operand<1, T0> >::rank, TensorTraits<T1>::rank>{});
384 std::move(operations.first),
385 std::forward<T0>(t0).operand(0_c),
388 std::move(operations.second),
389 std::forward<T0>(t0).operand(1_c),
395 template<
class F,
class T0,
class T1,
397 std::enable_if_t<(!IsRightAssociative<F, T0, T1>::value
398 && FunctorHas<IsProductOperation, F>::value
399 && IsProductExpression<T1>::value
401 constexpr decltype(
auto) associateRight(F&& f, T0&& t0, T1&& t1,
404 return operate(OptimizeOuter{}, std::forward<F>(f), std::forward<T0>(t0), std::forward<T1>(t1));
408 template<
class F,
class T0,
class T1,
410 std::enable_if_t<IsLeftAssociative<F, T0, T1>::value,
int> = 0>
411 constexpr auto associateLeftExpression(F&& f, T0&& t0, T1&& t1,
412 OptimizeInner = OptimizeInner{})
415 auto operations = associateLeftOperations(
416 std::forward<F>(f), std::forward<T1>(t1).operation(),
417 IndexSequence<TensorTraits<T0>::rank, TensorTraits<Operand<0, T1> >::rank, TensorTraits<Operand<1, T1> >::rank>{});
420 std::move(operations.second),
423 std::move(operations.first),
424 std::forward<T0>(t0),
425 std::forward<T1>(t1).operand(0_c)
427 std::forward<T1>(t1).operand(1_c)
435 template<
class F,
class T0,
class T1,
436 std::enable_if_t<(!IsLeftAssociative<F, T0, T1>::value
437 && IsProductExpression<T0>::value
438 && FunctorHas<IsProductOperation, F>::value
440 constexpr auto associateLeftExpression(F&& f, T0&& t0, T1&& t1,
447 template<
class F,
class T0,
class T1,
449 std::enable_if_t<AreProperTensors<T0, T1>::value,
int> = 0>
450 constexpr decltype(
auto) associateLeft(F&& f, T0&& t0, T1&& t1,
451 OptimizeOuter = OptimizeOuter{}, OptimizeInner = OptimizeInner{})
453 auto expr = associateLeftExpression(std::forward<F>(f), std::forward<T0>(t0), std::forward<T1>(t1), OptimizeInner{});
457 std::move(expr).operation(),
458 std::move(expr).operand(0_c),
459 std::move(expr).operand(1_c));
462 template<
class F,
class T0,
class T1,
465 IsLeftAssociative<F, T0, T1>::value),
int> = 0>
466 constexpr decltype(
auto) associateLeft(F&& f, T0&& t0, T1&& t1,
467 OptimizeOuter = OptimizeOuter{}, OptimizeInner = OptimizeInner{})
470 auto operations = associateLeftOperations(
471 std::forward<F>(f), std::forward<T1>(t1).operation(),
472 IndexSequence<TensorTraits<T0>::rank, TensorTraits<Operand<0, T1> >::rank, TensorTraits<Operand<1, T1> >::rank>{});
476 std::move(operations.second),
479 std::move(operations.first),
480 std::forward<T0>(t0),
481 std::forward<T1>(t1).operand(0_c)
483 std::forward<T1>(t1).operand(1_c)
491 template<
class F,
class T0,
class T1,
494 !IsLeftAssociative<F, T0, T1>::value
495 && IsProductExpression<T0>::value
496 && FunctorHas<IsProductOperation, F>::value
498 constexpr decltype(
auto) associateLeft(F&& f, T0&& t0, T1&& t1,
501 return operate(OptimizeOuter{}, std::forward<F>(f), std::forward<T0>(t0), std::forward<T1>(t1));
505 template<std::
size_t N,
class F,
class T0,
class T1>
506 using AssociateRightOperation =
509 associateRightOperations(
510 std::declval<Functor<T0> >(), std::declval<F>(),
512 TensorTraits<Operand<1, T0> >::rank,
513 TensorTraits<T1>::rank>{})
516 template<std::
size_t N,
class F,
class T0,
class T1>
517 using AssociateLeftOperation =
520 associateLeftOperations(
521 std::declval<F>(), std::declval<Functor<T1> >(),
523 TensorTraits<Operand<0, T1> >::rank,
524 TensorTraits<Operand<1, T1> >::rank>{})
536 constexpr decltype(
auto) rightMostFactor(T&& t)
538 if constexpr (!IsProductExpression<T>::value) {
539 return std::forward<T>(t);
540 }
else if constexpr (!IsLeftAssociative<Functor<T>, Operand<0, T>, Operand<1, T> >::value) {
541 return std::forward<T>(t).operand(1_c);
543 return rightMostFactor(std::forward<T>(t).operand(1_c));
556 constexpr decltype(
auto) leftMostFactor(T&& t)
558 if constexpr (!IsProductExpression<T>::value) {
559 return std::forward<T>(t);
560 }
else if constexpr (!IsRightAssociative<Functor<T>, Operand<0, T>, Operand<1, T> >::value) {
561 return std::forward<T>(t).operand(0_c);
563 return leftMostFactor(std::forward<T>(t).operand(0_c));
572 template<
class F,
class T0,
class T1,
class Optimize =
OptimizeTop,
573 std::enable_if_t<FunctorHas<IsProductOperation, F>::value,
int> = 0>
574 constexpr auto factorOutRight(
const F& f, T0&& t0, T1&& t1, Optimize = Optimize{})
576 if constexpr (!IsLeftAssociative<F, T0, T1>::value) {
577 return Expressions::Storage<F, T0, T1>(f, std::forward<T0>(t0), std::forward<T1>(t1));
580 auto operations = associateLeftOperations(
581 f, std::forward<T1>(t1).operation(),
582 IndexSequence<TensorTraits<T0>::rank, TensorTraits<Operand<0, T1> >::rank, TensorTraits<Operand<1, T1> >::rank>{});
584 return factorOutRight(std::move(operations.second),
587 std::move(operations.first),
588 std::forward<T0>(t0),
589 std::forward<T1>(t1).operand(0_c)
591 std::forward<T1>(t1).operand(1_c));
606 template<
class F,
class T0,
class T1,
class Optimize =
OptimizeTop,
607 std::enable_if_t<FunctorHas<IsProductOperation, F>::value,
int> = 0>
608 constexpr auto factorOutLeft(
const F& f, T0&& t0, T1&& t1, Optimize = Optimize{})
610 if constexpr (!IsRightAssociative<F, T0, T1>::value) {
611 return Expressions::Storage<F, T0, T1>(f, std::forward<T0>(t0), std::forward<T1>(t1));
614 auto operations = associateRightOperations(
615 std::forward<T0>(t0).operation(), f,
616 IndexSequence<TensorTraits<Operand<0, T0> >::rank, TensorTraits<Operand<1, T0> >::rank, TensorTraits<T1>::rank>{});
618 return factorOutLeft(std::move(operations.first),
619 std::forward<T0>(t0).operand(0_c),
622 std::move(operations.second),
623 std::forward<T0>(t0).operand(1_c),
624 std::forward<T1>(t1)));
628 template<
class F,
class T0,
class T1>
629 using FactorOutRight = std::decay_t<decltype(factorOutRight(std::declval<F>(), std::declval<T0>(), std::declval<T1>()))>;
631 template<
class F,
class T0,
class T1>
632 using FactorOutLeft = std::decay_t<decltype(factorOutLeft(std::declval<F>(), std::declval<T0>(), std::declval<T1>()))>;
636 using RightMostFactor =
decltype(rightMostFactor(std::declval<T>()));
639 using LeftMostFactor =
decltype(leftMostFactor(std::declval<T>()));
642 template<
class T,
class SFINAE =
void>
643 struct RightMostFactorHelper
649 struct RightMostFactorHelper<T,
std::enable_if_t<IsProductExpression<T>::value> >
651 using Type = ConditionalType<IsLeftAssociative<Functor<T>, Operand<0, T>, Operand<1, T> >::value,
652 typename RightMostFactorHelper<Operand<1, T> >::Type,
656 template<
class T,
class SFINAE =
void>
657 struct LeftMostFactorHelper
663 struct LeftMostFactorHelper<T,
std::enable_if_t<IsProductExpression<T>::value> >
665 using Type = ConditionalType<IsRightAssociative<Functor<T>, Operand<0, T>, Operand<1, T> >::value,
666 typename LeftMostFactorHelper<Operand<0, T> >::Type,
672 using RightMostFactor =
typename RightMostFactorHelper<T>::Type;
675 using LeftMostFactor =
typename LeftMostFactorHelper<T>::Type;
680 using ProductOperations::IsRightAssociative;
681 using ProductOperations::IsLeftAssociative;
682 using ProductOperations::AssociateRightOperation;
683 using ProductOperations::associateRightExpression;
684 using ProductOperations::associateRight;
685 using ProductOperations::associateRight;
686 using ProductOperations::AssociateLeftOperation;
687 using ProductOperations::associateLeftOperations;
688 using ProductOperations::associateLeftExpression ;
689 using ProductOperations::associateLeft;
690 using ProductOperations::factorOutRight;
691 using ProductOperations::rightMostFactor;
692 using ProductOperations::RightMostFactor;
693 using ProductOperations::FactorOutRight;
694 using ProductOperations::factorOutLeft;
695 using ProductOperations::leftMostFactor;
696 using ProductOperations::LeftMostFactor;
697 using ProductOperations::FactorOutLeft;
OptimizeTag< Policy::OptimizationLevelMax::value > OptimizeTop
The top-level optmization tag.
Definition: optimizationbase.hh:59
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
constexpr auto storage(const F &f, T &&... t)
Generate an expression storage container.
Definition: storage.hh:704
std::tuple_element_t< N, std::decay_t< TupleLike > > TupleElement
Forward to std::tuple_element<N, std::decay_t<T> >
Definition: access.hh:125
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
Sequence< std::size_t, V... > IndexSequence
Sequence of std::size_t values.
Definition: types.hh:64
BoolConstant< false > FalseType
Alias for std::false_type.
Definition: types.hh:110
BoolConstant< true > TrueType
Alias for std::true_type.
Definition: types.hh:107
Einstein summation, i.e.
Definition: expressionoperations.hh:308
Index functor.
Definition: transform.hh:346
Component-wise product over given index-set.
Definition: expressionoperations.hh:389
Definition: associativity.hh:53