00001 00029 #ifndef GMM_H 00030 #define GMM_H 00031 00032 #include <itpp/base/mat.h> 00033 00034 00035 namespace itpp 00036 { 00037 00039 00045 class GMM 00046 { 00047 public: 00048 GMM(); 00049 GMM(int nomix, int dim); 00050 GMM(std::string filename); 00051 void init_from_vq(const vec &codebook, int dim); 00052 // void init(const vec &w_in, const vec &m_in, const vec &sigma_in); 00053 void init(const vec &w_in, const mat &m_in, const mat &sigma_in); 00054 void load(std::string filename); 00055 void save(std::string filename); 00056 void set_weight(const vec &weights, bool compflag = true); 00057 void set_weight(int i, double weight, bool compflag = true); 00058 void set_mean(const mat &m_in); 00059 void set_mean(const vec &means, bool compflag = true); 00060 void set_mean(int i, const vec &means, bool compflag = true); 00061 void set_covariance(const mat &sigma_in); 00062 void set_covariance(const vec &covariances, bool compflag = true); 00063 void set_covariance(int i, const vec &covariances, bool compflag = true); 00064 int get_no_mixtures(); 00065 int get_no_gaussians() const { return M; } 00066 int get_dimension(); 00067 vec get_weight(); 00068 double get_weight(int i); 00069 vec get_mean(); 00070 vec get_mean(int i); 00071 vec get_covariance(); 00072 vec get_covariance(int i); 00073 void marginalize(int d_new); 00074 void join(const GMM &newgmm); 00075 void clear(); 00076 double likelihood(const vec &x); 00077 double likelihood_aposteriori(const vec &x, int mixture); 00078 vec likelihood_aposteriori(const vec &x); 00079 vec draw_sample(); 00080 protected: 00081 vec m, sigma, w; 00082 int M, d; 00083 private: 00084 void compute_internals(); 00085 vec normweight, normexp; 00086 }; 00087 00088 inline void GMM::set_weight(const vec &weights, bool compflag) {w = weights; if (compflag) compute_internals(); } 00089 inline void GMM::set_weight(int i, double weight, bool compflag) {w(i) = weight; if (compflag) compute_internals(); } 00090 inline void GMM::set_mean(const vec &means, bool compflag) {m = means; if (compflag) compute_internals(); } 00091 inline void GMM::set_covariance(const vec &covariances, bool compflag) {sigma = covariances; if (compflag) compute_internals(); } 00092 inline int GMM::get_dimension() {return d;} 00093 inline vec GMM::get_weight() {return w;} 00094 inline double GMM::get_weight(int i) {return w(i);} 00095 inline vec GMM::get_mean() {return m;} 00096 inline vec GMM::get_mean(int i) {return m.mid(i*d, d);} 00097 inline vec GMM::get_covariance() {return sigma;} 00098 inline vec GMM::get_covariance(int i) {return sigma.mid(i*d, d);} 00099 00100 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER = 30, bool VERBOSE = true); 00101 00103 00104 } // namespace itpp 00105 00106 #endif // #ifndef GMM_H
Generated on Wed Jul 27 2011 16:27:05 for IT++ by Doxygen 1.7.4