36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
40 #include <vigra/mathutil.hxx>
41 #include "rf_common.hxx"
62 template<
class Tag,
class LabelType,
class T1,
class C1,
class T2,
class C2>
77 switch(options.mtry_switch_)
80 ext_param.actual_mtry_ =
82 std::sqrt(
double(ext_param.column_count_))
87 ext_param.actual_mtry_ =
88 int(1+(std::log(
double(ext_param.column_count_))
92 ext_param.actual_mtry_ =
93 options.mtry_func_(ext_param.column_count_);
96 ext_param.actual_mtry_ = ext_param.column_count_;
99 ext_param.actual_mtry_ =
103 switch(options.training_set_calc_switch_)
106 ext_param.actual_msample_ =
107 options.training_set_size_;
109 case RF_PROPORTIONAL:
110 ext_param.actual_msample_ =
111 static_cast<int>(std::ceil(options.training_set_proportion_ *
112 ext_param.row_count_));
115 ext_param.actual_msample_ =
116 options.training_set_func_(ext_param.row_count_);
119 vigra_precondition(1!= 1,
"unexpected error");
127 template<
unsigned int N,
class T,
class C>
128 bool contains_nan(MultiArrayView<N, T, C>
const & in)
131 Iter i = in.begin(), end = in.end();
133 if(isnan(NumericTraits<T>::toRealPromote(*i)))
140 template<
unsigned int N,
class T,
class C>
141 bool contains_inf(MultiArrayView<N, T, C>
const & in)
143 if(!std::numeric_limits<T>::has_infinity)
146 Iter i = in.begin(), end = in.end();
148 if(
abs(*i) == std::numeric_limits<T>::infinity())
161 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
162 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
165 typedef Int32 LabelInt;
181 vigra_precondition(!detail::contains_nan(features),
"RandomForest(): Feature matrix "
183 vigra_precondition(!detail::contains_nan(response),
"RandomForest(): Response "
185 vigra_precondition(!detail::contains_inf(features),
"RandomForest(): Feature matrix "
187 vigra_precondition(!detail::contains_inf(response),
"RandomForest(): Response "
190 ext_param.column_count_ = features.
shape(1);
191 ext_param.row_count_ = features.
shape(0);
192 ext_param.problem_type_ = CLASSIFICATION;
193 ext_param.used_ =
true;
197 if(ext_param.class_count_ == 0)
201 std::set<T2> labelToInt;
203 labelToInt.insert(response(k,0));
204 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
205 ext_param.
classes_(tmp_.begin(), tmp_.end());
209 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
211 throw std::runtime_error(
"RandomForest(): invalid label in training data.");
214 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
215 - ext_param.classes.begin();
218 if(ext_param.class_weights_.size() == 0)
221 tmp(
static_cast<std::size_t
>(ext_param.class_count_),
222 NumericTraits<T2>::one());
227 detail::fill_external_parameters(options, ext_param);
230 strata_ = intLabels_;
268 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
269 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
292 ext_param_(ext_param)
295 ext_param.column_count_ = features.
shape(1);
296 ext_param.row_count_ = features.
shape(0);
297 ext_param.problem_type_ = REGRESSION;
298 ext_param.used_ =
true;
299 detail::fill_external_parameters(options, ext_param);
300 vigra_precondition(!detail::contains_nan(features),
"Processor(): Feature Matrix "
302 vigra_precondition(!detail::contains_nan(response),
"Processor(): Response "
304 vigra_precondition(!detail::contains_inf(features),
"Processor(): Feature Matrix "
306 vigra_precondition(!detail::contains_inf(response),
"Processor(): Response "
309 ext_param.response_size_ = response.
shape(1);
310 ext_param.class_count_ = response_.
shape(1);
311 std::vector<T2> tmp_(ext_param.class_count_, 0);
312 ext_param.
classes_(tmp_.begin(), tmp_.end());
337 #endif //VIGRA_RF_PREPROCESSING_HXX