Skip to content

Trait bounds of RandomForest::fit conflict with those of predict and other metrics #322

@DaGaMs

Description

@DaGaMs

I am trying to perform cross validation of a RandomForest like this:

let cv_score = cross_validate(
    RandomForestClassifier::new(),
    x_train,
    y_train,
    RandomForestClassifierParameters::default()
        .with_criterion(SplitCriterion::ClassificationError)
        .with_n_trees(*n_tree)
        .with_m(*m_feat)
        .with_min_samples_split(*m_split)
        .with_min_samples_leaf(*m_leaf),
    &KFold::default().with_n_splits(5),
    &precision
).unwrap();

x_train is a DenseMatrix<f32> and y_train is a Vec<u16> of 0 and 1.

The problem is that RandomForestClassifier::fit expects y to be Number + Ord, whereas precision expects y to be Number + RealNumber + FloatNumber. As far as I can see, RealNumber can never be Ord, so precision, f1 and roc_auc_score cannot be used in cross_validate with RandomForectClassifier directly.

For now, I worked around it by defining my own precision function that converts u16 to f32 before calling precision, but I suppose this should be fixed in the framework. The logical thing to do IMO would be to stop requiring Ord for y in RandomForrestClassifier::fit?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions