IT++ Logo
mat.cpp
Go to the documentation of this file.
00001 
00029 #include <itpp/base/mat.h>
00030 
00031 #ifndef _MSC_VER
00032 #  include <itpp/config.h>
00033 #else
00034 #  include <itpp/config_msvc.h>
00035 #endif
00036 
00037 #if defined (HAVE_BLAS)
00038 #  include <itpp/base/blas.h>
00039 #endif
00040 
00042 
00043 namespace itpp
00044 {
00045 
00046 template<>
00047 cmat cmat::hermitian_transpose() const
00048 {
00049   cmat temp(no_cols, no_rows);
00050   for (int i = 0; i < no_rows; i++)
00051     for (int j = 0; j < no_cols; j++)
00052       temp(j, i) = std::conj(operator()(i,j));
00053 
00054   return temp;
00055 }
00056 
00057 
00058 // -------- Multiplication operator -------------
00059 
00060 #if defined(HAVE_BLAS)
00061 
00062 template<>
00063 mat& mat::operator*=(const mat &m)
00064 {
00065   it_assert_debug(no_cols == m.no_rows, "mat::operator*=(): Wrong sizes");
00066   mat r(no_rows, m.no_cols); // unnecessary memory??
00067   double alpha = 1.0;
00068   double beta = 0.0;
00069   char trans = 'n';
00070   blas::dgemm_(&trans, &trans, &no_rows, &m.no_cols, &no_cols, &alpha, data,
00071                &no_rows, m.data, &m.no_rows, &beta, r.data, &r.no_rows);
00072   operator=(r); // time consuming
00073   return *this;
00074 }
00075 
00076 template<>
00077 cmat& cmat::operator*=(const cmat &m)
00078 {
00079   it_assert_debug(no_cols == m.no_rows, "cmat::operator*=(): Wrong sizes");
00080   cmat r(no_rows, m.no_cols); // unnecessary memory??
00081   std::complex<double> alpha = std::complex<double>(1.0);
00082   std::complex<double> beta = std::complex<double>(0.0);
00083   char trans = 'n';
00084   blas::zgemm_(&trans, &trans, &no_rows, &m.no_cols, &no_cols, &alpha, data,
00085                &no_rows, m.data, &m.no_rows, &beta, r.data, &r.no_rows);
00086   operator=(r); // time consuming
00087   return *this;
00088 }
00089 #else
00090 template<>
00091 mat& mat::operator*=(const mat &m)
00092 {
00093   it_assert_debug(no_cols == m.no_rows, "Mat<>::operator*=(): Wrong sizes");
00094   mat r(no_rows, m.no_cols);
00095   int r_pos = 0, pos = 0, m_pos = 0;
00096 
00097   for (int i = 0; i < r.no_cols; i++) {
00098     for (int j = 0; j < r.no_rows; j++) {
00099       double tmp = 0.0;
00100       pos = 0;
00101       for (int k = 0; k < no_cols; k++) {
00102         tmp += data[pos+j] * m.data[m_pos+k];
00103         pos += no_rows;
00104       }
00105       r.data[r_pos+j] = tmp;
00106     }
00107     r_pos += r.no_rows;
00108     m_pos += m.no_rows;
00109   }
00110   operator=(r); // time consuming
00111   return *this;
00112 }
00113 
00114 template<>
00115 cmat& cmat::operator*=(const cmat &m)
00116 {
00117   it_assert_debug(no_cols == m.no_rows, "Mat<>::operator*=(): Wrong sizes");
00118   cmat r(no_rows, m.no_cols);
00119   int r_pos = 0, pos = 0, m_pos = 0;
00120 
00121   for (int i = 0; i < r.no_cols; i++) {
00122     for (int j = 0; j < r.no_rows; j++) {
00123       std::complex<double> tmp(0.0);
00124       pos = 0;
00125       for (int k = 0; k < no_cols; k++) {
00126         tmp += data[pos+j] * m.data[m_pos+k];
00127         pos += no_rows;
00128       }
00129       r.data[r_pos+j] = tmp;
00130     }
00131     r_pos += r.no_rows;
00132     m_pos += m.no_rows;
00133   }
00134   operator=(r); // time consuming
00135   return *this;
00136 }
00137 #endif // HAVE_BLAS
00138 
00139 
00140 #if defined(HAVE_BLAS)
00141 template<>
00142 mat operator*(const mat &m1, const mat &m2)
00143 {
00144   it_assert_debug(m1.no_cols == m2.no_rows, "mat::operator*(): Wrong sizes");
00145   mat r(m1.no_rows, m2.no_cols);
00146   double alpha = 1.0;
00147   double beta = 0.0;
00148   char trans = 'n';
00149   blas::dgemm_(&trans, &trans, &m1.no_rows, &m2.no_cols, &m1.no_cols, &alpha,
00150                m1.data, &m1.no_rows, m2.data, &m2.no_rows, &beta, r.data,
00151                &r.no_rows);
00152   return r;
00153 }
00154 
00155 template<>
00156 cmat operator*(const cmat &m1, const cmat &m2)
00157 {
00158   it_assert_debug(m1.no_cols == m2.no_rows, "cmat::operator*(): Wrong sizes");
00159   cmat r(m1.no_rows, m2.no_cols);
00160   std::complex<double> alpha = std::complex<double>(1.0);
00161   std::complex<double> beta = std::complex<double>(0.0);
00162   char trans = 'n';
00163   blas::zgemm_(&trans, &trans, &m1.no_rows, &m2.no_cols, &m1.no_cols, &alpha,
00164                m1.data, &m1.no_rows, m2.data, &m2.no_rows, &beta, r.data,
00165                &r.no_rows);
00166   return r;
00167 }
00168 #else
00169 template<>
00170 mat operator*(const mat &m1, const mat &m2)
00171 {
00172   it_assert_debug(m1.no_cols == m2.no_rows,
00173                   "Mat<>::operator*(): Wrong sizes");
00174   mat r(m1.no_rows, m2.no_cols);
00175   double *tr = r.data;
00176   double *t1;
00177   double *t2 = m2.data;
00178   for (int i = 0; i < r.no_cols; i++) {
00179     for (int j = 0; j < r.no_rows; j++) {
00180       double tmp = 0.0;
00181       t1 = m1.data + j;
00182       for (int k = m1.no_cols; k > 0; k--) {
00183         tmp += *(t1) * *(t2++);
00184         t1 += m1.no_rows;
00185       }
00186       *(tr++) = tmp;
00187       t2 -= m2.no_rows;
00188     }
00189     t2 += m2.no_rows;
00190   }
00191   return r;
00192 }
00193 
00194 template<>
00195 cmat operator*(const cmat &m1, const cmat &m2)
00196 {
00197   it_assert_debug(m1.no_cols == m2.no_rows,
00198                   "Mat<>::operator*(): Wrong sizes");
00199   cmat r(m1.no_rows, m2.no_cols);
00200   std::complex<double> *tr = r.data;
00201   std::complex<double> *t1;
00202   std::complex<double> *t2 = m2.data;
00203   for (int i = 0; i < r.no_cols; i++) {
00204     for (int j = 0; j < r.no_rows; j++) {
00205       std::complex<double> tmp(0.0);
00206       t1 = m1.data + j;
00207       for (int k = m1.no_cols; k > 0; k--) {
00208         tmp += *(t1) * *(t2++);
00209         t1 += m1.no_rows;
00210       }
00211       *(tr++) = tmp;
00212       t2 -= m2.no_rows;
00213     }
00214     t2 += m2.no_rows;
00215   }
00216   return r;
00217 }
00218 #endif // HAVE_BLAS
00219 
00220 
00221 #if defined(HAVE_BLAS)
00222 template<>
00223 vec operator*(const mat &m, const vec &v)
00224 {
00225   it_assert_debug(m.no_cols == v.size(), "mat::operator*(): Wrong sizes");
00226   vec r(m.no_rows);
00227   double alpha = 1.0;
00228   double beta = 0.0;
00229   char trans = 'n';
00230   int incr = 1;
00231   blas::dgemv_(&trans, &m.no_rows, &m.no_cols, &alpha, m.data, &m.no_rows,
00232                v._data(), &incr, &beta, r._data(), &incr);
00233   return r;
00234 }
00235 
00236 template<>
00237 cvec operator*(const cmat &m, const cvec &v)
00238 {
00239   it_assert_debug(m.no_cols == v.size(), "cmat::operator*(): Wrong sizes");
00240   cvec r(m.no_rows);
00241   std::complex<double> alpha = std::complex<double>(1.0);
00242   std::complex<double> beta = std::complex<double>(0.0);
00243   char trans = 'n';
00244   int incr = 1;
00245   blas::zgemv_(&trans, &m.no_rows, &m.no_cols, &alpha, m.data, &m.no_rows,
00246                v._data(), &incr, &beta, r._data(), &incr);
00247   return r;
00248 }
00249 #else
00250 template<>
00251 vec operator*(const mat &m, const vec &v)
00252 {
00253   it_assert_debug(m.no_cols == v.size(),
00254                   "Mat<>::operator*(): Wrong sizes");
00255   vec r(m.no_rows);
00256   for (int i = 0; i < m.no_rows; i++) {
00257     r(i) = 0.0;
00258     int m_pos = 0;
00259     for (int k = 0; k < m.no_cols; k++) {
00260       r(i) += m.data[m_pos+i] * v(k);
00261       m_pos += m.no_rows;
00262     }
00263   }
00264   return r;
00265 }
00266 
00267 template<>
00268 cvec operator*(const cmat &m, const cvec &v)
00269 {
00270   it_assert_debug(m.no_cols == v.size(),
00271                   "Mat<>::operator*(): Wrong sizes");
00272   cvec r(m.no_rows);
00273   for (int i = 0; i < m.no_rows; i++) {
00274     r(i) = std::complex<double>(0.0);
00275     int m_pos = 0;
00276     for (int k = 0; k < m.no_cols; k++) {
00277       r(i) += m.data[m_pos+i] * v(k);
00278       m_pos += m.no_rows;
00279     }
00280   }
00281   return r;
00282 }
00283 #endif // HAVE_BLAS
00284 
00285 
00286 //---------------------------------------------------------------------
00287 // Instantiations
00288 //---------------------------------------------------------------------
00289 
00290 // class instantiations
00291 
00292 template class Mat<double>;
00293 template class Mat<std::complex<double> >;
00294 template class Mat<int>;
00295 template class Mat<short int>;
00296 template class Mat<bin>;
00297 
00298 // addition operators
00299 
00300 template mat operator+(const mat &m1, const mat &m2);
00301 template cmat operator+(const cmat &m1, const cmat &m2);
00302 template imat operator+(const imat &m1, const imat &m2);
00303 template smat operator+(const smat &m1, const smat &m2);
00304 template bmat operator+(const bmat &m1, const bmat &m2);
00305 
00306 template mat operator+(const mat &m, double t);
00307 template cmat operator+(const cmat &m, std::complex<double> t);
00308 template imat operator+(const imat &m, int t);
00309 template smat operator+(const smat &m, short t);
00310 template bmat operator+(const bmat &m, bin t);
00311 
00312 template mat operator+(double t, const mat &m);
00313 template cmat operator+(std::complex<double> t, const cmat &m);
00314 template imat operator+(int t, const imat &m);
00315 template smat operator+(short t, const smat &m);
00316 template bmat operator+(bin t, const bmat &m);
00317 
00318 // subraction operators
00319 
00320 template mat operator-(const mat &m1, const mat &m2);
00321 template cmat operator-(const cmat &m1, const cmat &m2);
00322 template imat operator-(const imat &m1, const imat &m2);
00323 template smat operator-(const smat &m1, const smat &m2);
00324 template bmat operator-(const bmat &m1, const bmat &m2);
00325 
00326 template mat operator-(const mat &m, double t);
00327 template cmat operator-(const cmat &m, std::complex<double> t);
00328 template imat operator-(const imat &m, int t);
00329 template smat operator-(const smat &m, short t);
00330 template bmat operator-(const bmat &m, bin t);
00331 
00332 template mat operator-(double t, const mat &m);
00333 template cmat operator-(std::complex<double> t, const cmat &m);
00334 template imat operator-(int t, const imat &m);
00335 template smat operator-(short t, const smat &m);
00336 template bmat operator-(bin t, const bmat &m);
00337 
00338 // unary minus
00339 
00340 template mat operator-(const mat &m);
00341 template cmat operator-(const cmat &m);
00342 template imat operator-(const imat &m);
00343 template smat operator-(const smat &m);
00344 template bmat operator-(const bmat &m);
00345 
00346 // multiplication operators
00347 
00348 template imat operator*(const imat &m1, const imat &m2);
00349 template smat operator*(const smat &m1, const smat &m2);
00350 template bmat operator*(const bmat &m1, const bmat &m2);
00351 
00352 template ivec operator*(const imat &m, const ivec &v);
00353 template svec operator*(const smat &m, const svec &v);
00354 template bvec operator*(const bmat &m, const bvec &v);
00355 
00356 template mat operator*(const mat &m, double t);
00357 template cmat operator*(const cmat &m, std::complex<double> t);
00358 template imat operator*(const imat &m, int t);
00359 template smat operator*(const smat &m, short t);
00360 template bmat operator*(const bmat &m, bin t);
00361 
00362 template mat operator*(double t, const mat &m);
00363 template cmat operator*(std::complex<double> t, const cmat &m);
00364 template imat operator*(int t, const imat &m);
00365 template smat operator*(short t, const smat &m);
00366 template bmat operator*(bin t, const bmat &m);
00367 
00368 // elementwise multiplication
00369 
00370 template mat elem_mult(const mat &m1, const mat &m2);
00371 template cmat elem_mult(const cmat &m1, const cmat &m2);
00372 template imat elem_mult(const imat &m1, const imat &m2);
00373 template smat elem_mult(const smat &m1, const smat &m2);
00374 template bmat elem_mult(const bmat &m1, const bmat &m2);
00375 
00376 template void elem_mult_out(const mat &m1, const mat &m2, mat &out);
00377 template void elem_mult_out(const cmat &m1, const cmat &m2, cmat &out);
00378 template void elem_mult_out(const imat &m1, const imat &m2, imat &out);
00379 template void elem_mult_out(const smat &m1, const smat &m2, smat &out);
00380 template void elem_mult_out(const bmat &m1, const bmat &m2, bmat &out);
00381 
00382 template void elem_mult_out(const mat &m1, const mat &m2,
00383                             const mat &m3, mat &out);
00384 template void elem_mult_out(const cmat &m1, const cmat &m2,
00385                             const cmat &m3, cmat &out);
00386 template void elem_mult_out(const imat &m1, const imat &m2,
00387                             const imat &m3, imat &out);
00388 template void elem_mult_out(const smat &m1, const smat &m2,
00389                             const smat &m3, smat &out);
00390 template void elem_mult_out(const bmat &m1, const bmat &m2,
00391                             const bmat &m3, bmat &out);
00392 
00393 template void elem_mult_out(const mat &m1, const mat &m2, const mat &m3,
00394                             const mat &m4, mat &out);
00395 template void elem_mult_out(const cmat &m1, const cmat &m2,
00396                             const cmat &m3, const cmat &m4, cmat &out);
00397 template void elem_mult_out(const imat &m1, const imat &m2,
00398                             const imat &m3, const imat &m4, imat &out);
00399 template void elem_mult_out(const smat &m1, const smat &m2,
00400                             const smat &m3, const smat &m4, smat &out);
00401 template void elem_mult_out(const bmat &m1, const bmat &m2,
00402                             const bmat &m3, const bmat &m4, bmat &out);
00403 
00404 template void elem_mult_inplace(const mat &m1, mat &m2);
00405 template void elem_mult_inplace(const cmat &m1, cmat &m2);
00406 template void elem_mult_inplace(const imat &m1, imat &m2);
00407 template void elem_mult_inplace(const smat &m1, smat &m2);
00408 template void elem_mult_inplace(const bmat &m1, bmat &m2);
00409 
00410 template double elem_mult_sum(const mat &m1, const mat &m2);
00411 template std::complex<double> elem_mult_sum(const cmat &m1, const cmat &m2);
00412 template int elem_mult_sum(const imat &m1, const imat &m2);
00413 template short elem_mult_sum(const smat &m1, const smat &m2);
00414 template bin elem_mult_sum(const bmat &m1, const bmat &m2);
00415 
00416 // division operator
00417 
00418 template mat operator/(double t, const mat &m);
00419 template cmat operator/(std::complex<double> t, const cmat &m);
00420 template imat operator/(int t, const imat &m);
00421 template smat operator/(short t, const smat &m);
00422 template bmat operator/(bin t, const bmat &m);
00423 
00424 template mat operator/(const mat &m, double t);
00425 template cmat operator/(const cmat &m, std::complex<double> t);
00426 template imat operator/(const imat &m, int t);
00427 template smat operator/(const smat &m, short t);
00428 template bmat operator/(const bmat &m, bin t);
00429 
00430 // elementwise division
00431 
00432 template mat elem_div(const mat &m1, const mat &m2);
00433 template cmat elem_div(const cmat &m1, const cmat &m2);
00434 template imat elem_div(const imat &m1, const imat &m2);
00435 template smat elem_div(const smat &m1, const smat &m2);
00436 template bmat elem_div(const bmat &m1, const bmat &m2);
00437 
00438 template void elem_div_out(const mat &m1, const mat &m2, mat &out);
00439 template void elem_div_out(const cmat &m1, const cmat &m2, cmat &out);
00440 template void elem_div_out(const imat &m1, const imat &m2, imat &out);
00441 template void elem_div_out(const smat &m1, const smat &m2, smat &out);
00442 template void elem_div_out(const bmat &m1, const bmat &m2, bmat &out);
00443 
00444 template double elem_div_sum(const mat &m1, const mat &m2);
00445 template std::complex<double> elem_div_sum(const cmat &m1,
00446     const cmat &m2);
00447 template int elem_div_sum(const imat &m1, const imat &m2);
00448 template short elem_div_sum(const smat &m1, const smat &m2);
00449 template bin elem_div_sum(const bmat &m1, const bmat &m2);
00450 
00451 // concatenation
00452 
00453 template mat concat_horizontal(const mat &m1, const mat &m2);
00454 template cmat concat_horizontal(const cmat &m1, const cmat &m2);
00455 template imat concat_horizontal(const imat &m1, const imat &m2);
00456 template smat concat_horizontal(const smat &m1, const smat &m2);
00457 template bmat concat_horizontal(const bmat &m1, const bmat &m2);
00458 
00459 template mat concat_vertical(const mat &m1, const mat &m2);
00460 template cmat concat_vertical(const cmat &m1, const cmat &m2);
00461 template imat concat_vertical(const imat &m1, const imat &m2);
00462 template smat concat_vertical(const smat &m1, const smat &m2);
00463 template bmat concat_vertical(const bmat &m1, const bmat &m2);
00464 
00465 // I/O streams
00466 
00467 template std::ostream &operator<<(std::ostream &os, const mat  &m);
00468 template std::ostream &operator<<(std::ostream &os, const cmat &m);
00469 template std::ostream &operator<<(std::ostream &os, const imat  &m);
00470 template std::ostream &operator<<(std::ostream &os, const smat  &m);
00471 template std::ostream &operator<<(std::ostream &os, const bmat  &m);
00472 
00473 template std::istream &operator>>(std::istream &is, mat  &m);
00474 template std::istream &operator>>(std::istream &is, cmat &m);
00475 template std::istream &operator>>(std::istream &is, imat  &m);
00476 template std::istream &operator>>(std::istream &is, smat  &m);
00477 template std::istream &operator>>(std::istream &is, bmat  &m);
00478 
00479 } // namespace itpp
00480 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Wed Jul 27 2011 16:27:04 for IT++ by Doxygen 1.7.4