00001
00002 #ifndef DUNE_MPICOLLECTIVECOMMUNICATION_HH
00003 #define DUNE_MPICOLLECTIVECOMMUNICATION_HH
00004
00005 #include<iostream>
00006 #include<complex>
00007 #include<algorithm>
00008 #include<functional>
00009
00010 #include"exceptions.hh"
00011 #include"collectivecommunication.hh"
00012 #include"binaryfunctions.hh"
00013 #include"shared_ptr.hh"
00014
00015 #if HAVE_MPI
00016
00017 #include<mpi.h>
00018
00019 namespace Dune
00020 {
00021
00022
00023
00024
00025
00026
00027
00028 template<typename T>
00029 class Generic_MPI_Datatype
00030 {
00031 public:
00032 static MPI_Datatype get ()
00033 {
00034 if (!type)
00035 {
00036 type = shared_ptr<MPI_Datatype>(new MPI_Datatype);
00037 MPI_Type_contiguous(sizeof(T),MPI_BYTE,type.get());
00038 MPI_Type_commit(type.get());
00039 }
00040 return *type;
00041 }
00042 private:
00043 Generic_MPI_Datatype () {}
00044 Generic_MPI_Datatype (const Generic_MPI_Datatype& ) {}
00045 static shared_ptr<MPI_Datatype> type;
00046 };
00047
00048 template<typename T>
00049 shared_ptr<MPI_Datatype> Generic_MPI_Datatype<T>::type = shared_ptr<MPI_Datatype>(static_cast<MPI_Datatype*>(0));
00050
00051
00052 #define ComposeMPITraits(p,m) \
00053 template<> \
00054 struct Generic_MPI_Datatype<p>{ \
00055 static inline MPI_Datatype get(){ \
00056 return m; \
00057 } \
00058 }
00059
00060
00061 ComposeMPITraits(char, MPI_CHAR);
00062 ComposeMPITraits(unsigned char,MPI_UNSIGNED_CHAR);
00063 ComposeMPITraits(short,MPI_SHORT);
00064 ComposeMPITraits(unsigned short,MPI_UNSIGNED_SHORT);
00065 ComposeMPITraits(int,MPI_INT);
00066 ComposeMPITraits(unsigned int,MPI_UNSIGNED);
00067 ComposeMPITraits(long,MPI_LONG);
00068 ComposeMPITraits(unsigned long,MPI_UNSIGNED_LONG);
00069 ComposeMPITraits(float,MPI_FLOAT);
00070 ComposeMPITraits(double,MPI_DOUBLE);
00071 ComposeMPITraits(long double,MPI_LONG_DOUBLE);
00072
00073 #undef ComposeMPITraits
00074
00075
00076
00077
00078
00079
00080
00081 template<typename Type, typename BinaryFunction>
00082 class Generic_MPI_Op
00083 {
00084
00085 public:
00086 static MPI_Op get ()
00087 {
00088 if (!op)
00089 {
00090 op = shared_ptr<MPI_Op>(new MPI_Op);
00091 MPI_Op_create((void (*)(void*, void*, int*, MPI_Datatype*))&operation,true,op.get());
00092 }
00093 return *op;
00094 }
00095 private:
00096 static void operation (Type *in, Type *inout, int *len, MPI_Datatype *dptr)
00097 {
00098 BinaryFunction func;
00099
00100 for (int i=0; i< *len; ++i, ++in, ++inout){
00101 Type temp;
00102 temp = func(*in, *inout);
00103 *inout = temp;
00104 }
00105 }
00106 Generic_MPI_Op () {}
00107 Generic_MPI_Op (const Generic_MPI_Op& ) {}
00108 static shared_ptr<MPI_Op> op;
00109 };
00110
00111
00112 template<typename Type, typename BinaryFunction>
00113 shared_ptr<MPI_Op> Generic_MPI_Op<Type,BinaryFunction>::op = shared_ptr<MPI_Op>(static_cast<MPI_Op*>(0));
00114
00115 #define ComposeMPIOp(type,func,op) \
00116 template<> \
00117 class Generic_MPI_Op<type, func<type> >{ \
00118 public:\
00119 static MPI_Op get(){ \
00120 return op; \
00121 } \
00122 private:\
00123 Generic_MPI_Op () {}\
00124 Generic_MPI_Op (const Generic_MPI_Op& ) {}\
00125 }
00126
00127
00128 ComposeMPIOp(char, std::plus, MPI_SUM);
00129 ComposeMPIOp(unsigned char, std::plus, MPI_SUM);
00130 ComposeMPIOp(short, std::plus, MPI_SUM);
00131 ComposeMPIOp(unsigned short, std::plus, MPI_SUM);
00132 ComposeMPIOp(int, std::plus, MPI_SUM);
00133 ComposeMPIOp(unsigned int, std::plus, MPI_SUM);
00134 ComposeMPIOp(long, std::plus, MPI_SUM);
00135 ComposeMPIOp(unsigned long, std::plus, MPI_SUM);
00136 ComposeMPIOp(float, std::plus, MPI_SUM);
00137 ComposeMPIOp(double, std::plus, MPI_SUM);
00138 ComposeMPIOp(long double, std::plus, MPI_SUM);
00139
00140 ComposeMPIOp(char, std::multiplies, MPI_PROD);
00141 ComposeMPIOp(unsigned char, std::multiplies, MPI_PROD);
00142 ComposeMPIOp(short, std::multiplies, MPI_PROD);
00143 ComposeMPIOp(unsigned short, std::multiplies, MPI_PROD);
00144 ComposeMPIOp(int, std::multiplies, MPI_PROD);
00145 ComposeMPIOp(unsigned int, std::multiplies, MPI_PROD);
00146 ComposeMPIOp(long, std::multiplies, MPI_PROD);
00147 ComposeMPIOp(unsigned long, std::multiplies, MPI_PROD);
00148 ComposeMPIOp(float, std::multiplies, MPI_PROD);
00149 ComposeMPIOp(double, std::multiplies, MPI_PROD);
00150 ComposeMPIOp(long double, std::multiplies, MPI_PROD);
00151
00152 ComposeMPIOp(char, Min, MPI_MIN);
00153 ComposeMPIOp(unsigned char, Min, MPI_MIN);
00154 ComposeMPIOp(short, Min, MPI_MIN);
00155 ComposeMPIOp(unsigned short, Min, MPI_MIN);
00156 ComposeMPIOp(int, Min, MPI_MIN);
00157 ComposeMPIOp(unsigned int, Min, MPI_MIN);
00158 ComposeMPIOp(long, Min, MPI_MIN);
00159 ComposeMPIOp(unsigned long, Min, MPI_MIN);
00160 ComposeMPIOp(float, Min, MPI_MIN);
00161 ComposeMPIOp(double, Min, MPI_MIN);
00162 ComposeMPIOp(long double, Min, MPI_MIN);
00163
00164 ComposeMPIOp(char, Max, MPI_MAX);
00165 ComposeMPIOp(unsigned char, Max, MPI_MAX);
00166 ComposeMPIOp(short, Max, MPI_MAX);
00167 ComposeMPIOp(unsigned short, Max, MPI_MAX);
00168 ComposeMPIOp(int, Max, MPI_MAX);
00169 ComposeMPIOp(unsigned int, Max, MPI_MAX);
00170 ComposeMPIOp(long, Max, MPI_MAX);
00171 ComposeMPIOp(unsigned long, Max, MPI_MAX);
00172 ComposeMPIOp(float, Max, MPI_MAX);
00173 ComposeMPIOp(double, Max, MPI_MAX);
00174 ComposeMPIOp(long double, Max, MPI_MAX);
00175
00176 #undef ComposeMPIOp
00177
00178
00179
00180
00181
00182
00183
00187 template<>
00188 class CollectiveCommunication<MPI_Comm>
00189 {
00190 public:
00192 CollectiveCommunication (const MPI_Comm& c)
00193 : communicator(c)
00194 {
00195 if(communicator!=MPI_COMM_NULL){
00196 MPI_Comm_rank(communicator,&me);
00197 MPI_Comm_size(communicator,&procs);
00198 }else{
00199 procs=0;
00200 me=-1;
00201 }
00202 }
00203
00205 int rank () const
00206 {
00207 return me;
00208 }
00209
00211 int size () const
00212 {
00213 return procs;
00214 }
00215
00217 template<typename T>
00218 T sum (T& in) const
00219 {
00220 T out;
00221 allreduce<std::plus<T> >(&in,&out,1);
00222 return out;
00223 }
00224
00226 template<typename T>
00227 int sum (T* inout, int len) const
00228 {
00229 return allreduce<std::plus<T> >(inout,len);
00230 }
00231
00233 template<typename T>
00234 T prod (T& in) const
00235 {
00236 T out;
00237 allreduce<std::multiplies<T> >(&in,&out,1);
00238 return out;
00239 }
00240
00242 template<typename T>
00243 int prod (T* inout, int len) const
00244 {
00245 return allreduce<std::plus<T> >(inout,len);
00246 }
00247
00249 template<typename T>
00250 T min (T& in) const
00251 {
00252 T out;
00253 allreduce<Min<T> >(&in,&out,1);
00254 return out;
00255 }
00256
00258 template<typename T>
00259 int min (T* inout, int len) const
00260 {
00261 return allreduce<Min<T> >(inout,len);
00262 }
00263
00264
00266 template<typename T>
00267 T max (T& in) const
00268 {
00269 T out;
00270 allreduce<Max<T> >(&in,&out,1);
00271 return out;
00272 }
00273
00275 template<typename T>
00276 int max (T* inout, int len) const
00277 {
00278 return allreduce<Max<T> >(inout,len);
00279 }
00280
00282 int barrier () const
00283 {
00284 return MPI_Barrier(communicator);
00285 }
00286
00288 template<typename T>
00289 int broadcast (T* inout, int len, int root) const
00290 {
00291 return MPI_Bcast(inout,len,Generic_MPI_Datatype<T>::get(),root,communicator);
00292 }
00293
00295 template<typename T>
00296 int gather (T* in, T* out, int len, int root) const
00297 {
00298 return MPI_Gather(in,len,Generic_MPI_Datatype<T>::get(),
00299 out,len,Generic_MPI_Datatype<T>::get(),
00300 root,communicator);
00301 }
00302
00303 operator MPI_Comm () const
00304 {
00305 return communicator;
00306 }
00307
00308 template<typename BinaryFunction, typename Type>
00309 int allreduce(Type* inout, int len) const
00310 {
00311 Type* out = new Type[len];
00312 int ret = allreduce<BinaryFunction>(inout,out,len);
00313 std::copy(out, out+len, inout);
00314 delete[] out;
00315 return ret;
00316 }
00317
00318 template<typename BinaryFunction, typename Type>
00319 int allreduce(Type* in, Type* out, int len) const
00320 {
00321 return MPI_Allreduce(in, out, len, Generic_MPI_Datatype<Type>::get(),
00322 Generic_MPI_Op<Type, BinaryFunction>::get(),communicator);
00323 }
00324
00325 private:
00326 MPI_Comm communicator;
00327 int me;
00328 int procs;
00329 };
00330 }
00331
00332 #endif
00333 #endif