IT++ Logo
gmm.cpp
Go to the documentation of this file.
00001 
00029 #include <itpp/srccode/gmm.h>
00030 #include <itpp/srccode/vqtrain.h>
00031 #include <itpp/base/math/elem_math.h>
00032 #include <itpp/base/matfunc.h>
00033 #include <itpp/base/specmat.h>
00034 #include <itpp/base/random.h>
00035 #include <itpp/base/timing.h>
00036 #include <iostream>
00037 #include <fstream>
00038 
00040 
00041 namespace itpp
00042 {
00043 
00044 GMM::GMM()
00045 {
00046   d = 0;
00047   M = 0;
00048 }
00049 
00050 GMM::GMM(std::string filename)
00051 {
00052   load(filename);
00053 }
00054 
00055 GMM::GMM(int M_in, int d_in)
00056 {
00057   M = M_in;
00058   d = d_in;
00059   m = zeros(M * d);
00060   sigma = zeros(M * d);
00061   w = 1. / M * ones(M);
00062 
00063   for (int i = 0;i < M;i++) {
00064     w(i) = 1.0 / M;
00065   }
00066   compute_internals();
00067 }
00068 
00069 void GMM::init_from_vq(const vec &codebook, int dim)
00070 {
00071 
00072   mat  C(dim, dim);
00073   int  i;
00074   vec  v;
00075 
00076   d = dim;
00077   M = codebook.length() / dim;
00078 
00079   m = codebook;
00080   w = ones(M) / double(M);
00081 
00082   C.clear();
00083   for (i = 0;i < M;i++) {
00084     v = codebook.mid(i * d, d);
00085     C = C + outer_product(v, v);
00086   }
00087   C = 1. / M * C;
00088   sigma.set_length(M*d);
00089   for (i = 0;i < M;i++) {
00090     sigma.replace_mid(i*d, diag(C));
00091   }
00092 
00093   compute_internals();
00094 }
00095 
00096 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
00097 {
00098   int  i, j;
00099   d = m_in.rows();
00100   M = m_in.cols();
00101 
00102   m.set_length(M*d);
00103   sigma.set_length(M*d);
00104   for (i = 0;i < M;i++) {
00105     for (j = 0;j < d;j++) {
00106       m(i*d + j) = m_in(j, i);
00107       sigma(i*d + j) = sigma_in(j, i);
00108     }
00109   }
00110   w = w_in;
00111 
00112   compute_internals();
00113 }
00114 
00115 void GMM::set_mean(const mat &m_in)
00116 {
00117   int  i, j;
00118 
00119   d = m_in.rows();
00120   M = m_in.cols();
00121 
00122   m.set_length(M*d);
00123   for (i = 0;i < M;i++) {
00124     for (j = 0;j < d;j++) {
00125       m(i*d + j) = m_in(j, i);
00126     }
00127   }
00128   compute_internals();
00129 }
00130 
00131 void GMM::set_mean(int i, const vec &means, bool compflag)
00132 {
00133   m.replace_mid(i*length(means), means);
00134   if (compflag) compute_internals();
00135 }
00136 
00137 void GMM::set_covariance(const mat &sigma_in)
00138 {
00139   int  i, j;
00140 
00141   d = sigma_in.rows();
00142   M = sigma_in.cols();
00143 
00144   sigma.set_length(M*d);
00145   for (i = 0;i < M;i++) {
00146     for (j = 0;j < d;j++) {
00147       sigma(i*d + j) = sigma_in(j, i);
00148     }
00149   }
00150   compute_internals();
00151 }
00152 
00153 void GMM::set_covariance(int i, const vec &covariances, bool compflag)
00154 {
00155   sigma.replace_mid(i*length(covariances), covariances);
00156   if (compflag) compute_internals();
00157 }
00158 
00159 void GMM::marginalize(int d_new)
00160 {
00161   it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension");
00162 
00163   vec  mnew(d_new*M), sigmanew(d_new*M);
00164   int  i, j;
00165 
00166   for (i = 0;i < M;i++) {
00167     for (j = 0;j < d_new;j++) {
00168       mnew(i*d_new + j) = m(i * d + j);
00169       sigmanew(i*d_new + j) = sigma(i * d + j);
00170     }
00171   }
00172   m = mnew;
00173   sigma = sigmanew;
00174   d = d_new;
00175 
00176   compute_internals();
00177 }
00178 
00179 void GMM::join(const GMM &newgmm)
00180 {
00181   if (d == 0) {
00182     w = newgmm.w;
00183     m = newgmm.m;
00184     sigma = newgmm.sigma;
00185     d = newgmm.d;
00186     M = newgmm.M;
00187   }
00188   else {
00189     it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension");
00190 
00191     w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w);
00192     w = w / sum(w);
00193     m = concat(m, newgmm.m);
00194     sigma = concat(sigma, newgmm.sigma);
00195 
00196     M = M + newgmm.M;
00197   }
00198   compute_internals();
00199 }
00200 
00201 void GMM::clear()
00202 {
00203   w.set_length(0);
00204   m.set_length(0);
00205   sigma.set_length(0);
00206   d = 0;
00207   M = 0;
00208 }
00209 
00210 void GMM::save(std::string filename)
00211 {
00212   std::ofstream f(filename.c_str());
00213   int   i, j;
00214 
00215   f << M << " " << d << std::endl ;
00216   for (i = 0;i < w.length();i++) {
00217     f << w(i) << std::endl ;
00218   }
00219   for (i = 0;i < M;i++) {
00220     f << m(i*d) ;
00221     for (j = 1;j < d;j++) {
00222       f << " " << m(i*d + j) ;
00223     }
00224     f << std::endl ;
00225   }
00226   for (i = 0;i < M;i++) {
00227     f << sigma(i*d) ;
00228     for (j = 1;j < d;j++) {
00229       f << " " << sigma(i*d + j) ;
00230     }
00231     f << std::endl ;
00232   }
00233 }
00234 
00235 void GMM::load(std::string filename)
00236 {
00237   std::ifstream GMMFile(filename.c_str());
00238   int   i, j;
00239 
00240   it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename);
00241 
00242   GMMFile >> M >> d ;
00243 
00244 
00245   w.set_length(M);
00246   for (i = 0;i < M;i++) {
00247     GMMFile >> w(i) ;
00248   }
00249   m.set_length(M*d);
00250   for (i = 0;i < M;i++) {
00251     for (j = 0;j < d;j++) {
00252       GMMFile >> m(i*d + j) ;
00253     }
00254   }
00255   sigma.set_length(M*d);
00256   for (i = 0;i < M;i++) {
00257     for (j = 0;j < d;j++) {
00258       GMMFile >> sigma(i*d + j) ;
00259     }
00260   }
00261   compute_internals();
00262   std::cout << "  mixtures:" << M << "  dim:" << d << std::endl ;
00263 }
00264 
00265 double GMM::likelihood(const vec &x)
00266 {
00267   double fx = 0;
00268   int  i;
00269 
00270   for (i = 0;i < M;i++) {
00271     fx += w(i) * likelihood_aposteriori(x, i);
00272   }
00273   return fx;
00274 }
00275 
00276 vec GMM::likelihood_aposteriori(const vec &x)
00277 {
00278   vec  v(M);
00279   int  i;
00280 
00281   for (i = 0;i < M;i++) {
00282     v(i) = w(i) * likelihood_aposteriori(x, i);
00283   }
00284   return v;
00285 }
00286 
00287 double GMM::likelihood_aposteriori(const vec &x, int mixture)
00288 {
00289   int  j;
00290   double s;
00291 
00292   it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match");
00293   s = 0;
00294   for (j = 0;j < d;j++) {
00295     s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j));
00296   }
00297   return normweight(mixture)*std::exp(s);;
00298 }
00299 
00300 void GMM::compute_internals()
00301 {
00302   int  i, j;
00303   double s;
00304   double constant = 1.0 / std::pow(2 * pi, d / 2.0);
00305 
00306   normweight.set_length(M);
00307   normexp.set_length(M*d);
00308 
00309   for (i = 0;i < M;i++) {
00310     s = 1;
00311     for (j = 0;j < d;j++) {
00312       normexp(i*d + j) = -0.5 / sigma(i * d + j);  // check time
00313       s *= sigma(i * d + j);
00314     }
00315     normweight(i) = constant / std::sqrt(s);
00316   }
00317 
00318 }
00319 
00320 vec GMM::draw_sample()
00321 {
00322   static bool first = true;
00323   static vec cumweight;
00324   double u = randu();
00325   int  k;
00326 
00327   if (first) {
00328     first = false;
00329     cumweight = cumsum(w);
00330     it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0");
00331     cumweight(length(cumweight) - 1) = 1;
00332   }
00333   k = 0;
00334   while (u > cumweight(k)) k++;
00335 
00336   return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d);
00337 }
00338 
00339 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
00340 {
00341   mat   mean;
00342   int   i, j, d = TrainingData(0).length();
00343   vec   sig;
00344   GMM   gmm(M, d);
00345   vec   m(d*M);
00346   vec   sigma(d*M);
00347   vec   w(M);
00348   vec   normweight(M);
00349   vec   normexp(d*M);
00350   double  LL = 0, LLold, fx;
00351   double  constant = 1.0 / std::pow(2 * pi, d / 2.0);
00352   int   T = TrainingData.length();
00353   vec   x1;
00354   int   t, n;
00355   vec   msum(d*M);
00356   vec   sigmasum(d*M);
00357   vec   wsum(M);
00358   vec   p_aposteriori(M);
00359   vec   x2;
00360   double  s;
00361   vec   temp1, temp2;
00362   //double  MINIMUM_VARIANCE=0.03;
00363 
00364   //-----------initialization-----------------------------------
00365 
00366   mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE);
00367   for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false);
00368   // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
00369   sig = zeros(d);
00370   for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i));
00371   sig /= TrainingData.length();
00372   for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false);
00373 
00374   gmm.set_weight(1.0 / M*ones(M));
00375 
00376   //-----------optimization-----------------------------------
00377 
00378   tic();
00379   for (i = 0;i < M;i++) {
00380     temp1 = gmm.get_mean(i);
00381     temp2 = gmm.get_covariance(i);
00382     for (j = 0;j < d;j++) {
00383       m(i*d + j) = temp1(j);
00384       sigma(i*d + j) = temp2(j);
00385     }
00386     w(i) = gmm.get_weight(i);
00387   }
00388   for (n = 0;n < NOITER;n++) {
00389     for (i = 0;i < M;i++) {
00390       s = 1;
00391       for (j = 0;j < d;j++) {
00392         normexp(i*d + j) = -0.5 / sigma(i * d + j);  // check time
00393         s *= sigma(i * d + j);
00394       }
00395       normweight(i) = constant * w(i) / std::sqrt(s);
00396     }
00397     LLold = LL;
00398     wsum.clear();
00399     msum.clear();
00400     sigmasum.clear();
00401     LL = 0;
00402     for (t = 0;t < T;t++) {
00403       x1 = TrainingData(t);
00404       x2 = sqr(x1);
00405       fx = 0;
00406       for (i = 0;i < M;i++) {
00407         s = 0;
00408         for (j = 0;j < d;j++) {
00409           s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j));
00410         }
00411         p_aposteriori(i) = normweight(i) * std::exp(s);
00412         fx += p_aposteriori(i);
00413       }
00414       p_aposteriori /= fx;
00415       LL = LL + std::log(fx);
00416 
00417       for (i = 0;i < M;i++) {
00418         wsum(i) += p_aposteriori(i);
00419         for (j = 0;j < d;j++) {
00420           msum(i*d + j) += p_aposteriori(i) * x1(j);
00421           sigmasum(i*d + j) += p_aposteriori(i) * x2(j);
00422         }
00423       }
00424     }
00425     for (i = 0;i < M;i++) {
00426       for (j = 0;j < d;j++) {
00427         m(i*d + j) = msum(i * d + j) / wsum(i);
00428         sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j));
00429       }
00430       w(i) = wsum(i) / T;
00431     }
00432     LL = LL / T;
00433 
00434     if (std::abs((LL - LLold) / LL) < 1e-6) break;
00435     if (VERBOSE) {
00436       std::cout << n << ":   " << LL << "   " << std::abs((LL - LLold) / LL) << "   " << toc() <<  std::endl ;
00437       std::cout << "---------------------------------------" << std::endl ;
00438       tic();
00439     }
00440     else {
00441       std::cout << n << ": LL =  " << LL << "   " << std::abs((LL - LLold) / LL) << "\r" ;
00442       std::cout.flush();
00443     }
00444   }
00445   for (i = 0;i < M;i++) {
00446     gmm.set_mean(i, m.mid(i*d, d), false);
00447     gmm.set_covariance(i, sigma.mid(i*d, d), false);
00448   }
00449   gmm.set_weight(w);
00450   return gmm;
00451 }
00452 
00453 } // namespace itpp
00454 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Wed Jul 27 2011 16:27:05 for IT++ by Doxygen 1.7.4