IT++ Logo
mog_generic.h
Go to the documentation of this file.
00001 
00029 #ifndef MOG_GENERIC_H
00030 #define MOG_GENERIC_H
00031 
00032 #include <itpp/base/vec.h>
00033 #include <itpp/base/mat.h>
00034 #include <itpp/base/array.h>
00035 
00036 
00037 namespace itpp
00038 {
00039 
00056 class MOG_generic
00057 {
00058 
00059 public:
00060 
00066   MOG_generic() { init(); }
00067 
00071   MOG_generic(const std::string &name_in) { load(name_in); }
00072 
00078   MOG_generic(const int &K_in, const int &D_in, bool full_in = false) { init(K_in, D_in, full_in); }
00079 
00087   MOG_generic(Array<vec> &means_in, bool full_in = false) { init(means_in, full_in); }
00088 
00095   MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); }
00096 
00103   MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); }
00104 
00106   virtual ~MOG_generic() { cleanup(); }
00107 
00112   void init();
00113 
00119   void init(const int &K_in, const int &D_in, bool full_in = false);
00120 
00128   void init(Array<vec> &means_in, bool full_in = false);
00129 
00136   void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in);
00137 
00144   void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in);
00145 
00150   virtual void cleanup();
00151 
00153   bool is_valid() const { return valid; }
00154 
00156   bool is_full() const { return full; }
00157 
00159   int get_K() const { if (valid) return(K); else return(0); }
00160 
00162   int get_D() const { if (valid) return(D); else return(0); }
00163 
00165   vec get_weights() const { vec tmp;  if (valid) { tmp = weights; } return tmp; }
00166 
00168   Array<vec> get_means() const { Array<vec> tmp; if (valid) { tmp = means; } return tmp; }
00169 
00171   Array<vec> get_diag_covs() const { Array<vec> tmp; if (valid && !full) { tmp = diag_covs; } return tmp; }
00172 
00174   Array<mat> get_full_covs() const { Array<mat> tmp; if (valid && full) { tmp = full_covs; } return tmp; }
00175 
00179   void set_means(Array<vec> &means_in);
00180 
00184   void set_diag_covs(Array<vec> &diag_covs_in);
00185 
00189   void set_full_covs(Array<mat> &full_covs_in);
00190 
00194   void set_weights(vec &weights_in);
00195 
00197   void set_means_zero();
00198 
00200   void set_diag_covs_unity();
00201 
00203   void set_full_covs_unity();
00204 
00206   void set_weights_uniform();
00207 
00213   void set_checks(bool do_checks_in) { do_checks = do_checks_in; }
00214 
00218   void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; }
00219 
00223   virtual void load(const std::string &name_in);
00224 
00228   virtual void save(const std::string &name_in) const;
00229 
00246   virtual void join(const MOG_generic &B_in);
00247 
00255   virtual void convert_to_diag();
00256 
00262   virtual void convert_to_full();
00263 
00265   virtual double log_lhood_single_gaus(const vec &x_in, const int k);
00266 
00268   virtual double log_lhood(const vec &x_in);
00269 
00271   virtual double lhood(const vec &x_in);
00272 
00274   virtual double avg_log_lhood(const Array<vec> &X_in);
00275 
00276 protected:
00277 
00279   bool do_checks;
00280 
00282   bool valid;
00283 
00285   bool full;
00286 
00288   bool paranoid;
00289 
00291   int K;
00292 
00294   int D;
00295 
00297   Array<vec> means;
00298 
00300   Array<vec> diag_covs;
00301 
00303   Array<mat> full_covs;
00304 
00306   vec weights;
00307 
00309   double log_max_K;
00310 
00316   vec log_det_etc;
00317 
00319   vec log_weights;
00320 
00322   Array<mat> full_covs_inv;
00323 
00325   Array<vec> diag_covs_inv_etc;
00326 
00328   bool check_size(const vec &x_in) const;
00329 
00331   bool check_size(const Array<vec> &X_in) const;
00332 
00334   bool check_array_uniformity(const Array<vec> & A) const;
00335 
00337   void set_means_internal(Array<vec> &means_in);
00339   void set_diag_covs_internal(Array<vec> &diag_covs_in);
00341   void set_full_covs_internal(Array<mat> &full_covs_in);
00343   void set_weights_internal(vec &_weigths);
00344 
00346   void set_means_zero_internal();
00348   void set_diag_covs_unity_internal();
00350   void set_full_covs_unity_internal();
00352   void set_weights_uniform_internal();
00353 
00355   void convert_to_diag_internal();
00357   void convert_to_full_internal();
00358 
00360   virtual void setup_means();
00361 
00363   virtual void setup_covs();
00364 
00366   virtual void setup_weights();
00367 
00369   virtual void setup_misc();
00370 
00372   virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k);
00374   virtual double log_lhood_internal(const vec &x_in);
00376   virtual double lhood_internal(const vec &x_in);
00377 
00378 private:
00379   vec tmpvecD;
00380   vec tmpvecK;
00381 
00382 };
00383 
00384 } // namespace itpp
00385 
00386 #endif // #ifndef MOG_GENERIC_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Wed Jul 27 2011 16:27:05 for IT++ by Doxygen 1.7.4