-
Notifications
You must be signed in to change notification settings - Fork 89
Open
Description
I am trying to perform cross validation of a RandomForest like this:
let cv_score = cross_validate(
RandomForestClassifier::new(),
x_train,
y_train,
RandomForestClassifierParameters::default()
.with_criterion(SplitCriterion::ClassificationError)
.with_n_trees(*n_tree)
.with_m(*m_feat)
.with_min_samples_split(*m_split)
.with_min_samples_leaf(*m_leaf),
&KFold::default().with_n_splits(5),
&precision
).unwrap();x_train is a DenseMatrix<f32> and y_train is a Vec<u16> of 0 and 1.
The problem is that RandomForestClassifier::fit expects y to be Number + Ord, whereas precision expects y to be Number + RealNumber + FloatNumber. As far as I can see, RealNumber can never be Ord, so precision, f1 and roc_auc_score cannot be used in cross_validate with RandomForectClassifier directly.
For now, I worked around it by defining my own precision function that converts u16 to f32 before calling precision, but I suppose this should be fixed in the framework. The logical thing to do IMO would be to stop requiring Ord for y in RandomForrestClassifier::fit?
Metadata
Metadata
Assignees
Labels
No labels