diff --git a/.Rbuildignore b/.Rbuildignore index 5f4a347a..6c470e26 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -21,3 +21,6 @@ ^CRAN-SUBMISSION$ ^[\.]?air\.toml$ ^\.vscode$ + +CLAUDE.md +^\.claude$ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..d61ca5f6 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,90 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +dials is an R package that provides infrastructure for creating and managing tuning parameter values in the tidymodels ecosystem. It defines parameter objects, sets of parameters, and methods for generating parameter grids for model tuning. + +## Key development commands + +General advice: +* When running R from the console, always run it with `--quiet --vanilla` +* Always run `air format .` after generating code + +### Testing + +- Use `devtools::test()` to run all tests +- Use `devtools::test_file("tests/testthat/test-filename.R")` to run tests in a specific file +- DO NOT USE `devtools::test_active_file()` +- All testing functions automatically load code; you don't need to. + +- All new code should have an accompanying test. +- Tests for `R/{name}.R` go in `tests/testthat/test-{name}.R`. +- If there are existing tests, place new tests next to similar existing tests. + +### Documentation + +- Run `devtools::document()` after changing any roxygen2 docs. +- Every user facing function should be exported and have roxygen2 documentation. +- Whenever you add a new documentation file, make sure to also add the topic name to `_pkgdown.yml`. +- Run `pkgdown::check_pkgdown()` to check that all topics are included in the reference index. +- Use sentence case for all headings + + +## Architecture + +### Core Parameter System + +The package is built around two main parameter types: + +1. **`quant_param`**: Quantitative parameters (continuous or integer) + - Created via `new_quant_param()` in `R/constructors.R` + - Has `range` (lower/upper bounds), `inclusive`, optional `trans` (transformation), and `finalize` function + - Examples: `penalty()`, `mtry()`, `learn_rate()` + +2. **`qual_param`**: Qualitative parameters (categorical) + - Created via `new_qual_param()` in `R/constructors.R` + - Has discrete `values` (character or logical) + - Examples: `activation()`, `weight_func()` + +### Parameter Organization + +- **Individual parameters**: parameter definition files (`R/param_*.R`), each defining specific tuning parameters used across tidymodels +- **Parameter sets**: The `parameters` class (defined in `R/parameters.R`) groups multiple parameters into a data frame-like structure + +### Grid Generation + +Three main grid types (in `R/grids.R` and `R/space_filling.R`): + +1. **Regular grids** (`grid_regular()`): Factorial designs with evenly-spaced values +2. **Random grids** (`grid_random()`): Random sampling from parameter ranges +3. **Space-filling grids** (`grid_space_filling()`): Experimental designs (Latin hypercube, max entropy, etc.) that efficiently cover the parameter space + +All grid functions: +- Accept parameter objects or parameter sets +- Return tibbles with one column per parameter + +### Finalization System + +Many parameters have `unknown()` ranges that depend on the dataset (e.g., `mtry()` depends on the number of predictors). The finalization system (`R/finalize.R`) resolves these: + +- `finalize()`: Generic function that calls the parameter's embedded `finalize` function +- `get_*()`: Various functions that get and set parameter ranges based on data characteristics + +### Infrastructure Files + +Files prefixed with `aaa_` load first and define foundational classes: +- `R/aaa_ranges.R`: Handling and validation of parameter ranges +- `R/aaa_unknown.R`: The `unknown()` placeholder for unspecified parameter bounds +- `R/aaa_values.R`: Validation, generation, and transformation of parameter values + +Files prefixed with `compat-` provide compatibility with dplyr and vctrs for parameter objects. + +## Integration with tidymodels + +dials is infrastructure-level; it defines parameters but doesn't perform tuning. The tune package uses dials for actual hyperparameter tuning. Parameter objects integrate with: +- **parsnip**: Model specifications reference dials parameters +- **recipes**: Preprocessing steps use dials parameters +- **workflows**: Workflows combine models and preprocessing that utilize dials parameters +- **tune**: Grid search and optimization consume parameter grids diff --git a/R/aaa_ranges.R b/R/aaa_ranges.R index bdf9e676..5b9d1c9d 100644 --- a/R/aaa_ranges.R +++ b/R/aaa_ranges.R @@ -81,6 +81,12 @@ range_validate <- function( if (!any(is_num)) { cli::cli_abort("{.arg range} should be numeric.", call = call) } + if (range[[1]] > range[[2]]) { + cli::cli_abort( + "The {.arg range} lower bound ({range[[1]]}) must not exceed upper bound ({range[[2]]}).", + call = call + ) + } # TODO check with transform } else { diff --git a/R/constructors.R b/R/constructors.R index 6dca8e4a..4e8501a1 100644 --- a/R/constructors.R +++ b/R/constructors.R @@ -100,7 +100,7 @@ new_quant_param <- function( type <- arg_match0(type, values = c("double", "integer")) - check_values_quant(values, call = call) + check_values_quant(values, type = type, call = call) if (!is.null(values)) { # fill in range if user didn't supply one diff --git a/R/misc.R b/R/misc.R index eb6efcde..4d7458d2 100644 --- a/R/misc.R +++ b/R/misc.R @@ -35,14 +35,19 @@ format_bounds <- function(bnds) { # checking functions ----------------------------------------------------------- -check_label <- function(label, ..., call = caller_env()) { +check_label <- function( + x, + ..., + arg = caller_arg(x), + call = caller_env() +) { check_dots_empty() - check_string(label, allow_null = TRUE, call = call) + check_string(x, allow_null = TRUE, arg = arg, call = call) - if (!is.null(label) && length(names(label)) != 1) { + if (!is.null(x) && length(names(x)) != 1) { cli::cli_abort( - "{.arg label} must be named.", + "{.arg {arg}} must be named.", call = call ) } @@ -50,16 +55,50 @@ check_label <- function(label, ..., call = caller_env()) { invisible(NULL) } -check_range <- function(x, type, trans, ..., call = caller_env()) { +check_range <- function( + x, + type, + trans, + ..., + arg = caller_arg(x), + call = caller_env() +) { check_dots_empty() + + if (length(x) != 2) { + cli::cli_abort( + "{.arg {arg}} must have 2 elements, not {length(x)}.", + call = call, + arg = arg + ) + } + + known <- !is_unknown(x) + + if (any(known) && !all(map_lgl(x[known], is.numeric))) { + cli::cli_abort( + "{.arg {arg}} must be numeric (or {.fn unknown}).", + call = call + ) + } + + if (all(known) && !anyNA(x) && x[[1]] > x[[2]]) { + cli::cli_abort( + "The {.arg {arg}} lower bound ({x[[1]]}) must not exceed upper bound ({x[[2]]}).", + call = call + ) + } + if (!is.null(trans)) { return(invisible(x)) } + + # only do this after `arg` is used but do it because + # this makes x0[known] <- as.integer(x0[known]) below work for e.g. c(1, 10) if (!is.list(x)) { x <- as.list(x) } x0 <- x - known <- !is_unknown(x) x <- x[known] x_type <- purrr::map_chr(x, typeof) wrong_type <- any(x_type != type) @@ -84,7 +123,7 @@ check_range <- function(x, type, trans, ..., call = caller_env()) { x0[known] <- as.integer(x0[known]) } else { cli::cli_abort( - "Since {.code type = \"{type}\"}, please use that data type for the + "Since {.code type = \"{type}\"}, please use that data type for the range.", call = call ) @@ -93,7 +132,13 @@ check_range <- function(x, type, trans, ..., call = caller_env()) { invisible(x0) } -check_values_quant <- function(x, ..., call = caller_env()) { +check_values_quant <- function( + x, + type = NULL, + ..., + arg = caller_arg(x), + call = caller_env() +) { check_dots_empty() if (is.null(x)) { @@ -101,37 +146,54 @@ check_values_quant <- function(x, ..., call = caller_env()) { } if (!is.numeric(x)) { - cli::cli_abort("{.arg values} must be numeric.", call = call) + cli::cli_abort("{.arg {arg}} must be numeric.", call = call) } + if (anyNA(x)) { - cli::cli_abort("{.arg values} can't be {.code NA}.", call = call) + cli::cli_abort("{.arg {arg}} can't contain {.code NA} values.", call = call) } if (length(x) == 0) { - cli::cli_abort("{.arg values} can't be empty.", call = call) + cli::cli_abort("{.arg {arg}} can't be empty.", call = call) + } + if (anyDuplicated(x)) { + cli::cli_abort("{.arg {arg}} can't contain duplicate values.", call = call) + } + + if (!is.null(type) && type == "integer") { + # logic from from ?is.integer + not_whole <- abs(x - round(x)) >= .Machine$double.eps^0.5 + if (any(not_whole)) { + offenders <- x[not_whole] + cli::cli_abort( + c( + "{.arg {arg}} must contain whole numbers for integer parameters.", + x = "These are not whole numbers: {offenders}." + ), + call = call + ) + } } invisible(x) } -check_inclusive <- function(x, ..., call = caller_env()) { +check_inclusive <- function(x, ..., arg = caller_arg(x), call = caller_env()) { check_dots_empty() + check_logical(x, arg = arg, call = call) + if (any(is.na(x))) { - cli::cli_abort("{.arg inclusive} cannot contain missings.", call = call) + cli::cli_abort("{.arg {arg}} can't contain missing values.", call = call) } - if (is_logical(x, n = 2)) { - return(invisible(NULL)) + if (length(x) != 2) { + cli::cli_abort( + "{.arg {arg}} must have length 2, not {length(x)}.", + call = call + ) } - stop_input_type( - x, - "a logical vector of length 2", - allow_na = FALSE, - allow_null = FALSE, - arg = "inclusive", - call = call - ) + invisible(NULL) } check_param <- function( diff --git a/tests/testthat/_snaps/constructors.md b/tests/testthat/_snaps/constructors.md index 599ff672..992839c8 100644 --- a/tests/testthat/_snaps/constructors.md +++ b/tests/testthat/_snaps/constructors.md @@ -27,8 +27,8 @@ Code new_quant_param("double", range = 1, inclusive = c(TRUE, TRUE)) Condition - Error in `names(range) <- names(inclusive) <- c("lower", "upper")`: - ! 'names' attribute [2] must be the same length as the vector [1] + Error: + ! `range` must have 2 elements, not 1. --- @@ -45,7 +45,7 @@ new_quant_param("double", range = c(1, NA), inclusive = TRUE) Condition Error: - ! `inclusive` must be a logical vector of length 2, not `TRUE`. + ! `inclusive` must have length 2, not 1. --- @@ -53,7 +53,7 @@ new_quant_param("double", range = c(1, NA), inclusive = c("(", "]")) Condition Error: - ! `inclusive` must be a logical vector of length 2, not a character vector. + ! `inclusive` must be a logical vector, not a character vector. --- @@ -70,7 +70,7 @@ new_quant_param("integer", range = 1:2, inclusive = c(TRUE, NA)) Condition Error: - ! `inclusive` cannot contain missings. + ! `inclusive` can't contain missing values. --- @@ -78,7 +78,7 @@ new_quant_param("integer", range = 1:2, inclusive = c(TRUE, unknown())) Condition Error: - ! `inclusive` must be a logical vector of length 2, not a list. + ! `inclusive` must be a logical vector, not a list. --- @@ -248,7 +248,7 @@ mixture(letters[1:2]) Condition Error in `mixture()`: - ! Since `type = "double"`, please use that data type for the range. + ! `range` must be numeric (or `unknown()`). --- @@ -324,7 +324,7 @@ new_quant_param(type = "integer", values = NA_integer_, label = c(foo = "Foo")) Condition Error: - ! `values` can't be `NA`. + ! `values` can't contain `NA` values. --- @@ -352,3 +352,28 @@ Error: ! The `default` argument of `new_qual_param()` was deprecated in dials 1.1.0 and is now defunct. +# range ordering is validated + + Code + new_quant_param("integer", range = c(10L, 1L), inclusive = c(TRUE, TRUE)) + Condition + Error: + ! The `range` lower bound (10) must not exceed upper bound (1). + +# duplicate values are rejected + + Code + new_quant_param("double", values = c(1, 2, 2, 3)) + Condition + Error: + ! `values` can't contain duplicate values. + +# integer type requires whole number values + + Code + new_quant_param("integer", values = c(1.5, 2, 3)) + Condition + Error: + ! `values` must contain whole numbers for integer parameters. + x These are not whole numbers: 1.5. + diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index f14e4982..51350764 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -4,7 +4,7 @@ check_label("unnamed label") Condition Error: - ! `label` must be named. + ! `"unnamed label"` must be named. --- @@ -12,7 +12,7 @@ check_label(c("more", "than", "one", "label")) Condition Error: - ! `label` must be a single string or `NULL`, not a character vector. + ! `c("more", "than", "one", "label")` must be a single string or `NULL`, not a character vector. # check_values_quant() @@ -20,7 +20,7 @@ check_values_quant("should have been a numeric") Condition Error: - ! `values` must be numeric. + ! `"should have been a numeric"` must be numeric. --- @@ -28,7 +28,7 @@ check_values_quant(c(1, NA)) Condition Error: - ! `values` can't be `NA`. + ! `c(1, NA)` can't contain `NA` values. --- @@ -36,7 +36,7 @@ check_values_quant(numeric()) Condition Error: - ! `values` can't be empty. + ! `numeric()` can't be empty. # check_inclusive() @@ -44,7 +44,7 @@ check_inclusive(TRUE) Condition Error: - ! `inclusive` must be a logical vector of length 2, not `TRUE`. + ! `TRUE` must have length 2, not 1. --- @@ -52,7 +52,7 @@ check_inclusive(NULL) Condition Error: - ! `inclusive` must be a logical vector of length 2, not `NULL`. + ! `NULL` must be a logical vector, not `NULL`. --- @@ -60,7 +60,7 @@ check_inclusive(c(TRUE, NA)) Condition Error: - ! `inclusive` cannot contain missings. + ! `c(TRUE, NA)` can't contain missing values. --- @@ -68,7 +68,7 @@ check_inclusive(1:2) Condition Error: - ! `inclusive` must be a logical vector of length 2, not an integer vector. + ! `1:2` must be a logical vector, not an integer vector. # vctrs-helpers-parameters diff --git a/tests/testthat/test-constructors.R b/tests/testthat/test-constructors.R index 3a679386..3e072bf4 100644 --- a/tests/testthat/test-constructors.R +++ b/tests/testthat/test-constructors.R @@ -274,3 +274,24 @@ test_that("`default` arg is deprecated", { ) }) }) + +test_that("range ordering is validated", { + expect_snapshot(error = TRUE, { + new_quant_param("integer", range = c(10L, 1L), inclusive = c(TRUE, TRUE)) + }) + expect_no_error({ + new_quant_param("integer", range = c(5L, 5L), inclusive = c(TRUE, TRUE)) + }) +}) + +test_that("duplicate values are rejected", { + expect_snapshot(error = TRUE, { + new_quant_param("double", values = c(1, 2, 2, 3)) + }) +}) + +test_that("integer type requires whole number values", { + expect_snapshot(error = TRUE, { + new_quant_param("integer", values = c(1.5, 2, 3)) + }) +}) diff --git a/tests/testthat/test-params.R b/tests/testthat/test-params.R index 431faf52..23924694 100644 --- a/tests/testthat/test-params.R +++ b/tests/testthat/test-params.R @@ -135,7 +135,7 @@ test_that("param ranges", { list(lower = 0.1, upper = 0.4) ) expect_equal(target_weight(c(0.1, 0.4))$range, list(lower = 0.1, upper = 0.4)) - expect_equal(lower_limit(c(Inf, 0))$range, list(lower = Inf, upper = 0)) + expect_equal(lower_limit(c(-Inf, 0))$range, list(lower = -Inf, upper = 0)) expect_equal(upper_limit(c(0, Inf))$range, list(lower = 0, upper = Inf)) expect_equal(max_num_terms(c(31, 100))$range, list(lower = 31, upper = 100)) expect_equal(max_nodes(c(31, 100))$range, list(lower = 31, upper = 100))