Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@
^CRAN-SUBMISSION$
^[\.]?air\.toml$
^\.vscode$

CLAUDE.md
^\.claude$
90 changes: 90 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Overview

dials is an R package that provides infrastructure for creating and managing tuning parameter values in the tidymodels ecosystem. It defines parameter objects, sets of parameters, and methods for generating parameter grids for model tuning.

## Key development commands

General advice:
* When running R from the console, always run it with `--quiet --vanilla`
* Always run `air format .` after generating code

### Testing

- Use `devtools::test()` to run all tests
- Use `devtools::test_file("tests/testthat/test-filename.R")` to run tests in a specific file
- DO NOT USE `devtools::test_active_file()`
- All testing functions automatically load code; you don't need to.

- All new code should have an accompanying test.
- Tests for `R/{name}.R` go in `tests/testthat/test-{name}.R`.
- If there are existing tests, place new tests next to similar existing tests.

### Documentation

- Run `devtools::document()` after changing any roxygen2 docs.
- Every user facing function should be exported and have roxygen2 documentation.
- Whenever you add a new documentation file, make sure to also add the topic name to `_pkgdown.yml`.
- Run `pkgdown::check_pkgdown()` to check that all topics are included in the reference index.
- Use sentence case for all headings


## Architecture

### Core Parameter System

The package is built around two main parameter types:

1. **`quant_param`**: Quantitative parameters (continuous or integer)
- Created via `new_quant_param()` in `R/constructors.R`
- Has `range` (lower/upper bounds), `inclusive`, optional `trans` (transformation), and `finalize` function
- Examples: `penalty()`, `mtry()`, `learn_rate()`

2. **`qual_param`**: Qualitative parameters (categorical)
- Created via `new_qual_param()` in `R/constructors.R`
- Has discrete `values` (character or logical)
- Examples: `activation()`, `weight_func()`

### Parameter Organization

- **Individual parameters**: parameter definition files (`R/param_*.R`), each defining specific tuning parameters used across tidymodels
- **Parameter sets**: The `parameters` class (defined in `R/parameters.R`) groups multiple parameters into a data frame-like structure

### Grid Generation

Three main grid types (in `R/grids.R` and `R/space_filling.R`):

1. **Regular grids** (`grid_regular()`): Factorial designs with evenly-spaced values
2. **Random grids** (`grid_random()`): Random sampling from parameter ranges
3. **Space-filling grids** (`grid_space_filling()`): Experimental designs (Latin hypercube, max entropy, etc.) that efficiently cover the parameter space

All grid functions:
- Accept parameter objects or parameter sets
- Return tibbles with one column per parameter

### Finalization System

Many parameters have `unknown()` ranges that depend on the dataset (e.g., `mtry()` depends on the number of predictors). The finalization system (`R/finalize.R`) resolves these:

- `finalize()`: Generic function that calls the parameter's embedded `finalize` function
- `get_*()`: Various functions that get and set parameter ranges based on data characteristics

### Infrastructure Files

Files prefixed with `aaa_` load first and define foundational classes:
- `R/aaa_ranges.R`: Handling and validation of parameter ranges
- `R/aaa_unknown.R`: The `unknown()` placeholder for unspecified parameter bounds
- `R/aaa_values.R`: Validation, generation, and transformation of parameter values

Files prefixed with `compat-` provide compatibility with dplyr and vctrs for parameter objects.

## Integration with tidymodels

