-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behavior
Description
Hello, it seems that augment is returning different predictions compared to collect_predictions.
In the example below, the kappa metric was calculated to be 0.9, however if I compare the predictions from augment with the ground truth I would find that most of the predictions are wrong.
The dataset returned by augment seems to be in the same order as testing(split), it just the prediction column that seems to be wrong.
All packages are up to date, as far as I can tell.
library(tidymodels)
library(beans)
set.seed(42)
split <- initial_split(beans, prop = 0.85, strata = "class")
rec <- recipe(class ~ ., data = training(split)) |>
step_normalize(all_predictors()) |>
step_pca(all_predictors(), keep_original_cols = TRUE, threshold = 0.75)
spec <- rand_forest(
mode = 'classification',
mtry = 11,
trees = 100,
min_n = 2
) |>
set_engine("ranger", num.threads = 8)
res = last_fit(spec, rec, split, metrics = metric_set(kap))
collect_metrics(res)
#> # A tibble: 1 × 4
#> .metric .estimator .estimate .config
#> <chr> <chr> <dbl> <chr>
#> 1 kap multiclass 0.907 pre0_mod0_post0
collect_predictions(res) |>
count(.pred_class == class)
#> # A tibble: 2 × 2
#> `.pred_class == class` n
#> <lgl> <int>
#> 1 FALSE 157
#> 2 TRUE 1887
augment(res) |>
count(.pred_class == class)
#> # A tibble: 2 × 2
#> `.pred_class == class` n
#> <lgl> <int>
#> 1 FALSE 1939
#> 2 TRUE 105
augment(res)
#> # A tibble: 2,044 × 18
#> .pred_class area perimeter major_axis_length minor_axis_length aspect_ratio
#> <fct> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 sira 30477 670. 211. 184. 1.15
#> 2 barbunya 30519 630. 213. 183. 1.17
#> 3 sira 31637 657. 230. 176. 1.31
#> 4 barbunya 31821 650. 214. 190. 1.13
#> 5 seker 32015 654. 213. 192. 1.11
#> 6 sira 32026 654. 231. 177. 1.31
#> 7 sira 32044 653. 216. 189. 1.14
#> 8 barbunya 32066 669. 236. 173. 1.36
#> 9 barbunya 32262 652. 223. 185. 1.21
#> 10 barbunya 32829 658. 222. 189. 1.17
#> # ℹ 2,034 more rows
#> # ℹ 12 more variables: eccentricity <dbl>, convex_area <dbl>,
#> # equiv_diameter <dbl>, extent <dbl>, solidity <dbl>, roundness <dbl>,
#> # compactness <dbl>, shape_factor_1 <dbl>, shape_factor_2 <dbl>,
#> # shape_factor_3 <dbl>, shape_factor_4 <dbl>, class <fct>
collect_predictions(res)
#> # A tibble: 2,044 × 5
#> .pred_class id class .row .config
#> <fct> <chr> <fct> <int> <chr>
#> 1 seker train/test split seker 7 pre0_mod0_post0
#> 2 seker train/test split seker 8 pre0_mod0_post0
#> 3 dermason train/test split seker 24 pre0_mod0_post0
#> 4 seker train/test split seker 31 pre0_mod0_post0
#> 5 seker train/test split seker 38 pre0_mod0_post0
#> 6 dermason train/test split seker 39 pre0_mod0_post0
#> 7 seker train/test split seker 41 pre0_mod0_post0
#> 8 dermason train/test split seker 43 pre0_mod0_post0
#> 9 seker train/test split seker 52 pre0_mod0_post0
#> 10 seker train/test split seker 78 pre0_mod0_post0
#> # ℹ 2,034 more rows
all(testing(split)$class == augment(res)$class)
#> [1] TRUE
all(testing(split)$perimeter == augment(res)$perimeter)
#> [1] TRUE
all(testing(split)$class == collect_predictions(res)$class)
#> [1] TRUECreated on 2025-11-03 with reprex v2.1.1
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behavior