Skip to content

Commit 3a7110f

Browse files
committed
ci: Allow specifying extra extensive tests to run
For cases where we would like to run tests that aren't automatically detected, allow the following syntax in PR descriptions: ci: extra-extensive=copysignf,sqrtf16
1 parent 34a9616 commit 3a7110f

File tree

1 file changed

+54
-8
lines changed

1 file changed

+54
-8
lines changed

ci/ci-util.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
import re
1212
import subprocess as sp
1313
import sys
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from functools import cache
1616
from glob import glob
1717
from inspect import cleandoc
1818
from os import getenv
1919
from pathlib import Path
2020
from typing import TypedDict, Self
2121

22-
USAGE = cleandoc(
23-
"""
22+
USAGE = cleandoc("""
2423
usage:
2524
2625
./ci/ci-util.py <COMMAND> [flags]
@@ -44,8 +43,7 @@
4443
Exit with success if the pull request contains a line starting with
4544
`ci: allow-regressions`, indicating that regressions in benchmarks should
4645
be accepted. Otherwise, exit 1.
47-
"""
48-
)
46+
""")
4947

5048
REPO_ROOT = Path(__file__).parent.parent
5149
GIT = ["git", "-C", REPO_ROOT]
@@ -84,6 +82,8 @@ class PrCfg:
8482
allow_regressions: bool = False
8583
# Don't run extensive tests
8684
skip_extensive: bool = False
85+
# Add these extensive tests to the list
86+
extra_extensive: list[str] = field(default_factory=list)
8787

8888
# Allow running a large number of extensive tests. If not set, this script
8989
# will error out if a threshold is exceeded in order to avoid accidentally
@@ -101,11 +101,17 @@ class PrCfg:
101101
DIR_SKIP_EXTENSIVE: str = "skip-extensive"
102102
DIR_ALLOW_MANY_EXTENSIVE: str = "allow-many-extensive"
103103
DIR_TEST_LIBM: str = "test-libm"
104+
DIR_EXTRA_EXTENSIVE: str = "extra-extensive"
104105

105106
def __init__(self, body: str):
106-
directives = re.finditer(r"^\s*ci:\s*(?P<dir_name>\S*)", body, re.MULTILINE)
107+
directives = re.finditer(
108+
r"^\s*ci:\s*(?P<dir_name>[^\s=]*)(?:\s*=\s*(?P<args>.*))?",
109+
body,
110+
re.MULTILINE,
111+
)
107112
for dir in directives:
108113
name = dir.group("dir_name")
114+
args = dir.group("args")
109115
if name == self.DIR_ALLOW_REGRESSIONS:
110116
self.allow_regressions = True
111117
elif name == self.DIR_SKIP_EXTENSIVE:
@@ -114,10 +120,17 @@ def __init__(self, body: str):
114120
self.allow_many_extensive = True
115121
elif name == self.DIR_TEST_LIBM:
116122
self.always_test_libm = True
123+
elif name == self.DIR_EXTRA_EXTENSIVE:
124+
self.extra_extensive = [x.strip() for x in args.split(",")]
125+
args = None
117126
else:
118127
eprint(f"Found unexpected directive `{name}`")
119128
exit(1)
120129

130+
if args is not None:
131+
eprint("Found arguments where not expected")
132+
exit(1)
133+
121134
pprint.pp(self)
122135

123136

@@ -276,29 +289,35 @@ def emit_workflow_output(self):
276289

277290
skip_tests = False
278291
error_on_many_tests = False
292+
extra_tests = {}
279293

280294
pr = PrInfo.from_env()
281295
if pr is not None:
282296
skip_tests = pr.cfg.skip_extensive
283297
error_on_many_tests = not pr.cfg.allow_many_extensive
298+
for fn_name in pr.cfg.extra_extensive:
299+
extra_tests.setdefault(base_name(fn_name)[1], []).append(fn_name)
284300

285301
if skip_tests:
286302
eprint("Skipping all extensive tests")
287303

288304
changed = self.changed_routines()
305+
eprint(f"Changed: {changed}")
306+
289307
matrix = []
290308
total_to_test = 0
291309

292310
# Figure out which extensive tests need to run
293311
for ty in TYPES:
294312
ty_changed = changed.get(ty, [])
295313
ty_to_test = [] if skip_tests else ty_changed
314+
ty_to_test.extend(extra_tests[ty])
296315
total_to_test += len(ty_to_test)
297316

298317
item = {
299318
"ty": ty,
300-
"changed": ",".join(ty_changed),
301-
"to_test": ",".join(ty_to_test),
319+
"changed": ",".join(sorted(set(ty_changed))),
320+
"to_test": ",".join(sorted(set(ty_to_test))),
302321
}
303322

304323
matrix.append(item)
@@ -319,6 +338,33 @@ def emit_workflow_output(self):
319338
exit(1)
320339

321340

341+
def base_name(name: str) -> tuple[str, str]:
342+
"""Return the basename and type from a full function name. Keep in sync with Rust's
343+
`fn base_name`.
344+
"""
345+
known_mappings = [
346+
("erff", ("erf", "f32")),
347+
("erf", ("erf", "f64")),
348+
("modff", ("modf", "f32")),
349+
("modf", ("modf", "f64")),
350+
("lgammaf_r", ("lgamma_r", "f32")),
351+
("lgamma_r", ("lgamma_r", "f64")),
352+
]
353+
354+
found = next((base for (full, base) in known_mappings if full == name), None)
355+
if found is not None:
356+
return found
357+
358+
if name.endswith("f"):
359+
return (name.rstrip("f"), "f32")
360+
elif name.endswith("f16"):
361+
return (name.rstrip("f16"), "f16")
362+
elif name.endswith("f128"):
363+
return (name.rstrip("f128"), "f128")
364+
365+
return (name, "f64")
366+
367+
322368
def locate_baseline(flags: list[str]) -> None:
323369
"""Find the most recent baseline from CI, download it if specified.
324370

0 commit comments

Comments
 (0)