dials is infrastructure-level; it defines parameters but doesn't perform tuning. The tune package uses dials for actual hyperparameter tuning. Parameter objects integrate with:
- **parsnip**: Model specifications reference dials parameters
- **recipes**: Preprocessing steps use dials parameters
- **workflows**: Workflows combine models and preprocessing that utilize dials parameters
- **tune**: Grid search and optimization consume parameter grids
6 changes: 6 additions & 0 deletions R/aaa_ranges.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ range_validate <- function(
if (!any(is_num)) {
cli::cli_abort("{.arg range} should be numeric.", call = call)
}
if (range[[1]] > range[[2]]) {
cli::cli_abort(
"The {.arg range} lower bound ({range[[1]]}) must not exceed upper bound ({range[[2]]}).",
call = call
)
}

# TODO check with transform
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/constructors.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ new_quant_param <- function(

type <- arg_match0(type, values = c("double", "integer"))

check_values_quant(values, call = call)
check_values_quant(values, type = type, call = call)

if (!is.null(values)) {
# fill in range if user didn't supply one
Expand Down
108 changes: 85 additions & 23 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,70 @@ format_bounds <- function(bnds) {

# checking functions -----------------------------------------------------------

check_label <- function(label, ..., call = caller_env()) {
check_label <- function(
x,
...,
arg = caller_arg(x),
call = caller_env()
) {
check_dots_empty()

check_string(label, allow_null = TRUE, call = call)
check_string(x, allow_null = TRUE, arg = arg, call = call)

if (!is.null(label) && length(names(label)) != 1) {
if (!is.null(x) && length(names(x)) != 1) {
cli::cli_abort(
"{.arg label} must be named.",
"{.arg {arg}} must be named.",
call = call
)
}

invisible(NULL)
}

check_range <- function(x, type, trans, ..., call = caller_env()) {
check_range <- function(
x,
type,
trans,
...,
arg = caller_arg(x),
call = caller_env()
) {
check_dots_empty()

if (length(x) != 2) {
cli::cli_abort(
"{.arg {arg}} must have 2 elements, not {length(x)}.",
call = call,
arg = arg
)
}

known <- !is_unknown(x)

if (any(known) && !all(map_lgl(x[known], is.numeric))) {
cli::cli_abort(
"{.arg {arg}} must be numeric (or {.fn unknown}).",
call = call
)
}

if (all(known) && !anyNA(x) && x[[1]] > x[[2]]) {
cli::cli_abort(
"The {.arg {arg}} lower bound ({x[[1]]}) must not exceed upper bound ({x[[2]]}).",
call = call
)
}

if (!is.null(trans)) {
return(invisible(x))
}

# only do this after `arg` is used but do it because
# this makes x0[known] <- as.integer(x0[known]) below work for e.g. c(1, 10)
Comment on lines +96 to +97
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x <- c(1, 10)
class(x)
#> [1] "numeric"
typeof(x)
#> [1] "double"
purrr::map(x, typeof)
#> [[1]]
#> [1] "double"
#> 
#> [[2]]
#> [1] "double"

x1 <- as.integer(x)
class(x1)
#> [1] "integer"
typeof(x1)
#> [1] "integer"
purrr::map(x1, typeof)
#> [[1]]
#> [1] "integer"
#> 
#> [[2]]
#> [1] "integer"

# but
x[c(TRUE, TRUE)] <- as.integer(x)
class(x)
#> [1] "numeric"
typeof(x)
#> [1] "double"
purrr::map(x, typeof)
#> [[1]]
#> [1] "double"
#> 
#> [[2]]
#> [1] "double"

Created on 2026-01-25 with reprex v2.1.1

if (!is.list(x)) {
x <- as.list(x)
}
x0 <- x
known <- !is_unknown(x)
x <- x[known]
x_type <- purrr::map_chr(x, typeof)
wrong_type <- any(x_type != type)
Expand All @@ -84,7 +123,7 @@ check_range <- function(x, type, trans, ..., call = caller_env()) {
x0[known] <- as.integer(x0[known])
} else {
cli::cli_abort(
"Since {.code type = \"{type}\"}, please use that data type for the
"Since {.code type = \"{type}\"}, please use that data type for the
range.",
call = call
)
Expand All @@ -93,45 +132,68 @@ check_range <- function(x, type, trans, ..., call = caller_env()) {
invisible(x0)
}

check_values_quant <- function(x, ..., call = caller_env()) {
check_values_quant <- function(
x,
type = NULL,
...,
arg = caller_arg(x),
call = caller_env()
) {
check_dots_empty()

if (is.null(x)) {
return(invisible(x))
}

if (!is.numeric(x)) {
cli::cli_abort("{.arg values} must be numeric.", call = call)
cli::cli_abort("{.arg {arg}} must be numeric.", call = call)
}

if (anyNA(x)) {
cli::cli_abort("{.arg values} can't be {.code NA}.", call = call)
cli::cli_abort("{.arg {arg}} can't contain {.code NA} values.", call = call)
}
if (length(x) == 0) {
cli::cli_abort("{.arg values} can't be empty.", call = call)
cli::cli_abort("{.arg {arg}} can't be empty.", call = call)
}
if (anyDuplicated(x)) {
cli::cli_abort("{.arg {arg}} can't contain duplicate values.", call = call)
}

if (!is.null(type) && type == "integer") {
# logic from from ?is.integer
not_whole <- abs(x - round(x)) >= .Machine$double.eps^0.5
if (any(not_whole)) {
offenders <- x[not_whole]
cli::cli_abort(
c(
"{.arg {arg}} must contain whole numbers for integer parameters.",
x = "These are not whole numbers: {offenders}."
),
call = call
)
}
}

invisible(x)
}

check_inclusive <- function(x, ..., call = caller_env()) {
check_inclusive <- function(x, ..., arg = caller_arg(x), call = caller_env()) {
check_dots_empty()

check_logical(x, arg = arg, call = call)

if (any(is.na(x))) {
cli::cli_abort("{.arg inclusive} cannot contain missings.", call = call)
cli::cli_abort("{.arg {arg}} can't contain missing values.", call = call)
}

if (is_logical(x, n = 2)) {
return(invisible(NULL))
if (length(x) != 2) {
cli::cli_abort(
"{.arg {arg}} must have length 2, not {length(x)}.",
call = call
)
}

stop_input_type(
x,
"a logical vector of length 2",
allow_na = FALSE,
allow_null = FALSE,
arg = "inclusive",
call = call
)
invisible(NULL)
}

check_param <- function(
Expand Down
41 changes: 33 additions & 8 deletions tests/testthat/_snaps/constructors.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Code
new_quant_param("double", range = 1, inclusive = c(TRUE, TRUE))
Condition
Error in `names(range) <- names(inclusive) <- c("lower", "upper")`:
! 'names' attribute [2] must be the same length as the vector [1]
Error:
! `range` must have 2 elements, not 1.

---

Expand All @@ -45,15 +45,15 @@
new_quant_param("double", range = c(1, NA), inclusive = TRUE)
Condition
Error:
! `inclusive` must be a logical vector of length 2, not `TRUE`.
! `inclusive` must have length 2, not 1.

---

Code
new_quant_param("double", range = c(1, NA), inclusive = c("(", "]"))
Condition
Error:
! `inclusive` must be a logical vector of length 2, not a character vector.
! `inclusive` must be a logical vector, not a character vector.

---

Expand All @@ -70,15 +70,15 @@
new_quant_param("integer", range = 1:2, inclusive = c(TRUE, NA))
Condition
Error:
! `inclusive` cannot contain missings.
! `inclusive` can't contain missing values.

---

Code
new_quant_param("integer", range = 1:2, inclusive = c(TRUE, unknown()))
Condition
Error:
! `inclusive` must be a logical vector of length 2, not a list.
! `inclusive` must be a logical vector, not a list.

---

Expand Down Expand Up @@ -248,7 +248,7 @@
mixture(letters[1:2])
Condition
Error in `mixture()`:
! Since `type = "double"`, please use that data type for the range.
! `range` must be numeric (or `unknown()`).

---

Expand Down Expand Up @@ -324,7 +324,7 @@
new_quant_param(type = "integer", values = NA_integer_, label = c(foo = "Foo"))
Condition
Error:
! `values` can't be `NA`.
! `values` can't contain `NA` values.

---

Expand Down Expand Up @@ -352,3 +352,28 @@
Error:
! The `default` argument of `new_qual_param()` was deprecated in dials 1.1.0 and is now defunct.

# range ordering is validated

Code
new_quant_param("integer", range = c(10L, 1L), inclusive = c(TRUE, TRUE))
Condition
Error:
! The `range` lower bound (10) must not exceed upper bound (1).

# duplicate values are rejected

Code
new_quant_param("double", values = c(1, 2, 2, 3))
Condition
Error:
! `values` can't contain duplicate values.

# integer type requires whole number values

Code
new_quant_param("integer", values = c(1.5, 2, 3))
Condition
Error:
! `values` must contain whole numbers for integer parameters.
x These are not whole numbers: 1.5.

Loading