multitypeblockmatrix.hh

Go to the documentation of this file.
00001 #ifndef DUNE_MultiTypeMATRIX_HH
00002 #define DUNE_MultiTypeMATRIX_HH
00003 
00004 #include<cmath>
00005 #include<iostream>
00006 
00007 #include "istlexception.hh"
00008 
00009 #ifdef HAVE_BOOST_FUSION
00010 
00011 #include <boost/fusion/sequence.hpp>
00012 #include <boost/fusion/container.hpp>
00013 #include <boost/fusion/iterator.hpp>
00014 #include <boost/typeof/typeof.hpp>
00015 #include <boost/fusion/algorithm.hpp>
00016 
00017 namespace mpl=boost::mpl;
00018 namespace fusion=boost::fusion;
00019 
00020 // forward decl
00021 namespace Dune
00022 {
00023     template<typename T1, typename T2=fusion::void_, typename T3=fusion::void_, typename T4=fusion::void_,
00024              typename T5=fusion::void_, typename T6=fusion::void_, typename T7=fusion::void_,
00025              typename T8=fusion::void_, typename T9=fusion::void_>
00026     class MultiTypeBlockMatrix;
00027 
00028     template<int I, int crow, int remain_row>
00029     class MultiTypeBlockMatrix_Solver;
00030 }
00031 
00032 #include "gsetc.hh"
00033 
00034 namespace Dune {
00035 
00053   template<int crow, int remain_rows, int ccol, int remain_cols, 
00054            typename TMatrix>
00055   class MultiTypeBlockMatrix_Print {
00056   public:
00057 
00061     static void print(const TMatrix& m) {
00062       std::cout << "\t(" << crow << ", " << ccol << "): \n" << fusion::at_c<ccol>( fusion::at_c<crow>(m));
00063       MultiTypeBlockMatrix_Print<crow,remain_rows,ccol+1,remain_cols-1,TMatrix>::print(m);         //next column
00064     }
00065   };
00066   template<int crow, int remain_rows, int ccol, typename TMatrix> //specialization for remain_cols=0
00067   class MultiTypeBlockMatrix_Print<crow,remain_rows,ccol,0,TMatrix> {
00068   public: static void print(const TMatrix& m) {
00069     static const int xlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00070     MultiTypeBlockMatrix_Print<crow+1,remain_rows-1,0,xlen,TMatrix>::print(m);                   //next row
00071   }
00072   };
00073 
00074   template<int crow, int ccol, int remain_cols, typename TMatrix> //recursion end: specialization for remain_rows=0
00075   class MultiTypeBlockMatrix_Print<crow,0,ccol,remain_cols,TMatrix> {
00076   public: 
00077     static void print(const TMatrix& m) 
00078     {std::cout << std::endl;} 
00079   };
00080 
00081 
00082 
00083   //make MultiTypeBlockVector_Ident known (for MultiTypeBlockMatrix_Ident)
00084   template<int count, typename T1, typename T2>
00085   class MultiTypeBlockVector_Ident;
00086 
00087 
00100   template<int rowcount, typename T1, typename T2>
00101   class MultiTypeBlockMatrix_Ident {
00102   public:
00103 
00108     static void equalize(T1& a, const T2& b) {
00109       MultiTypeBlockVector_Ident< mpl::size< typename mpl::at_c<T1,rowcount-1>::type >::value ,T1,T2>::equalize(a,b);              //rows are cvectors
00110       MultiTypeBlockMatrix_Ident<rowcount-1,T1,T2>::equalize(a,b);         //iterate over rows
00111     }
00112   };
00113 
00114   //recursion end for rowcount=0
00115   template<typename T1, typename T2>
00116   class MultiTypeBlockMatrix_Ident<0,T1,T2> {
00117   public: 
00118     static void equalize (T1& a, const T2& b) 
00119     {} 
00120   };
00121 
00127   template<int crow, int remain_rows, int ccol, int remain_cols, 
00128            typename TVecY, typename TMatrix, typename TVecX>
00129   class MultiTypeBlockMatrix_VectMul {
00130   public:
00131 
00135     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {                   
00136       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).umv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00137       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x);
00138     }
00139 
00143     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {                   
00144       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00145       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x);
00146     }
00147 
00148     template<typename AlphaType>
00149     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {                  
00150       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).usmv(alpha, fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00151       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
00152     }
00153 
00154                                 
00155   };
00156 
00157   //specialization for remain_cols = 0
00158   template<int crow, int remain_rows,int ccol, typename TVecY, 
00159            typename TMatrix, typename TVecX>
00160   class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> {                                    //start iteration over next row
00161         
00162   public:
00166     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
00167       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00168       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x);
00169     }
00170 
00174     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
00175       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00176       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x);
00177     }
00178 
00179     template <typename AlphaType>
00180     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
00181       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00182       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
00183     }
00184   };
00185 
00186    //specialization for remain_rows = 0
00187   template<int crow, int ccol, int remain_cols, typename TVecY, 
00188            typename TMatrix, typename TVecX>
00189   class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> { 
00190     //end recursion
00191   public:
00192     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {}
00193     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {}
00194 
00195     template<typename AlphaType>
00196     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {}
00197   };
00198 
00199 
00200 
00201 
00202 
00203 
00209   template<typename T1, typename T2, typename T3, typename T4,
00210            typename T5, typename T6, typename T7, typename T8, typename T9>
00211   class MultiTypeBlockMatrix : public fusion::vector<T1, T2, T3, T4, T5, T6, T7, T8, T9> {
00212 
00213   public:
00214 
00218     typedef MultiTypeBlockMatrix<T1, T2, T3, T4, T5, T6, T7, T8, T9> type;
00219 
00220     typedef typename mpl::at_c<T1,0>::type field_type;
00221 
00225     template<typename T>
00226     void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); }
00227 
00231     template<typename X, typename Y>
00232     void mv (const X& x, Y& y) const {
00233       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00234       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00235 
00236       y = 0;                                                                  //reset y (for mv uses umv)
00237       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x);    //iterate over all matrix elements
00238     }
00239 
00243     template<typename X, typename Y>
00244     void umv (const X& x, Y& y) const {
00245       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00246       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00247 
00248       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x);    //iterate over all matrix elements
00249     }
00250 
00254     template<typename X, typename Y>
00255     void mmv (const X& x, Y& y) const {
00256       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00257       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00258 
00259       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::mmv(y, *this, x);    //iterate over all matrix elements
00260     }
00261 
00263     template<typename AlphaType, typename X, typename Y>
00264     void usmv (const AlphaType& alpha, const X& x, Y& y) const {
00265       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00266       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00267 
00268       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::usmv(alpha,y, *this, x);     //iterate over all matrix elements
00269         
00270     }
00271 
00272 
00273 
00274   };
00275 
00276 
00277 
00283   template<typename T1, typename T2, typename T3, typename T4, typename T5, 
00284            typename T6, typename T7, typename T8, typename T9>
00285   std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) {
00286     static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value;            //row count
00287     static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value;       //col count of first row
00288     MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m);
00289     return s;
00290   }
00291 
00292 
00293 
00294 
00295 
00296   //make algmeta_itsteps known
00297   template<int I>
00298   struct algmeta_itsteps;
00299 
00300 
00301 
00302 
00303 
00304 
00311   template<int I, int crow, int ccol, int remain_col>                             //MultiTypeBlockMatrix_Solver_Col: iterating over one row
00312   class MultiTypeBlockMatrix_Solver_Col {                                                      //calculating b- A[i][j]*x[j]
00313   public:
00317     template <typename Trhs, typename TVector, typename TMatrix, typename K>
00318     static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {
00319       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), b );
00320       MultiTypeBlockMatrix_Solver_Col<I, crow, ccol+1, remain_col-1>::calc_rhs(A,x,v,b,w); //next column element
00321     }
00322 
00323   };
00324   template<int I, int crow, int ccol>                                             //MultiTypeBlockMatrix_Solver_Col recursion end
00325   class MultiTypeBlockMatrix_Solver_Col<I,crow,ccol,0> {
00326   public:
00327     template <typename Trhs, typename TVector, typename TMatrix, typename K>
00328     static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {}
00329   };
00330 
00331 
00332 
00339   template<int I, int crow, int remain_row>
00340   class MultiTypeBlockMatrix_Solver {
00341   public:
00342 
00346     template <typename TVector, typename TMatrix, typename K>
00347     static void dbgs(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00348       TVector xold(x);
00349       xold=x;                                                         //store old x values
00350       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbgs(A,x,x,b,w);
00351       x *= w;
00352       x.axpy(1-w,xold);                                                       //improve x
00353     }
00354     template <typename TVector, typename TMatrix, typename K>
00355     static void dbgs(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00356       typename mpl::at_c<TVector,crow>::type rhs;
00357       rhs = fusion::at_c<crow> (b);
00358 
00359       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00360       //solve on blocklevel I-1
00361       algmeta_itsteps<I-1>::dbgs(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(x),rhs,w);
00362       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbgs(A,x,v,b,w); //next row
00363     }
00364 
00365 
00366 
00370     template <typename TVector, typename TMatrix, typename K>
00371     static void bsorf(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00372       TVector v;
00373       v=x;                                                            //use latest x values in right side calculation
00374       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorf(A,x,v,b,w);
00375                 
00376     }
00377     template <typename TVector, typename TMatrix, typename K>               //recursion over all matrix rows (A)
00378     static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00379       typename mpl::at_c<TVector,crow>::type rhs;
00380       rhs = fusion::at_c<crow> (b);
00381 
00382       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00383       //solve on blocklevel I-1
00384       algmeta_itsteps<I-1>::bsorf(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00385       fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
00386       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::bsorf(A,x,v,b,w);        //next row
00387     }
00388 
00392     template <typename TVector, typename TMatrix, typename K>
00393     static void bsorb(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00394       TVector v;
00395       v=x;                                                            //use latest x values in right side calculation
00396       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorb(A,x,v,b,w);
00397                 
00398     }
00399     template <typename TVector, typename TMatrix, typename K>               //recursion over all matrix rows (A)
00400     static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00401       typename mpl::at_c<TVector,crow>::type rhs;
00402       rhs = fusion::at_c<crow> (b);
00403 
00404       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00405       //solve on blocklevel I-1
00406       algmeta_itsteps<I-1>::bsorb(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00407       fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
00408       MultiTypeBlockMatrix_Solver<I,crow-1,remain_row-1>::bsorb(A,x,v,b,w);        //next row
00409     }
00410 
00411 
00415     template <typename TVector, typename TMatrix, typename K>
00416     static void dbjac(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00417       TVector v(x);
00418       v=0;                                                            //calc new x in v
00419       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbjac(A,x,v,b,w);
00420       x.axpy(w,v);                                                    //improve x
00421     }
00422     template <typename TVector, typename TMatrix, typename K>
00423     static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00424       typename mpl::at_c<TVector,crow>::type rhs;
00425       rhs = fusion::at_c<crow> (b);
00426 
00427       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00428       //solve on blocklevel I-1
00429       algmeta_itsteps<I-1>::dbjac(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00430       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbjac(A,x,v,b,w);        //next row
00431     }
00432 
00433 
00434 
00435 
00436   };
00437   template<int I, int crow>                                                       //recursion end for remain_row = 0
00438   class MultiTypeBlockMatrix_Solver<I,crow,0> {
00439   public:
00440     template <typename TVector, typename TMatrix, typename K>
00441     static void dbgs(const TMatrix& A, TVector& x, TVector& v, 
00442                      const TVector& b, const K& w) {}
00443 
00444     template <typename TVector, typename TMatrix, typename K>
00445     static void bsorf(const TMatrix& A, TVector& x, TVector& v, 
00446                       const TVector& b, const K& w) {}
00447 
00448     template <typename TVector, typename TMatrix, typename K>
00449     static void bsorb(const TMatrix& A, TVector& x, TVector& v, 
00450                       const TVector& b, const K& w) {}
00451 
00452     template <typename TVector, typename TMatrix, typename K>
00453     static void dbjac(const TMatrix& A, TVector& x, TVector& v, 
00454                       const TVector& b, const K& w) {}
00455   };
00456 
00457 } // end namespace
00458 
00459 #endif // HAVE_BOOST_FUSION
00460 
00461 #endif
00462 

Generated on Fri Apr 29 2011 with Doxygen (ver 1.7.1) [doxygen-log,error-log].