Skip to content

Commit 11ec401

Browse files
authored
Merge pull request #414 from tidymodels/claude-code
Improve type-checking
2 parents d5888de + 7c1cba5 commit 11ec401

File tree

9 files changed

+249
-42
lines changed

9 files changed

+249
-42
lines changed

.Rbuildignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@
2121
^CRAN-SUBMISSION$
2222
^[\.]?air\.toml$
2323
^\.vscode$
24+
25+
CLAUDE.md
26+
^\.claude$

CLAUDE.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Overview
6+
7+
dials is an R package that provides infrastructure for creating and managing tuning parameter values in the tidymodels ecosystem. It defines parameter objects, sets of parameters, and methods for generating parameter grids for model tuning.
8+
9+
## Key development commands
10+
11+
General advice:
12+
* When running R from the console, always run it with `--quiet --vanilla`
13+
* Always run `air format .` after generating code
14+
15+
### Testing
16+
17+
- Use `devtools::test()` to run all tests
18+
- Use `devtools::test_file("tests/testthat/test-filename.R")` to run tests in a specific file
19+
- DO NOT USE `devtools::test_active_file()`
20+
- All testing functions automatically load code; you don't need to.
21+
22+
- All new code should have an accompanying test.
23+
- Tests for `R/{name}.R` go in `tests/testthat/test-{name}.R`.
24+
- If there are existing tests, place new tests next to similar existing tests.
25+
26+
### Documentation
27+
28+
- Run `devtools::document()` after changing any roxygen2 docs.
29+
- Every user facing function should be exported and have roxygen2 documentation.
30+
- Whenever you add a new documentation file, make sure to also add the topic name to `_pkgdown.yml`.
31+
- Run `pkgdown::check_pkgdown()` to check that all topics are included in the reference index.
32+
- Use sentence case for all headings
33+
34+
35+
## Architecture
36+
37+
### Core Parameter System
38+
39+
The package is built around two main parameter types:
40+
41+
1. **`quant_param`**: Quantitative parameters (continuous or integer)
42+
- Created via `new_quant_param()` in `R/constructors.R`
43+
- Has `range` (lower/upper bounds), `inclusive`, optional `trans` (transformation), and `finalize` function
44+
- Examples: `penalty()`, `mtry()`, `learn_rate()`
45+
46+
2. **`qual_param`**: Qualitative parameters (categorical)
47+
- Created via `new_qual_param()` in `R/constructors.R`
48+
- Has discrete `values` (character or logical)
49+
- Examples: `activation()`, `weight_func()`
50+
51+
### Parameter Organization
52+
53+
- **Individual parameters**: parameter definition files (`R/param_*.R`), each defining specific tuning parameters used across tidymodels
54+
- **Parameter sets**: The `parameters` class (defined in `R/parameters.R`) groups multiple parameters into a data frame-like structure
55+
56+
### Grid Generation
57+
58+
Three main grid types (in `R/grids.R` and `R/space_filling.R`):
59+
60+
1. **Regular grids** (`grid_regular()`): Factorial designs with evenly-spaced values
61+
2. **Random grids** (`grid_random()`): Random sampling from parameter ranges
62+
3. **Space-filling grids** (`grid_space_filling()`): Experimental designs (Latin hypercube, max entropy, etc.) that efficiently cover the parameter space
63+
64+
All grid functions:
65+
- Accept parameter objects or parameter sets
66+
- Return tibbles with one column per parameter
67+
68+
### Finalization System
69+
70+
Many parameters have `unknown()` ranges that depend on the dataset (e.g., `mtry()` depends on the number of predictors). The finalization system (`R/finalize.R`) resolves these:
71+
72+
- `finalize()`: Generic function that calls the parameter's embedded `finalize` function
73+
- `get_*()`: Various functions that get and set parameter ranges based on data characteristics
74+
75+
### Infrastructure Files
76+
77+
Files prefixed with `aaa_` load first and define foundational classes:
78+
- `R/aaa_ranges.R`: Handling and validation of parameter ranges
79+
- `R/aaa_unknown.R`: The `unknown()` placeholder for unspecified parameter bounds
80+
- `R/aaa_values.R`: Validation, generation, and transformation of parameter values
81+
82+
Files prefixed with `compat-` provide compatibility with dplyr and vctrs for parameter objects.
83+
84+
## Integration with tidymodels
85+
86+
dials is infrastructure-level; it defines parameters but doesn't perform tuning. The tune package uses dials for actual hyperparameter tuning. Parameter objects integrate with:
87+
- **parsnip**: Model specifications reference dials parameters
88+
- **recipes**: Preprocessing steps use dials parameters
89+
- **workflows**: Workflows combine models and preprocessing that utilize dials parameters
90+
- **tune**: Grid search and optimization consume parameter grids

