- Home
- About DUNE
- Download
- Documentation
- Community
- Development
00001 #ifndef DUNE_MultiTypeMATRIX_HH 00002 #define DUNE_MultiTypeMATRIX_HH 00003 00004 #include<cmath> 00005 #include<iostream> 00006 00007 #include "istlexception.hh" 00008 00009 #ifdef HAVE_BOOST_FUSION 00010 00011 #include <boost/fusion/sequence.hpp> 00012 #include <boost/fusion/container.hpp> 00013 #include <boost/fusion/iterator.hpp> 00014 #include <boost/typeof/typeof.hpp> 00015 #include <boost/fusion/algorithm.hpp> 00016 00017 namespace mpl=boost::mpl; 00018 namespace fusion=boost::fusion; 00019 00020 // forward decl 00021 namespace Dune 00022 { 00023 template<typename T1, typename T2=fusion::void_, typename T3=fusion::void_, typename T4=fusion::void_, 00024 typename T5=fusion::void_, typename T6=fusion::void_, typename T7=fusion::void_, 00025 typename T8=fusion::void_, typename T9=fusion::void_> 00026 class MultiTypeBlockMatrix; 00027 00028 template<int I, int crow, int remain_row> 00029 class MultiTypeBlockMatrix_Solver; 00030 } 00031 00032 #include "gsetc.hh" 00033 00034 namespace Dune { 00035 00053 template<int crow, int remain_rows, int ccol, int remain_cols, 00054 typename TMatrix> 00055 class MultiTypeBlockMatrix_Print { 00056 public: 00057 00061 static void print(const TMatrix& m) { 00062 std::cout << "\t(" << crow << ", " << ccol << "): \n" << fusion::at_c<ccol>( fusion::at_c<crow>(m)); 00063 MultiTypeBlockMatrix_Print<crow,remain_rows,ccol+1,remain_cols-1,TMatrix>::print(m); //next column 00064 } 00065 }; 00066 template<int crow, int remain_rows, int ccol, typename TMatrix> //specialization for remain_cols=0 00067 class MultiTypeBlockMatrix_Print<crow,remain_rows,ccol,0,TMatrix> { 00068 public: static void print(const TMatrix& m) { 00069 static const int xlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00070 MultiTypeBlockMatrix_Print<crow+1,remain_rows-1,0,xlen,TMatrix>::print(m); //next row 00071 } 00072 }; 00073 00074 template<int crow, int ccol, int remain_cols, typename TMatrix> //recursion end: specialization for remain_rows=0 00075 class MultiTypeBlockMatrix_Print<crow,0,ccol,remain_cols,TMatrix> { 00076 public: 00077 static void print(const TMatrix& m) 00078 {std::cout << std::endl;} 00079 }; 00080 00081 00082 00083 //make MultiTypeBlockVector_Ident known (for MultiTypeBlockMatrix_Ident) 00084 template<int count, typename T1, typename T2> 00085 class MultiTypeBlockVector_Ident; 00086 00087 00100 template<int rowcount, typename T1, typename T2> 00101 class MultiTypeBlockMatrix_Ident { 00102 public: 00103 00108 static void equalize(T1& a, const T2& b) { 00109 MultiTypeBlockVector_Ident< mpl::size< typename mpl::at_c<T1,rowcount-1>::type >::value ,T1,T2>::equalize(a,b); //rows are cvectors 00110 MultiTypeBlockMatrix_Ident<rowcount-1,T1,T2>::equalize(a,b); //iterate over rows 00111 } 00112 }; 00113 00114 //recursion end for rowcount=0 00115 template<typename T1, typename T2> 00116 class MultiTypeBlockMatrix_Ident<0,T1,T2> { 00117 public: 00118 static void equalize (T1& a, const T2& b) 00119 {} 00120 }; 00121 00127 template<int crow, int remain_rows, int ccol, int remain_cols, 00128 typename TVecY, typename TMatrix, typename TVecX> 00129 class MultiTypeBlockMatrix_VectMul { 00130 public: 00131 00135 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) { 00136 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).umv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00137 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x); 00138 } 00139 00143 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) { 00144 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00145 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x); 00146 } 00147 00148 template<typename AlphaType> 00149 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) { 00150 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).usmv(alpha, fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00151 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x); 00152 } 00153 00154 00155 }; 00156 00157 //specialization for remain_cols = 0 00158 template<int crow, int remain_rows,int ccol, typename TVecY, 00159 typename TMatrix, typename TVecX> 00160 class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> { //start iteration over next row 00161 00162 public: 00166 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) { 00167 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00168 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x); 00169 } 00170 00174 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) { 00175 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00176 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x); 00177 } 00178 00179 template <typename AlphaType> 00180 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) { 00181 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00182 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x); 00183 } 00184 }; 00185 00186 //specialization for remain_rows = 0 00187 template<int crow, int ccol, int remain_cols, typename TVecY, 00188 typename TMatrix, typename TVecX> 00189 class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> { 00190 //end recursion 00191 public: 00192 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {} 00193 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {} 00194 00195 template<typename AlphaType> 00196 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {} 00197 }; 00198 00199 00200 00201 00202 00203 00209 template<typename T1, typename T2, typename T3, typename T4, 00210 typename T5, typename T6, typename T7, typename T8, typename T9> 00211 class MultiTypeBlockMatrix : public fusion::vector<T1, T2, T3, T4, T5, T6, T7, T8, T9> { 00212 00213 public: 00214 00218 typedef MultiTypeBlockMatrix<T1, T2, T3, T4, T5, T6, T7, T8, T9> type; 00219 00220 typedef typename mpl::at_c<T1,0>::type field_type; 00221 00225 template<typename T> 00226 void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); } 00227 00231 template<typename X, typename Y> 00232 void mv (const X& x, Y& y) const { 00233 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00234 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00235 00236 y = 0; //reset y (for mv uses umv) 00237 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements 00238 } 00239 00243 template<typename X, typename Y> 00244 void umv (const X& x, Y& y) const { 00245 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00246 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00247 00248 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements 00249 } 00250 00254 template<typename X, typename Y> 00255 void mmv (const X& x, Y& y) const { 00256 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00257 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00258 00259 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements 00260 } 00261 00263 template<typename AlphaType, typename X, typename Y> 00264 void usmv (const AlphaType& alpha, const X& x, Y& y) const { 00265 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00266 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00267 00268 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements 00269 00270 } 00271 00272 00273 00274 }; 00275 00276 00277 00283 template<typename T1, typename T2, typename T3, typename T4, typename T5, 00284 typename T6, typename T7, typename T8, typename T9> 00285 std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) { 00286 static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value; //row count 00287 static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value; //col count of first row 00288 MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m); 00289 return s; 00290 } 00291 00292 00293 00294 00295 00296 //make algmeta_itsteps known 00297 template<int I> 00298 struct algmeta_itsteps; 00299 00300 00301 00302 00303 00304 00311 template<int I, int crow, int ccol, int remain_col> //MultiTypeBlockMatrix_Solver_Col: iterating over one row 00312 class MultiTypeBlockMatrix_Solver_Col { //calculating b- A[i][j]*x[j] 00313 public: 00317 template <typename Trhs, typename TVector, typename TMatrix, typename K> 00318 static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) { 00319 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), b ); 00320 MultiTypeBlockMatrix_Solver_Col<I, crow, ccol+1, remain_col-1>::calc_rhs(A,x,v,b,w); //next column element 00321 } 00322 00323 }; 00324 template<int I, int crow, int ccol> //MultiTypeBlockMatrix_Solver_Col recursion end 00325 class MultiTypeBlockMatrix_Solver_Col<I,crow,ccol,0> { 00326 public: 00327 template <typename Trhs, typename TVector, typename TMatrix, typename K> 00328 static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {} 00329 }; 00330 00331 00332 00339 template<int I, int crow, int remain_row> 00340 class MultiTypeBlockMatrix_Solver { 00341 public: 00342 00346 template <typename TVector, typename TMatrix, typename K> 00347 static void dbgs(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00348 TVector xold(x); 00349 xold=x; //store old x values 00350 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbgs(A,x,x,b,w); 00351 x *= w; 00352 x.axpy(1-w,xold); //improve x 00353 } 00354 template <typename TVector, typename TMatrix, typename K> 00355 static void dbgs(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00356 typename mpl::at_c<TVector,crow>::type rhs; 00357 rhs = fusion::at_c<crow> (b); 00358 00359 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00360 //solve on blocklevel I-1 00361 algmeta_itsteps<I-1>::dbgs(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(x),rhs,w); 00362 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbgs(A,x,v,b,w); //next row 00363 } 00364 00365 00366 00370 template <typename TVector, typename TMatrix, typename K> 00371 static void bsorf(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00372 TVector v; 00373 v=x; //use latest x values in right side calculation 00374 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorf(A,x,v,b,w); 00375 00376 } 00377 template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A) 00378 static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00379 typename mpl::at_c<TVector,crow>::type rhs; 00380 rhs = fusion::at_c<crow> (b); 00381 00382 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00383 //solve on blocklevel I-1 00384 algmeta_itsteps<I-1>::bsorf(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00385 fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v)); 00386 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::bsorf(A,x,v,b,w); //next row 00387 } 00388 00392 template <typename TVector, typename TMatrix, typename K> 00393 static void bsorb(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00394 TVector v; 00395 v=x; //use latest x values in right side calculation 00396 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorb(A,x,v,b,w); 00397 00398 } 00399 template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A) 00400 static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00401 typename mpl::at_c<TVector,crow>::type rhs; 00402 rhs = fusion::at_c<crow> (b); 00403 00404 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00405 //solve on blocklevel I-1 00406 algmeta_itsteps<I-1>::bsorb(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00407 fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v)); 00408 MultiTypeBlockMatrix_Solver<I,crow-1,remain_row-1>::bsorb(A,x,v,b,w); //next row 00409 } 00410 00411 00415 template <typename TVector, typename TMatrix, typename K> 00416 static void dbjac(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00417 TVector v(x); 00418 v=0; //calc new x in v 00419 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbjac(A,x,v,b,w); 00420 x.axpy(w,v); //improve x 00421 } 00422 template <typename TVector, typename TMatrix, typename K> 00423 static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00424 typename mpl::at_c<TVector,crow>::type rhs; 00425 rhs = fusion::at_c<crow> (b); 00426 00427 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00428 //solve on blocklevel I-1 00429 algmeta_itsteps<I-1>::dbjac(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00430 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbjac(A,x,v,b,w); //next row 00431 } 00432 00433 00434 00435 00436 }; 00437 template<int I, int crow> //recursion end for remain_row = 0 00438 class MultiTypeBlockMatrix_Solver<I,crow,0> { 00439 public: 00440 template <typename TVector, typename TMatrix, typename K> 00441 static void dbgs(const TMatrix& A, TVector& x, TVector& v, 00442 const TVector& b, const K& w) {} 00443 00444 template <typename TVector, typename TMatrix, typename K> 00445 static void bsorf(const TMatrix& A, TVector& x, TVector& v, 00446 const TVector& b, const K& w) {} 00447 00448 template <typename TVector, typename TMatrix, typename K> 00449 static void bsorb(const TMatrix& A, TVector& x, TVector& v, 00450 const TVector& b, const K& w) {} 00451 00452 template <typename TVector, typename TMatrix, typename K> 00453 static void dbjac(const TMatrix& A, TVector& x, TVector& v, 00454 const TVector& b, const K& w) {} 00455 }; 00456 00457 } // end namespace 00458 00459 #endif // HAVE_BOOST_FUSION 00460 00461 #endif 00462
Generated on Fri Apr 29 2011 with Doxygen (ver 1.7.1) [doxygen-log,error-log].