1111import re
1212import subprocess as sp
1313import sys
14- from dataclasses import dataclass
14+ from dataclasses import dataclass , field
1515from functools import cache
1616from glob import glob
1717from inspect import cleandoc
1818from os import getenv
1919from pathlib import Path
2020from typing import TypedDict , Self
2121
22- USAGE = cleandoc (
23- """
22+ USAGE = cleandoc ("""
2423 usage:
2524
2625 ./ci/ci-util.py <COMMAND> [flags]
4443 Exit with success if the pull request contains a line starting with
4544 `ci: allow-regressions`, indicating that regressions in benchmarks should
4645 be accepted. Otherwise, exit 1.
47- """
48- )
46+ """ )
4947
5048REPO_ROOT = Path (__file__ ).parent .parent
5149GIT = ["git" , "-C" , REPO_ROOT ]
@@ -84,6 +82,8 @@ class PrCfg:
8482 allow_regressions : bool = False
8583 # Don't run extensive tests
8684 skip_extensive : bool = False
85+ # Add these extensive tests to the list
86+ extra_extensive : list [str ] = field (default_factory = list )
8787
8888 # Allow running a large number of extensive tests. If not set, this script
8989 # will error out if a threshold is exceeded in order to avoid accidentally
@@ -101,11 +101,17 @@ class PrCfg:
101101 DIR_SKIP_EXTENSIVE : str = "skip-extensive"
102102 DIR_ALLOW_MANY_EXTENSIVE : str = "allow-many-extensive"
103103 DIR_TEST_LIBM : str = "test-libm"
104+ DIR_EXTRA_EXTENSIVE : str = "extra-extensive"
104105
105106 def __init__ (self , body : str ):
106- directives = re .finditer (r"^\s*ci:\s*(?P<dir_name>\S*)" , body , re .MULTILINE )
107+ directives = re .finditer (
108+ r"^\s*ci:\s*(?P<dir_name>[^\s=]*)(?:\s*=\s*(?P<args>.*))?" ,
109+ body ,
110+ re .MULTILINE ,
111+ )
107112 for dir in directives :
108113 name = dir .group ("dir_name" )
114+ args = dir .group ("args" )
109115 if name == self .DIR_ALLOW_REGRESSIONS :
110116 self .allow_regressions = True
111117 elif name == self .DIR_SKIP_EXTENSIVE :
@@ -114,10 +120,17 @@ def __init__(self, body: str):
114120 self .allow_many_extensive = True
115121 elif name == self .DIR_TEST_LIBM :
116122 self .always_test_libm = True
123+ elif name == self .DIR_EXTRA_EXTENSIVE :
124+ self .extra_extensive = [x .strip () for x in args .split ("," )]
125+ args = None
117126 else :
118127 eprint (f"Found unexpected directive `{ name } `" )
119128 exit (1 )
120129
130+ if args is not None :
131+ eprint ("Found arguments where not expected" )
132+ exit (1 )
133+
121134 pprint .pp (self )
122135
123136
@@ -276,29 +289,35 @@ def emit_workflow_output(self):
276289
277290 skip_tests = False
278291 error_on_many_tests = False
292+ extra_tests = {}
279293
280294 pr = PrInfo .from_env ()
281295 if pr is not None :
282296 skip_tests = pr .cfg .skip_extensive
283297 error_on_many_tests = not pr .cfg .allow_many_extensive
298+ for fn_name in pr .cfg .extra_extensive :
299+ extra_tests .setdefault (base_name (fn_name )[1 ], []).append (fn_name )
284300
285301 if skip_tests :
286302 eprint ("Skipping all extensive tests" )
287303
288304 changed = self .changed_routines ()
305+ eprint (f"Changed: { changed } " )
306+
289307 matrix = []
290308 total_to_test = 0
291309
292310 # Figure out which extensive tests need to run
293311 for ty in TYPES :
294312 ty_changed = changed .get (ty , [])
295313 ty_to_test = [] if skip_tests else ty_changed
314+ ty_to_test .extend (extra_tests [ty ])
296315 total_to_test += len (ty_to_test )
297316
298317 item = {
299318 "ty" : ty ,
300- "changed" : "," .join (ty_changed ),
301- "to_test" : "," .join (ty_to_test ),
319+ "changed" : "," .join (sorted ( set ( ty_changed )) ),
320+ "to_test" : "," .join (sorted ( set ( ty_to_test )) ),
302321 }
303322
304323 matrix .append (item )
@@ -319,6 +338,33 @@ def emit_workflow_output(self):
319338 exit (1 )
320339
321340
341+ def base_name (name : str ) -> tuple [str , str ]:
342+ """Return the basename and type from a full function name. Keep in sync with Rust's
343+ `fn base_name`.
344+ """
345+ known_mappings = [
346+ ("erff" , ("erf" , "f32" )),
347+ ("erf" , ("erf" , "f64" )),
348+ ("modff" , ("modf" , "f32" )),
349+ ("modf" , ("modf" , "f64" )),
350+ ("lgammaf_r" , ("lgamma_r" , "f32" )),
351+ ("lgamma_r" , ("lgamma_r" , "f64" )),
352+ ]
353+
354+ found = next ((base for (full , base ) in known_mappings if full == name ), None )
355+ if found is not None :
356+ return found
357+
358+ if name .endswith ("f" ):
359+ return (name .rstrip ("f" ), "f32" )
360+ elif name .endswith ("f16" ):
361+ return (name .rstrip ("f16" ), "f16" )
362+ elif name .endswith ("f128" ):
363+ return (name .rstrip ("f128" ), "f128" )
364+
365+ return (name , "f64" )
366+
367+
322368def locate_baseline (flags : list [str ]) -> None :
323369 """Find the most recent baseline from CI, download it if specified.
324370
0 commit comments