Skip to content

augment predictions are different from collect_predictions #1110

@rvalieris

Description

@rvalieris

Hello, it seems that augment is returning different predictions compared to collect_predictions.

In the example below, the kappa metric was calculated to be 0.9, however if I compare the predictions from augment with the ground truth I would find that most of the predictions are wrong.

The dataset returned by augment seems to be in the same order as testing(split), it just the prediction column that seems to be wrong.

All packages are up to date, as far as I can tell.

library(tidymodels)
library(beans)

set.seed(42)
split <- initial_split(beans, prop = 0.85, strata = "class")

rec <- recipe(class ~ ., data = training(split)) |>
  step_normalize(all_predictors()) |>
  step_pca(all_predictors(), keep_original_cols = TRUE, threshold = 0.75)

spec <- rand_forest(
  mode = 'classification',
  mtry = 11,
  trees = 100,
  min_n = 2
) |>
  set_engine("ranger", num.threads = 8)

res = last_fit(spec, rec, split, metrics = metric_set(kap))

collect_metrics(res)
#> # A tibble: 1 × 4
#>   .metric .estimator .estimate .config        
#>   <chr>   <chr>          <dbl> <chr>          
#> 1 kap     multiclass     0.907 pre0_mod0_post0

collect_predictions(res) |>
  count(.pred_class == class)
#> # A tibble: 2 × 2
#>   `.pred_class == class`     n
#>   <lgl>                  <int>
#> 1 FALSE                    157
#> 2 TRUE                    1887

augment(res) |>
  count(.pred_class == class)
#> # A tibble: 2 × 2
#>   `.pred_class == class`     n
#>   <lgl>                  <int>
#> 1 FALSE                   1939
#> 2 TRUE                     105

augment(res)
#> # A tibble: 2,044 × 18
#>    .pred_class  area perimeter major_axis_length minor_axis_length aspect_ratio
#>    <fct>       <dbl>     <dbl>             <dbl>             <dbl>        <dbl>
#>  1 sira        30477      670.              211.              184.         1.15
#>  2 barbunya    30519      630.              213.              183.         1.17
#>  3 sira        31637      657.              230.              176.         1.31
#>  4 barbunya    31821      650.              214.              190.         1.13
#>  5 seker       32015      654.              213.              192.         1.11
#>  6 sira        32026      654.              231.              177.         1.31
#>  7 sira        32044      653.              216.              189.         1.14
#>  8 barbunya    32066      669.              236.              173.         1.36
#>  9 barbunya    32262      652.              223.              185.         1.21
#> 10 barbunya    32829      658.              222.              189.         1.17
#> # ℹ 2,034 more rows
#> # ℹ 12 more variables: eccentricity <dbl>, convex_area <dbl>,
#> #   equiv_diameter <dbl>, extent <dbl>, solidity <dbl>, roundness <dbl>,
#> #   compactness <dbl>, shape_factor_1 <dbl>, shape_factor_2 <dbl>,
#> #   shape_factor_3 <dbl>, shape_factor_4 <dbl>, class <fct>

collect_predictions(res)
#> # A tibble: 2,044 × 5
#>    .pred_class id               class  .row .config        
#>    <fct>       <chr>            <fct> <int> <chr>          
#>  1 seker       train/test split seker     7 pre0_mod0_post0
#>  2 seker       train/test split seker     8 pre0_mod0_post0
#>  3 dermason    train/test split seker    24 pre0_mod0_post0
#>  4 seker       train/test split seker    31 pre0_mod0_post0
#>  5 seker       train/test split seker    38 pre0_mod0_post0
#>  6 dermason    train/test split seker    39 pre0_mod0_post0
#>  7 seker       train/test split seker    41 pre0_mod0_post0
#>  8 dermason    train/test split seker    43 pre0_mod0_post0
#>  9 seker       train/test split seker    52 pre0_mod0_post0
#> 10 seker       train/test split seker    78 pre0_mod0_post0
#> # ℹ 2,034 more rows

all(testing(split)$class == augment(res)$class)
#> [1] TRUE

all(testing(split)$perimeter == augment(res)$perimeter)
#> [1] TRUE

all(testing(split)$class == collect_predictions(res)$class)
#> [1] TRUE

Created on 2025-11-03 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions