DUNE-ACFEM (unstable)

einsum.hh
1 #ifndef __DUNE_ACFEM_TENSORS_OPERATIONS_EINSUM_HH__
2 #define __DUNE_ACFEM_TENSORS_OPERATIONS_EINSUM_HH__
3 
4 #include "../../common/types.hh"
5 //#include "../../common/ostream.hh"
6 #include "../../mpl/foreach.hh"
7 #include "../../mpl/insertat.hh"
8 #include "../../mpl/toarray.hh"
9 #include "../../expressions/storage.hh"
10 #include "../../expressions/expressionoperations.hh"
11 #include "../../expressions/constantoperations.hh"
12 #include "../tensorbase.hh"
13 #include "../modules/eye.hh"
14 #include "../optimization/policy.hh"
15 #include "restrictiondetail.hh"
16 
17 namespace Dune {
18 
19  namespace ACFem {
20 
21  // forward
22  template<class Pos1, class Pos2, class ContractDims>
23  struct EinsumOperation;
24 
29  namespace Tensor {
30 
63  template<class LeftTensor, class LeftIndices, class RightTensor, class RightIndices, bool IdenticalOperands = Expressions::AreRuntimeEqual<LeftTensor, RightTensor>::value>
64  class EinsteinSummation;
65 
66  template<class LDims, class LPos, class RDims, class RPos, class SFINAE = void>
67  struct EinsumOperationTraits
68  : EinsumOperationTraits<typename TensorTraits<LDims>::Signature, LPos,
69  typename TensorTraits<RDims>::Signature, RPos>
70  {};
71 
72  template<std::size_t... LDims, std::size_t... LPos, std::size_t... RDims, std::size_t... RPos>
73  struct EinsumOperationTraits<Seq<LDims...>, Seq<LPos...>, Seq<RDims...>, Seq<RPos...> >
74  {
75  using Signature = SequenceCat<SubSequenceComplement<Seq<LDims...>, LPos...>,
76  SubSequenceComplement<Seq<RDims...>, RPos...> >;
77  static constexpr std::size_t rank_ = Signature::size();
78  using Operation = EinsumOperation<Seq<LPos...>, Seq<RPos...>, SubSequence<Seq<LDims...>, LPos...> >;
79  using Functor = OperationTraits<Operation>;
80  };
81 
83  template<class LDims, class LPos, class RDims, class RPos>
84  using EinsumSignature = typename EinsumOperationTraits<LDims, LPos, RDims, RPos>::Signature;
85 
90  template<class LDims, class LPos, class RDims, class RPos>
91  using EinsumFunctor = typename EinsumOperationTraits<LDims, LPos, RDims, RPos>::Functor;
92 
93  using ScalarEinsumOperation = EinsumOperation<Seq<>, Seq<>, Seq<> >;
94  using ScalarEinsumFunctor = OperationTraits<ScalarEinsumOperation>;
95 
96  template<class T0, class T1, bool IdenticalOperands>
97  using ScalarEinsumExpression = EinsteinSummation<T0, Seq<>, T1, Seq<>, IdenticalOperands>;
98 
99  template<class LDims, class LPos, class RDims, class RPos>
100  inline constexpr std::size_t EinsumRank = EinsumOperationTraits<LDims, LPos, RDims, RPos>::rank_;
101 
102  template<class Dims, class Pos>
103  using EinsumDimensions = SequenceSlice<Dims, Pos>;
104 
105  template<class LeftTensor, class LeftIndices, class RightTensor, class RightIndices, bool IdenticalOperands>
107  : public TensorBase<FloatingPointClosure<typename FieldPromotion<LeftTensor, RightTensor>::Type>,
108  EinsumSignature<LeftTensor, LeftIndices,
109  RightTensor, RightIndices >,
110  EinsteinSummation<LeftTensor, LeftIndices, RightTensor, RightIndices > >
111  , public Expressions::Storage<EinsumFunctor<LeftTensor, LeftIndices, RightTensor, RightIndices >,
112  LeftTensor, RightTensor>
113  // Expressions model rvalues. We are constant if the underlying
114  // expression claims to be so or if we store a copy and the
115  // underlying expression is "independent" (i.e. determined at
116  // runtime only by the contained data)
117  {
118  using ThisType = EinsteinSummation;
119  using LeftType = LeftTensor;
120  using RightType = RightTensor;
121  using LeftTensorType = std::decay_t<LeftTensor>;
122  using RightTensorType = std::decay_t<RightTensor>;
123  public:
124  using LeftIndexPositions = LeftIndices;
125  using RightIndexPositions = RightIndices;
126  private:
127  // Mmmh. Should we really support this?
128  using LeftSorted = SortSequence<LeftIndexPositions>;
129  using RightSorted = SortSequence<RightIndexPositions>;
130  using LeftPermutation = typename LeftSorted::Permutation;
131  using RightPermutation = typename RightSorted::Permutation;
132  using LeftSortedPos = typename LeftSorted::Result;
133  using RightSortedPos = typename RightSorted::Result;
134  static constexpr bool leftIsSorted_ = isSimple(LeftPermutation{});
135  static constexpr bool rightIsSorted_ = isSimple(RightPermutation{});
137  using LeftSignature = typename LeftTensorType::Signature;
138  using RightSignature = typename RightTensorType::Signature;
139  using LeftDefectSignature = SequenceSlice<LeftSignature, LeftIndices>;
140  using RightDefectSignature = SequenceSlice<RightSignature, RightIndices>;
141  public:
142  using LeftSignatureRest = SequenceSliceComplement<LeftSignature, LeftIndices>;
143  using RightSignatureRest = SequenceSliceComplement<RightSignature, RightIndices>;
146  using FunctorType = OperationTraits<EinsumOperation<LeftIndexPositions, RightIndexPositions, LeftDefectSignature> >;
147  private:
149  using StorageType = Expressions::Storage<FunctorType, LeftTensor, RightTensor>;
150  public:
151  static const std::size_t leftRank_ = LeftSignatureRest::size();
152  static const std::size_t rightRank_ = RightSignatureRest::size();
153  using LeftArgs = MakeIndexSequence<leftRank_>;
154  using RightArgs = MakeIndexSequence<rightRank_, leftRank_>; // offset
155  using DefectSignature = LeftDefectSignature;
156  static constexpr std::size_t defectRank_ = DefectSignature::size();
157  private:
158  static_assert(LeftIndices::size() == RightIndices::size(),
159  "Number of contraction indices must coincide.");
160  static_assert(LeftTensorType::rank >= LeftIndices::size(),
161  "Left: Number of index-positions must fit into tensor rank.");
162  static_assert(RightTensorType::rank >= RightIndices::size(),
163  "Right: Number of index-positions must fit into tensor rank.");
164  static_assert(std::is_same<LeftDefectSignature, RightDefectSignature>::value,
165  "Left- and right defect-dimensions must coincide.");
166  static_assert(LeftIndexPositions{} < LeftTensorType::rank,
167  "Left index positions out of range.");
168  static_assert(RightIndexPositions{} < RightTensorType::rank,
169  "Right index positions out of range.");
170 
171  public:
172  using StorageType::operation;
173  using StorageType::operand;
174  using BaseType::rank;
175 
176  template<class LeftArg, class RightArg,
177  std::enable_if_t<std::is_constructible<LeftType, LeftArg>::value && std::is_constructible<RightType, RightArg>::value, int> = 0>
178  EinsteinSummation(LeftArg&& left, RightArg&& right)
179  : StorageType(std::forward<LeftArg>(left), std::forward<RightArg>(right))
180  {}
181 
182  template<
183  class... Dummy,
184  std::enable_if_t<(sizeof...(Dummy) == 0
187  ), int> = 0>
188  EinsteinSummation(Dummy&&...)
189  : StorageType(LeftType{}, RightType{})
190  {}
191 
192 #if 0
193  static_assert(!(IsTypedValue<LeftType>::value && std::is_reference<LeftType>::value)
194  &&
195  !(IsTypedValue<RightType>::value && std::is_reference<RightType>::value),
196  "Typed values should not be stored as references.");
197 #endif
198 
199 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
200  class GCCBugCompensator1
201  {
202  public:
203  template<class LeftIndexArg, class RightIndexArg>
204  GCCBugCompensator1(FieldType& accu,
205  const LeftTensorType& left, const RightTensorType& right,
206  LeftIndexArg&& leftIndices, RightIndexArg&& rightIndices)
207  : accu_(accu), left_(left), right_(right), leftIndices_(toArray<std::size_t>(leftIndices)), rightIndices_(toArray<std::size_t>(rightIndices))
208  {}
209 
210  template<class I>
211  void operator()(I)
212  {
213  using DefectIndices = MultiIndex<I::value, DefectSignature>;
214  if (!LeftTensorType::isZero(DefectIndices{}, LeftIndexPositions{}) && !RightTensorType::isZero(DefectIndices{}, RightIndexPositions{})) {
215  if constexpr (IdenticalOperands) {
216  if constexpr (rank == 0 && std::is_same<LeftIndexPositions, RightIndexPositions>::value) {
217  auto val = tensorValue(left_, insertAt(leftIndices_, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}));
218  accu_ += val * val;
219  } else {
220  accu_ +=
221  tensorValue(left_, insertAt(leftIndices_, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}))
222  *
223  tensorValue(left_, insertAt(rightIndices_, PermuteSequence<DefectIndices, RightPermutation>{}, RightSortedPos{}));
224  }
225  } else {
226  accu_ +=
227  tensorValue(left_, insertAt(leftIndices_, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}))
228  *
229  tensorValue(right_, insertAt(rightIndices_, PermuteSequence<DefectIndices, RightPermutation>{}, RightSortedPos{}));
230  }
231  }
232  }
233  private:
234  FieldType& accu_;
235  const LeftTensorType& left_;
236  const RightTensorType& right_;
237  const std::array<std::size_t, leftRank_> leftIndices_;
238  const std::array<std::size_t, rightRank_> rightIndices_;
239  };
240 #endif
241 
242  template<class... Dims,
243  std::enable_if_t<(sizeof...(Dims) == rank
244  &&
246  , int> = 0>
247  auto operator()(Dims... indices) const
248  {
249  auto leftIndices = forwardSubTuple(std::forward_as_tuple(indices...), LeftArgs{});
250  auto rightIndices = forwardSubTuple(std::forward_as_tuple(indices...), RightArgs{});
251  (void)rightIndices;
252 
253  auto accu = FieldType(0);
254 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
255  forLoop<multiDim(DefectSignature{})>(GCCBugCompensator1(accu, operand(0_c), operand(1_c), leftIndices, rightIndices));
256 #else
257  forLoop<multiDim(DefectSignature{})>([&](auto i) {
258  using I = decltype(i);
259  using DefectIndices = MultiIndex<I::value, DefectSignature>;
260  if (!LeftTensorType::isZero(DefectIndices{}, LeftIndexPositions{}) && !RightTensorType::isZero(DefectIndices{}, RightIndexPositions{})) {
261  if constexpr (IdenticalOperands) {
262  if constexpr (rank == 0 && std::is_same<LeftIndexPositions, RightIndexPositions>::value) {
263  auto val = tensorValue(operand(0_c), insertAt(leftIndices, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}));
264  accu += val * val;
265  } else {
266  accu +=
267  tensorValue(operand(0_c), insertAt(leftIndices, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}))
268  *
269  tensorValue(operand(0_c), insertAt(rightIndices, PermuteSequence<DefectIndices, RightPermutation>{}, RightSortedPos{}));
270  }
271  } else {
272  accu +=
273  tensorValue(operand(0_c), insertAt(leftIndices, PermuteSequence<DefectIndices, LeftPermutation>{}, LeftSortedPos{}))
274  *
275  tensorValue(operand(1_c), insertAt(rightIndices, PermuteSequence<DefectIndices, RightPermutation>{}, RightSortedPos{}));
276  }
277  }
278  });
279 #endif
280  return accu;
281  }
282 
283 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
284  template<class LeftIndexArg, class RightIndexArg>
285  class GCCBugCompensator2
286  {
287  public:
288  GCCBugCompensator2(FieldType& accu, const LeftTensorType& left, const RightTensorType& right)
289  : accu_(accu), left_(left), right_(right)
290  {}
291 
292  template<class I>
293  void operator()(I)
294  {
295  using DefectIndices = MultiIndex<I::value, DefectSignature>;
298  if constexpr (IdenticalOperands) {
299  if constexpr (rank == 0 && std::is_same<LeftIndexPositions, RightIndexPositions>::value) {
300  if (!LeftTensorType::isZero(LeftArg{})) {
301  auto val = left_(LeftArg{});
302  accu_ += val * val;
303  }
304  } else {
305  if (!LeftTensorType::isZero(LeftArg{}) && !LeftTensorType::isZero(RightArg{})) {
306  accu_ += left_(LeftArg{}) * left_(RightArg{});
307  }
308  }
309  } else {
310  if (!LeftTensorType::isZero(LeftArg{}) && !RightTensorType::isZero(RightArg{})) {
311  accu_ += left_(LeftArg{}) * right_(RightArg{});
312  }
313  }
314  }
315  private:
316  FieldType& accu_;
317  const LeftTensorType& left_;
318  const RightTensorType& right_;
319  };
320 
321  template<class LeftIndexArg, class RightIndexArg>
322  class GCCBugCompensator3
323  {
324  public:
325  GCCBugCompensator3(const LeftTensorType& left, const RightTensorType& right)
326  : left_(left), right_(right)
327  {}
328 
329  template<class I>
330  constexpr auto operator()(I) const
331  {
332  using DefectIndices = MultiIndex<I::value, DefectSignature>;
335  if constexpr (IdenticalOperands) {
336  if constexpr (rank == 0 && std::is_same<LeftIndexPositions, RightIndexPositions>::value) {
337  auto val = IntFraction<(int)!LeftTensorType::isZero(LeftArg{})>{} * left_(LeftArg{});
338  return val * val;
339  } else {
340  return
341  (IntFraction<(int)!LeftTensorType::isZero(LeftArg{})>{} * left_(LeftArg{}))
342  *
343  (IntFraction<(int)!RightTensorType::isZero(RightArg{})>{} * left_(RightArg{}));
344  }
345  } else {
346  return
347  (IntFraction<(int)!LeftTensorType::isZero(LeftArg{})>{} * left_(LeftArg{}))
348  *
349  (IntFraction<(int)!RightTensorType::isZero(RightArg{})>{} * right_(RightArg{}));
350  }
351  }
352  private:
353  const LeftTensorType& left_;
354  const RightTensorType& right_;
355  };
356 #endif
357 
360  template<std::size_t... Indices,
361  std::enable_if_t<(sizeof...(Indices) == rank
362  && ThisType::template isZero<Indices...>()
363  ), int> = 0>
364  auto constexpr operator()(Seq<Indices...>) const
365  {
366  return IntFraction<0>{};
367  }
368 
369  template<std::size_t... Indices, std::enable_if_t<sizeof...(Indices) != rank, int> = 0>
370  auto constexpr operator()(Seq<Indices...>) const
371  {
372  static_assert(sizeof...(Indices) == rank, "This should not happen");
373  }
374 
377  template<std::size_t... Indices,
378  std::enable_if_t<(sizeof...(Indices) == rank
379  && !ThisType::template isZero<Indices...>()
380  && multiDim(DefectSignature{}) >= 0
381  && multiDim(DefectSignature{}) <= Policy::TemplateForEachLimit::value
382  ), int> = 0>
383  auto constexpr operator()(Seq<Indices...>) const
384  {
385  using IndexArg = Seq<Indices...>;
386  using LeftIndexArg = SequenceSlice<IndexArg, LeftArgs>;
387  using RightIndexArg = SequenceSlice<IndexArg, RightArgs>;
388  static_assert(multiDim(DefectSignature{}) > 0, "BUG");
389 #if 1
390  return
391  addLoop<multiDim(DefectSignature{})>(
392 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
393  GCCBugCompensator3<LeftIndexArg, RightIndexArg>(operand(0_c), operand(1_c)),
394 #else
395  [&](auto i) {
396  using DefectIndices = MultiIndex<i(), DefectSignature>;
399  if constexpr (IdenticalOperands) {
400  if constexpr (rank == 0 && std::is_same<LeftIndexPositions, RightIndexPositions>::value) {
401  auto val = intFraction<!LeftTensorType::isZero(LeftArg{})>() * operand(0_c)(LeftArg{});
402  return val * val;
403  } else {
404  return
405  (intFraction<!LeftTensorType::isZero(LeftArg{})>() * operand(0_c)(LeftArg{}))
406  *
407  (intFraction<!RightTensorType::isZero(RightArg{})>() * operand(0_c)(RightArg{}));
408  }
409  } else {
410  return
411  (intFraction<!LeftTensorType::isZero(LeftArg{})>() * operand(0_c)(LeftArg{}))
412  *
413  (intFraction<!RightTensorType::isZero(RightArg{})>() * operand(1_c)(RightArg{}));
414  }
415  },
416 #endif
417  intFraction<0>());
418 #else // if 0
419  auto accu = FieldType(0);
420 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
421  forLoop<multiDim(DefectSignature{})>(GCCBugCompensator2<LeftIndexArg, RightIndexArg>(accu, operand(0_c), operand(1_c)));
422 #else
423  forLoop<multiDim(DefectSignature{})>([&](auto i) {
424  using DefectIndices = MultiIndex<i, DefectSignature>;
427 
428  assert(!LeftTensorType::isZero(LeftArg{}) || operand(0_c)(LeftArg{}) == 0);
429  assert(!RightTensorType::isZero(RightArg{}) || operand(1_c)(RightArg{}) == 0);
430 
431  if (!LeftTensorType::isZero(LeftArg{}) && !RightTensorType::isZero(RightArg{})) {
432  accu += operand(0_c)(LeftArg{}) * operand(1_c)(RightArg{});
433  }
434  });
435 #endif
436  return accu;
437 #endif
438  }
439 
442  template<std::size_t... Indices,
443  std::enable_if_t<(sizeof...(Indices) == rank
444  && !ThisType::template isZero<Indices...>()
445  && multiDim(DefectSignature{}) > Policy::TemplateForEachLimit::value
446  ), int> = 0>
447  auto constexpr operator()(Seq<Indices...>) const
448  {
449 #if 0
450  // Just forward to the run-time method. Perhaps the compiler
451  // still can optimize something ...
452  return (*this)(Indices...);
453 #else
454  // Still a compile-time loop, but with pack expansion
455  // instead of recursion. Still quite costly ... This will
456  // never return an integral constant, as the pack-expansion
457  // loop cannot return values.
458  using IndexArg = Seq<Indices...>;
459  using LeftIndexArg = SequenceSlice<IndexArg, LeftArgs>;
460  using RightIndexArg = SequenceSlice<IndexArg, RightArgs>;
461  auto accu = FieldType(0);
462 #if DUNE_ACFEM_TENSOR_WORKAROUND_GCC(7)
463  forLoop<multiDim(DefectSignature{})>(GCCBugCompensator2<LeftIndexArg, RightIndexArg>(accu, operand(0_c), operand(1_c)));
464 # else
465  forLoop<multiDim(DefectSignature{})>([&](auto i) {
466  using I = decltype(i);
467  using DefectIndices = MultiIndex<I::value, DefectSignature>;
470 
471  assert(!LeftTensorType::isZero(LeftArg{}) || operand(0_c)(LeftArg{}) == 0);
472  assert(!RightTensorType::isZero(RightArg{}) || operand(1_c)(RightArg{}) == 0);
473 
474  if (!LeftTensorType::isZero(LeftArg{}) && !RightTensorType::isZero(RightArg{})) {
475  accu += operand(0_c)(LeftArg{}) * operand(1_c)(RightArg{});
476  }
477  });
478 # endif
479  return accu;
480 #endif
481  }
482 
483  private:
484 
485  template<class LeftInd, class LeftInj, class LeftPos,
486  class RightInd, class RightInj, class RightPos,
487  std::size_t... N>
488  static bool constexpr isZeroExpander(LeftInd, LeftInj, LeftPos,
489  RightInd, RightInj, RightPos,
490  Seq<N...>)
491  {
492  return (... && (LeftTensorType::isZero(InsertAt<PermuteSequence<MultiIndex<N, DefectSignature>, LeftPermutation>, LeftInd, LeftInj>{}, LeftPos{})
493  ||
494  RightTensorType::isZero(InsertAt<PermuteSequence<MultiIndex<N, DefectSignature>, RightPermutation>, RightInd, RightInj>{}, RightPos{})));
495  }
496 
497  template<std::size_t... Indices, class Pos = MakeIndexSequence<sizeof...(Indices)> >
498  static bool constexpr isZeroWorker(Seq<Indices...> = Seq<Indices...>{}, Pos = Pos{})
499  {
500  // This ain't pretty ;)
501 
502  // Truncate Pos to size of Indices and sort it, otherwise
503  // the JoinedDefects stuff will not work properly
504  using TruncatedPos = HeadPart<sizeof...(Indices), Pos>;
505 
506  using RealPos = typename SortSequence<TruncatedPos>::Result;
507  using Permutation = typename SortSequence<TruncatedPos>::Permutation;
508  using RealIndices = PermuteSequence<Seq<Indices...>, Permutation>;
509 
510  // Extract the indices referring to the left Tensor
511  using LeftPosInd =
512  TransformSequence<RealPos, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, 0, leftRank_> >;
513  using LeftPos = SequenceSlice<RealPos, LeftPosInd>;
514  using LeftInd = SequenceSlice<RealIndices, LeftPosInd>;
515 
516  // The total set of left index positions, contraction and taken from Indices...
517  using LeftTotalPos = JoinedDefects<LeftSortedPos, LeftPos>;
518  using LeftInj = JoinedInjections<LeftSortedPos, LeftPos>;
519 
520  // Extract the indices referring to the right Tensor
521  using RightPosInd =
522  TransformSequence<RealPos, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, leftRank_, leftRank_ + rightRank_> >;
523  using RightPos = TransformSequence<SequenceSlice<RealPos, RightPosInd>, Seq<>, OffsetFunctor<-(ssize_t)leftRank_> >;
524  using RightInd = SequenceSlice<RealIndices, RightPosInd>;
525 
526  // The total set of right index positions, contraction and taken from Indices...
527  using RightTotalPos = JoinedDefects<RightSortedPos, RightPos>;
528  using RightInj = JoinedInjections<RightSortedPos, RightPos>;
529 
530  return isZeroExpander(LeftInd{}, LeftInj{}, LeftTotalPos{},
531  RightInd{}, RightInj{}, RightTotalPos{},
532  MakeIndexSequence<multiDim(DefectSignature{})>{});
533  }
534 
535  public:
536  template<std::size_t... Indices, class Pos = MakeIndexSequence<sizeof...(Indices)> >
537  static bool constexpr isZero(Seq<Indices...> = Seq<Indices...>{}, Pos = Pos{})
538  {
539  if constexpr (ExpressionTraits<LeftTensorType>::isZero || ExpressionTraits<RightTensorType>::isZero) {
540  return true;
541  } else {
542  return isZeroWorker(Seq<Indices...>{}, Pos{});
543  }
544  }
545 
546  std::string name() const
547  {
548  using TL = LeftTensor;
549  using TR = RightTensor;
550  std::string pfxL = std::is_reference<TL>::value ? (RefersConst<TL>::value ? "cref" : "ref") : "";
551  std::string pfxR = std::is_reference<TR>::value ? (RefersConst<TR>::value ? "cref" : "ref") : "";
552 
553  return operationName(operation(), pfxL+operand(0_c).name(), pfxR+operand(1_c).name());
554  }
555 
556  };
557 
559  template<class Seq1, class Seq2, class T1, class T2,
560  std::enable_if_t<AreProperTensors<T1, T2>::value, int> = 0>
561  constexpr decltype(auto) einsum(T1&& t1, T2&& t2)
562  {
563  return finalize(EinsumFunctor<T1, Seq1, T2, Seq2>{}, std::forward<T1>(t1), std::forward<T2>(t2));
564  }
565 
569  template<class T1, class T2,
570  std::enable_if_t<AreProperTensors<T1, T2>::value, int> = 0>
571  constexpr decltype(auto) einsum(T1&& t1, T2&& t2)
572  {
573  constexpr std::size_t numDims = CommonHead<typename TensorTraits<T1>::Signature,
574  typename TensorTraits<T2>::Signature>::value;
575  using ContractPos = MakeIndexSequence<numDims>;
576  return einsum<ContractPos, ContractPos>(std::forward<T1>(t1), std::forward<T2>(t2));
577  }
578 
582  template<class Seq0, class Seq1, class Dims, class T0, class T1>
583  constexpr auto operate(Expressions::DontOptimize, OperationTraits<EinsumOperation<Seq0, Seq1, Dims> >, T0&& t0, T1&& t1)
584  {
585  DUNE_ACFEM_RECORD_OPTIMIZATION;
586 
587  return EinsteinSummation<T0, Seq0, T1, Seq1>(std::forward<T0>(t0), std::forward<T1>(t1));
588  }
589 
590  } // NS Tensor
591 
592  using Tensor::operate;
593  using Tensor::einsum;
594 
596 
597  } // NS ACFem
598 
599  template<class LeftTensor, class LeftIndices, class RightTensor, class RightIndices>
600  struct FieldTraits<ACFem::Tensor::EinsteinSummation<LeftTensor, LeftIndices, RightTensor, RightIndices> >
601  : FieldTraits<ACFem::FloatingPointClosure<typename ACFem::FieldPromotion<LeftTensor, RightTensor>::Type> >
602  {};
603 
605 
606 } // NS Dune
607 
608 #endif // __DUNE_ACFEM_TENSORS_OPERATIONS_EINSUM_HH__
Contraction of two tensor over a selection set of indices.
Definition: einsum.hh:117
constexpr auto operator()(Seq< Indices... >) const
Constant access from index-sequence.
Definition: einsum.hh:383
OptimizeTag< 0 > DontOptimize
Bottom level is overloaded to do nothing.
Definition: optimizationbase.hh:74
std::string operationName(F &&f, const std::string &arg)
Verbose print of an operation, helper function to produce noise.
Definition: operationtraits.hh:601
typename FloatingPointClosureHelper< T >::Type FloatingPointClosure
Template alias.
Definition: fieldpromotion.hh:74
BoolConstant< ExpressionTraits< T >::isTypedValue > IsTypedValue
Compile-time true if T is a "typed value", e.g. a std::integral_constant.
Definition: expressiontraits.hh:90
constexpr auto addLoop(F &&f, T &&init, IndexConstant< N >=IndexConstant< N >{})
Version with just a plain number as argument.
Definition: foreach.hh:110
typename GetHeadPartHelper< Cnt, Seq >::Type HeadPart
Extract Cnt many consecutive elements from the front of Seq.
Definition: access.hh:217
constexpr std::size_t size()
Gives the number of elements in tuple-likes and std::integer_sequence.
Definition: size.hh:73
IndexConstant< CommonHeadHelper< 0UL, Seq1, Seq2 >::value > CommonHead
Compute the number of identical indices at the head of the sequence.
Definition: filter.hh:173
MakeSequence< std::size_t, N, Offset, Stride, Repeat > MakeIndexSequence
Make a sequence of std::size_t elements.
Definition: generators.hh:34
MultiplyOffsetFunctor< 1, Offset > OffsetFunctor
Offset functor.
Definition: transform.hh:335
typename SequenceCatHelper2< S... >::Type SequenceCat
Concatenate the given sequences, in order, to <S0, S1, ...
Definition: transform.hh:220
typename InsertAtHelper< Sequence< typename Input::value_type >, Input, Inject, Pos, AssumeSorted >::Type InsertAt
Insert Inject into the sequence Input at the position specified by Pos.
Definition: insertat.hh:87
auto insertAt(SrcTuple &&src, DataTuple &&data, BoolConstant< AssumeSorted >=BoolConstant< AssumeSorted >{})
Insert the elements of data at positions given by pos into src in turn.
Definition: subtuple.hh:106
auto forwardSubTuple(const TupleLike &t, IndexSequence< I... >=IndexSequence< I... >{})
Like subTuble() but forward the arguments, possibly in order to be expanded as parameters to another ...
Definition: subtuple.hh:162
constexpr auto operate(OptimizeNext< EinsumTag >, F &&f, T0 &&t0, T1 &&t1)
Definition: einsum.hh:327
constexpr decltype(auto) einsum(T1 &&t1, T2 &&t2)
Greedy tensor contraction for proper tensors, summing over all matching dimensions.
Definition: einsum.hh:571
typename EinsumOperationTraits< LDims, LPos, RDims, RPos >::Signature EinsumSignature
Compute the signature of the einsum tensor.
Definition: einsum.hh:84
typename EinsumOperationTraits< LDims, LPos, RDims, RPos >::Functor EinsumFunctor
Generate an einsum-functor.
Definition: einsum.hh:91
constexpr bool isSimple(Sequence< T, V... > seq)
Definition: compare.hh:334
decltype(isIntegralPack(std::declval< T >()...)) IsIntegralPack
Decide whether the given parameter pack contains only integral types.
Definition: compare.hh:377
std::decay_t< decltype(permute(Sequence{}, Perm{}))> PermuteSequence
Apply the given permutation to the positions of the given sequence.
Definition: permutation.hh:137
Sequence< typename Seq::value_type, Get< Indices, Seq >::value... > SubSequence
Create a subsequence containing all values designated by the positions given by Indices.
Definition: sequenceslice.hh:16
static constexpr std::size_t multiDim(IndexSequence< Dimensions... >, IndexSequence< IndexPositions... >=IndexSequence< IndexPositions... >{})
Compute the "dimension" corresponding to the given signature, i.e.
Definition: multiindex.hh:171
std::decay_t< decltype(multiIndex< I >(DimSeq{}))> MultiIndex
Generate the multi-index corresponding to the flattened index I where the multiindex varies between t...
Definition: multiindex.hh:68
Einstein summation, i.e.
Definition: expressionoperations.hh:308
Base class for all tensors.
Definition: tensorbase.hh:144
Creative Commons License   |  Legal Statements / Impressum  |  Hosted by TU Dresden  |  generated with Hugo v0.80.0 (May 5, 22:29, 2024)