R/aaa_ranges.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ range_validate <- function(
8181
if (!any(is_num)) {
8282
cli::cli_abort("{.arg range} should be numeric.", call = call)
8383
}
84+
if (range[[1]] > range[[2]]) {
85+
cli::cli_abort(
86+
"The {.arg range} lower bound ({range[[1]]}) must not exceed upper bound ({range[[2]]}).",
87+
call = call
88+
)
89+
}
8490

8591
# TODO check with transform
8692
} else {

R/constructors.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ new_quant_param <- function(
100100

101101
type <- arg_match0(type, values = c("double", "integer"))
102102

103-
check_values_quant(values, call = call)
103+
check_values_quant(values, type = type, call = call)
104104

105105
if (!is.null(values)) {
106106
# fill in range if user didn't supply one

R/misc.R

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,70 @@ format_bounds <- function(bnds) {
3535

3636
# checking functions -----------------------------------------------------------
3737

38-
check_label <- function(label, ..., call = caller_env()) {
38+
check_label <- function(
39+
x,
40+
...,
41+
arg = caller_arg(x),
42+
call = caller_env()
43+
) {
3944
check_dots_empty()
4045

41-
check_string(label, allow_null = TRUE, call = call)
46+
check_string(x, allow_null = TRUE, arg = arg, call = call)
4247

43-
if (!is.null(label) && length(names(label)) != 1) {
48+
if (!is.null(x) && length(names(x)) != 1) {
4449
cli::cli_abort(
45-
"{.arg label} must be named.",
50+
"{.arg {arg}} must be named.",
4651
call = call
4752
)
4853
}
4954

5055
invisible(NULL)
5156
}
5257

53-
check_range <- function(x, type, trans, ..., call = caller_env()) {
58+
check_range <- function(
59+
x,
60+
type,
61+
trans,
62+
...,
63+
arg = caller_arg(x),
64+
call = caller_env()
65+
) {
5466
check_dots_empty()
67+
68+
if (length(x) != 2) {
69+
cli::cli_abort(
70+
"{.arg {arg}} must have 2 elements, not {length(x)}.",
71+
call = call,
72+
arg = arg
73+
)
74+
}
75+
76+
known <- !is_unknown(x)
77+
78+
if (any(known) && !all(map_lgl(x[known], is.numeric))) {
79+
cli::cli_abort(
80+
"{.arg {arg}} must be numeric (or {.fn unknown}).",
81+
call = call
82+
)
83+
}
84+
85+
if (all(known) && !anyNA(x) && x[[1]] > x[[2]]) {
86+
cli::cli_abort(
87+
"The {.arg {arg}} lower bound ({x[[1]]}) must not exceed upper bound ({x[[2]]}).",
88+
call = call
89+
)
90+
}
91+
5592
if (!is.null(trans)) {
5693
return(invisible(x))
5794
}
95+
96+
# only do this after `arg` is used but do it because
97+
# this makes x0[known] <- as.integer(x0[known]) below work for e.g. c(1, 10)
5898
if (!is.list(x)) {
5999
x <- as.list(x)
60100
}
61101
x0 <- x
62-
known <- !is_unknown(x)
63102
x <- x[known]
64103
x_type <- purrr::map_chr(x, typeof)
65104
wrong_type <- any(x_type != type)
@@ -84,7 +123,7 @@ check_range <- function(x, type, trans, ..., call = caller_env()) {
84123
x0[known] <- as.integer(x0[known])
85124
} else {
86125
cli::cli_abort(
87-
"Since {.code type = \"{type}\"}, please use that data type for the
126+
"Since {.code type = \"{type}\"}, please use that data type for the
88127
range.",
89128
call = call
90129
)
@@ -93,45 +132,68 @@ check_range <- function(x, type, trans, ..., call = caller_env()) {
93132
invisible(x0)
94133
}
95134

96-
check_values_quant <- function(x, ..., call = caller_env()) {
135+
check_values_quant <- function(
136+
x,
137+
type = NULL,
138+
...,
139+
arg = caller_arg(x),
140+
call = caller_env()
141+
) {
97142
check_dots_empty()
98143

99144
if (is.null(x)) {
100145
return(invisible(x))
101146
}
102147

103148
if (!is.numeric(x)) {
104-
cli::cli_abort("{.arg values} must be numeric.", call = call)
149+
cli::cli_abort("{.arg {arg}} must be numeric.", call = call)
105150
}
151+
106152
if (anyNA(x)) {
107-
cli::cli_abort("{.arg values} can't be {.code NA}.", call = call)
153+
cli::cli_abort("{.arg {arg}} can't contain {.code NA} values.", call = call)
108154
}
109155
if (length(x) == 0) {
110-
cli::cli_abort("{.arg values} can't be empty.", call = call)
156+
cli::cli_abort("{.arg {arg}} can't be empty.", call = call)
157+
}
158+
if (anyDuplicated(x)) {
159+
cli::cli_abort("{.arg {arg}} can't contain duplicate values.", call = call)
160+
}
161+
162+
if (!is.null(type) && type == "integer") {
163+
# logic from from ?is.integer
164+
not_whole <- abs(x - round(x)) >= .Machine$double.eps^0.5
165+
if (any(not_whole)) {
166+
offenders <- x[not_whole]
167+
cli::cli_abort(
168+
c(
169+
"{.arg {arg}} must contain whole numbers for integer parameters.",
170+
x = "These are not whole numbers: {offenders}."
171+
),
172+
call = call
173+
)
174+
}
111175
}
112176

113177
invisible(x)
114178
}
115179

116-
check_inclusive <- function(x, ..., call = caller_env()) {
180+
check_inclusive <- function(x, ..., arg = caller_arg(x), call = caller_env()) {
117181
check_dots_empty()
118182

183+
check_logical(x, arg = arg, call = call)
184+
119185
if (any(is.na(x))) {
120-
cli::cli_abort("{.arg inclusive} cannot contain missings.", call = call)
186+
cli::cli_abort("{.arg {arg}} can't contain missing values.", call = call)
121187
}
122188

123-
if (is_logical(x, n = 2)) {
124-
return(invisible(NULL))
189+
if (length(x) != 2) {
190+
cli::cli_abort(
191+
"{.arg {arg}} must have length 2, not {length(x)}.",
192+
call = call
193+
)
125194
}
126195

127-
stop_input_type(
128-
x,
129-
"a logical vector of length 2",
130-
allow_na = FALSE,
131-
allow_null = FALSE,
132-
arg = "inclusive",
133-
call = call
134-
)
196+
invisible(NULL)
135197
}
136198

137199
check_param <- function(

tests/testthat/_snaps/constructors.md

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
Code
2828
new_quant_param("double", range = 1, inclusive = c(TRUE, TRUE))
2929
Condition
30-
Error in `names(range) <- names(inclusive) <- c("lower", "upper")`:
31-
! 'names' attribute [2] must be the same length as the vector [1]
30+
Error:
31+
! `range` must have 2 elements, not 1.
3232

3333
---
3434

@@ -45,15 +45,15 @@
4545
new_quant_param("double", range = c(1, NA), inclusive = TRUE)
4646
Condition
4747
Error:
48-
! `inclusive` must be a logical vector of length 2, not `TRUE`.
48+
! `inclusive` must have length 2, not 1.
4949

5050
---
5151

5252
Code
5353
new_quant_param("double", range = c(1, NA), inclusive = c("(", "]"))
5454
Condition
5555
Error:
56-
! `inclusive` must be a logical vector of length 2, not a character vector.
56+
! `inclusive` must be a logical vector, not a character vector.
5757

5858
---
5959

@@ -70,15 +70,15 @@
7070
new_quant_param("integer", range = 1:2, inclusive = c(TRUE, NA))
7171
Condition
7272
Error:
73-
! `inclusive` cannot contain missings.
73+
! `inclusive` can't contain missing values.
7474

7575
---
7676

7777
Code
7878
new_quant_param("integer", range = 1:2, inclusive = c(TRUE, unknown()))
7979
Condition
8080
Error:
81-
! `inclusive` must be a logical vector of length 2, not a list.
81+
! `inclusive` must be a logical vector, not a list.
8282

8383
---
8484

@@ -248,7 +248,7 @@
248248
mixture(letters[1:2])
249249
Condition
250250
Error in `mixture()`:
251-
! Since `type = "double"`, please use that data type for the range.
251+
! `range` must be numeric (or `unknown()`).
252252

253253
---
254254

@@ -324,7 +324,7 @@
324324
new_quant_param(type = "integer", values = NA_integer_, label = c(foo = "Foo"))
325325
Condition
326326
Error:
327-
! `values` can't be `NA`.
327+
! `values` can't contain `NA` values.
328328

329329
---
330330

@@ -352,3 +352,28 @@
352352
Error:
353353
! The `default` argument of `new_qual_param()` was deprecated in dials 1.1.0 and is now defunct.
354354

355+
# range ordering is validated
356+
357+
Code
358+
new_quant_param("integer", range = c(10L, 1L), inclusive = c(TRUE, TRUE))
359+
Condition
360+
Error:
361+
! The `range` lower bound (10) must not exceed upper bound (1).
362+
363+
# duplicate values are rejected
364+
365+
Code
366+
new_quant_param("double", values = c(1, 2, 2, 3))
367+
Condition
368+
Error:
369+
! `values` can't contain duplicate values.
370+
371+
# integer type requires whole number values
372+
373+
Code
374+
new_quant_param("integer", values = c(1.5, 2, 3))
375+
Condition
376+
Error:
377+
! `values` must contain whole numbers for integer parameters.
378+
x These are not whole numbers: 1.5.
379+

0 commit comments

Comments
 (0)