4 #include "../../expressions/storage.hh"
5 #include "../../expressions/expressionoperations.hh" // FieldPromotion
6 #include "../../mpl/insertat.hh"
8 #include "../tensorbase.hh"
9 #include "../modules.hh"
10 #include "restrictiondetail.hh"
12 namespace Dune {
14  namespace ACFem {
16  // forward
17  template<class Pos1, class Pos2, class ProductDims>
18  struct TensorProductOperation;
28  namespace Tensor {
65  template<class LeftTensor, class LeftIndices, class RightTensor, class RightIndices>
68  template<class Left, class LPos, class Right, class RPos, class SFINAE = void>
69  struct ProductOperationTraits
70  {
71  using Signature = SequenceCat<SequenceSlice<Left, LPos>,
72  SequenceSliceComplement<Left, LPos>,
73  SequenceSliceComplement<Right, RPos> >;
75  using Functor = OperationTraits<Operation>;
76  };
78  template<class Left, class LPos, class Right, class RPos>
79  struct ProductOperationTraits<Left, LPos, Right, RPos,
80  std::enable_if_t<(IsTensorOperand<Left>::value
81  && IsTensorOperand<Right>::value
82  )> >
83  : ProductOperationTraits<typename TensorTraits<Left>::Signature, LPos,
84  typename TensorTraits<Right>::Signature, RPos>
85  {};
91  template<class Left, class LPos, class Right, class RPos>
92  using ProductSignature = typename ProductOperationTraits<typename TensorTraits<Left>::TensorType, LPos,
93  typename TensorTraits<Right>::TensorType, RPos>::Signature;
95  template<class Left, class LPos, class Right, class RPos>
96  using ProductFunctor = typename ProductOperationTraits<typename TensorTraits<Left>::TensorType, LPos,
97  typename TensorTraits<Right>::TensorType, RPos>::Functor;
99  template<class T>
103  typename TensorTraits<T>::Signature
104  >;
106  template<class LeftTensor, std::size_t... LeftIndices, class RightTensor, std::size_t... RightIndices>
107  class ProductTensor<LeftTensor, Seq<LeftIndices...>, RightTensor, Seq<RightIndices...> >
108  : public TensorBase<typename FieldPromotion<LeftTensor, RightTensor>::Type,
109  ProductSignature<LeftTensor, Seq<LeftIndices...>, RightTensor, Seq<RightIndices...> >,
110  ProductTensor<LeftTensor, Seq<LeftIndices...>, RightTensor, Seq<RightIndices...> > >
111  , public Expressions::Storage<ProductFunctor<LeftTensor, Seq<LeftIndices...>, RightTensor, Seq<RightIndices...> >,
112  LeftTensor, RightTensor>
113  // Expressions model rvalues. We are constant if the underlying
114  // expression claims to be so or if we store a copy and the
115  // underlying expression is "independent" (i.e. determined at
116  // runtime only by the contained data)
117  {
118  using ThisType = ProductTensor;
119  using LeftType = LeftTensor;
120  using RightType = RightTensor;
121  using LeftTensorType = std::decay_t<LeftTensor>;
122  using RightTensorType = std::decay_t<RightTensor>;
123  using LeftSignature = typename LeftTensorType::Signature;
124  using RightSignature = typename RightTensorType::Signature;
125  using LeftDefectSignature = SubSequence<LeftSignature, LeftIndices...>;
126  using RightDefectSignature = SubSequence<RightSignature, RightIndices...>;
127  using LeftSignatureRest = SubSequenceComplement<LeftSignature, LeftIndices...>;
128  using RightSignatureRest = SubSequenceComplement<RightSignature, RightIndices...>;
129  public:
130  using LeftIndexPositions = Seq<LeftIndices...>;
131  using RightIndexPositions = Seq<RightIndices...>;
133  using FieldType = typename FieldPromotion<LeftType, RightType>::Type;
134  using FunctorType = OperationTraits<TensorProductOperation<LeftIndexPositions, RightIndexPositions, LeftDefectSignature> >;
135  private:
136  // Mmmh. Should we really support this?
137  using LeftSorted = SortSequence<LeftIndexPositions>;
138  using RightSorted = SortSequence<RightIndexPositions>;
139  using LeftPermutation = typename LeftSorted::Permutation;
140  using RightPermutation = typename RightSorted::Permutation;
141  using LeftSortedPos = typename LeftSorted::Result;
142  using RightSortedPos = typename RightSorted::Result;
143  static constexpr bool leftIsSorted_ = isSimple(LeftPermutation{});
144  static constexpr bool rightIsSorted_ = isSimple(RightPermutation{});
147  using StorageType = Expressions::Storage<FunctorType, LeftTensor, RightTensor>;
148  public:
149  using StorageType::operation;
150  using StorageType::operand;
151  using BaseType::rank;
152  static constexpr std::size_t leftRank_ = LeftSignatureRest::size();
153  static constexpr std::size_t rightRank_ = RightSignatureRest::size();
154  static constexpr std::size_t frontRank_ = rank - leftRank_ - rightRank_;
155  using FrontArgs = MakeIndexSequence<frontRank_>;
158  private:
160  static_assert(sizeof...(LeftIndices) == sizeof...(RightIndices),
161  "Number of contraction indices must coincide.");
162  static_assert(LeftTensorType::rank >= sizeof...(LeftIndices),
163  "Left: Number of index-positions must fit into tensor rank.");
164  static_assert(RightTensorType::rank >= sizeof...(RightIndices),
165  "Right: Number of index-positions must fit into tensor rank.");
166  static_assert(std::is_same<LeftDefectSignature, RightDefectSignature>::value,
167  "Left- and right defect-dimensions must coincide.");
169  public:
170  using DefectSignature = LeftDefectSignature;
171  static constexpr std::size_t defectRank_ = DefectSignature::size();
173  template<class LeftArg, class RightArg,
174  std::enable_if_t<std::is_constructible<LeftType, LeftArg>::value && std::is_constructible<RightType, RightArg>::value, int> = 0>
175  ProductTensor(LeftArg&& left, RightArg&& right)
176  : StorageType(std::forward<LeftArg>(left), std::forward<RightArg>(right))
177  {}
180  template<
181  class... Dummy,
182  std::enable_if_t<(sizeof...(Dummy) == 0
184  && IsTypedValue<RightType>::value), int> = 0>
185  ProductTensor(Dummy&&...)
186  : StorageType(LeftType{}, RightType{})
187  {}
192  template<class... Dims,
193  std::enable_if_t<(sizeof...(Dims) == rank
194  &&
196  , int> = 0>
197  auto operator()(Dims... indices) const
198  {
199  auto frontIndices = forwardSubTuple(std::forward_as_tuple(indices...), FrontArgs{});
200  auto leftIndices = forwardSubTuple(std::forward_as_tuple(indices...), LeftArgs{});
201  auto rightIndices = forwardSubTuple(std::forward_as_tuple(indices...), RightArgs{});
203  return
204  tensorValue(operand(0_c), insertAt(leftIndices, permute(frontIndices, LeftPermutation{}), LeftSortedPos{}))
205  *
206  tensorValue(operand(1_c), insertAt(rightIndices, permute(frontIndices, RightPermutation{}), RightSortedPos{}));
207  }
210  template<std::size_t... Indices,
211  std::enable_if_t<(sizeof...(Indices) == rank
212  && ThisType::template isZero<Indices...>()
213  ), int> = 0>
214  auto constexpr operator()(Seq<Indices...>) const
215  {
216  return IntFraction<0>{};
217  }
220  template<std::size_t... Indices,
221  std::enable_if_t<(sizeof...(Indices) == rank
222  && !ThisType::template isZero<Indices...>()
223  ), int> = 0>
224  auto constexpr operator()(Seq<Indices...>) const
225  {
226  using IndexSeq = Seq<Indices...>;
227  using ArgFrontIndices = SequenceSlice<IndexSeq, FrontArgs>;
228  using ArgLeftIndices = SequenceSlice<IndexSeq, LeftArgs>;
229  using ArgRightIndices = SequenceSlice<IndexSeq, RightArgs>;
233  return operand(0_c)(LeftArg{}) * operand(1_c)(RightArg{});
234  }
236  private:
237  template<std::size_t... Indices, class Pos = MakeIndexSequence<sizeof...(Indices)> >
238  static bool constexpr isZeroWorker(Seq<Indices...> = Seq<Indices...>{}, Pos = Pos{})
239  {
240  // This ain't pretty ;)
242  // Truncate Pos to size of Indices and sort it, otherwise
243  // the JoinedDefects stuff will not work properly
244  using TruncatedPos = HeadPart<sizeof...(Indices), Pos>;
245  using RealPos = typename SortSequence<TruncatedPos>::Result;
246  using Permutation = typename SortSequence<TruncatedPos>::Permutation;
247  using RealIndices = PermuteSequence<Seq<Indices...>, Permutation>;
249  // Extract leading product indices
250  using FrontPosInd =
251  TransformSequence<RealPos, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, 0, frontRank_> >;
252  using FrontInd = SequenceSlice<RealIndices, FrontPosInd>;
253  using FrontPos = SequenceSlice<RealPos, FrontPosInd>;
255  using LeftFrontPos = SequenceSlice<LeftIndexPositions, FrontPos>;
256  using RightFrontPos = SequenceSlice<RightIndexPositions, FrontPos>;
258  // Extract the indices referring to the left Tensor
259  using LeftPosInd =
260  TransformSequence<RealPos, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, frontRank_, frontRank_ + leftRank_> >;
261  using LeftPos = TransformSequence<SequenceSlice<RealPos, LeftPosInd>, Seq<>, OffsetFunctor<-(std::ptrdiff_t)frontRank_> >;
262  using LeftInd = SequenceSlice<RealIndices, LeftPosInd>;
264  // The total set of left index positions, contraction and taken from Indices...
265  using LeftLookup = SequenceSliceComplement<MakeIndexSequence<frontRank_+leftRank_>, LeftSortedPos>;
266  using LeftTotalPos = SequenceCat<LeftFrontPos, TransformedSequence<MapSequenceFunctor<LeftLookup>, LeftPos> >;
267  using LeftTotalInd = SequenceCat<FrontInd, LeftInd>;
269  // Extract the indices referring to the right Tensor
270  using RightPosInd =
271  TransformSequence<RealPos, Seq<>, IndexFunctor, AcceptInputInRangeFunctor<std::size_t, frontRank_ + leftRank_, frontRank_ + leftRank_ + rightRank_> >;
272  using RightPos = TransformSequence<SequenceSlice<RealPos, RightPosInd>, Seq<>, OffsetFunctor<-(std::ptrdiff_t)(frontRank_+leftRank_)> >;
273  using RightInd = SequenceSlice<RealIndices, RightPosInd>;
275  // The total set of right index positions, contraction and taken from Indices...
276  using RightLookup = SequenceSliceComplement<MakeIndexSequence<frontRank_+rightRank_>, RightSortedPos>;
277  using RightTotalPos = SequenceCat<RightFrontPos, TransformedSequence<MapSequenceFunctor<RightLookup>, RightPos> >;
278  using RightTotalInd = SequenceCat<FrontInd, RightInd>;
280  return
281  LeftTensorType::isZero(LeftTotalInd{}, LeftTotalPos{})
282  ||
283  RightTensorType::isZero(RightTotalInd{}, RightTotalPos{});
284  }
286  public:
287  template<std::size_t... Indices, class Pos = MakeIndexSequence<sizeof...(Indices)> >
288  static bool constexpr isZero(Seq<Indices...> = Seq<Indices...>{}, Pos = Pos{})
289  {
290 // if constexpr (ExpressionTraits<LeftTensorType>::isZero || ExpressionTraits<RightTensorType>::isZero) {
291 // return true;
292 // } else {
293  return isZeroWorker(Seq<Indices...>{}, Pos{});
294 // }
295  }
297  std::string name() const
298  {
299  using TL = LeftTensor;
300  using TR = RightTensor;
301  std::string pfxL = std::is_reference<TL>::value ? (RefersConst<TL>::value ? "cref" : "ref") : "";
302  std::string pfxR = std::is_reference<TR>::value ? (RefersConst<TR>::value ? "cref" : "ref") : "";
304  return operationName(operation(), pfxL+operand(0_c).name(), pfxR+operand(1_c).name());
305  }
307  }; // ProductTensor class
310  template<class Seq1, class Seq2, class T1, class T2,
311  std::enable_if_t<AreProperTensors<T1, T2>::value, int> = 0>
312  constexpr decltype(auto) multiply(T1&& t1, T2&& t2)
313  {
314  return Expressions::finalize(ProductFunctor<T1, Seq1, T2, Seq2>{}, std::forward<T1>(t1), std::forward<T2>(t2));
315  }
320  template<class T1, class T2,
321  std::enable_if_t<AreProperTensors<T1, T2>::value, int> = 0>
322  constexpr decltype(auto) multiply(T1&& t1, T2&& t2)
323  {
324  constexpr std::size_t numDims = CommonHead<typename TensorTraits<T1>::Signature,
325  typename TensorTraits<T2>::Signature>::value;
326  using ContractPos = MakeIndexSequence<numDims>;
327  return multiply<ContractPos, ContractPos>(std::forward<T1>(t1), std::forward<T2>(t2));
328  }
333  template<class Seq0, class Seq1, class Dims, class T0, class T1>
334  constexpr auto operate(Expressions::DontOptimize, OperationTraits<TensorProductOperation<Seq0, Seq1, Dims> >, T0&& t0, T1&& t1)
335  {
338  return ProductTensor<T0, Seq0, T1, Seq1>(std::forward<T0>(t0), std::forward<T1>(t1));
339  }
341  } // NS Tensor
343  // Point here is that the operations defined in
344  // ./expressionoperations.hh need the name "multiply".
345  using Tensor::multiply;
347  } // NS ACFem
349  template<class LeftTensor, class LeftIndices, class RightTensor, class RightIndices>
350  struct FieldTraits<ACFem::Tensor::ProductTensor<LeftTensor, LeftIndices, RightTensor, RightIndices> >
351  : FieldTraits<typename ACFem::FieldPromotion<LeftTensor, RightTensor>::Type>
352  {};
358 } // NS Dune
