Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions ci/ci-util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
import re
import subprocess as sp
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cache
from glob import glob
from inspect import cleandoc
from os import getenv
from pathlib import Path
from typing import TypedDict, Self

USAGE = cleandoc(
"""
USAGE = cleandoc("""
usage:

./ci/ci-util.py <COMMAND> [flags]
Expand All @@ -44,8 +43,7 @@
Exit with success if the pull request contains a line starting with
`ci: allow-regressions`, indicating that regressions in benchmarks should
be accepted. Otherwise, exit 1.
"""
)
""")

REPO_ROOT = Path(__file__).parent.parent
GIT = ["git", "-C", REPO_ROOT]
Expand Down Expand Up @@ -84,6 +82,8 @@ class PrCfg:
allow_regressions: bool = False
# Don't run extensive tests
skip_extensive: bool = False
# Add these extensive tests to the list
extra_extensive: list[str] = field(default_factory=list)

# Allow running a large number of extensive tests. If not set, this script
# will error out if a threshold is exceeded in order to avoid accidentally
Expand All @@ -101,11 +101,17 @@ class PrCfg:
DIR_SKIP_EXTENSIVE: str = "skip-extensive"
DIR_ALLOW_MANY_EXTENSIVE: str = "allow-many-extensive"
DIR_TEST_LIBM: str = "test-libm"
DIR_EXTRA_EXTENSIVE: str = "extra-extensive"

def __init__(self, body: str):
directives = re.finditer(r"^\s*ci:\s*(?P<dir_name>\S*)", body, re.MULTILINE)
directives = re.finditer(
r"^\s*ci:\s*(?P<dir_name>[^\s=]*)(?:\s*=\s*(?P<args>.*))?",
body,
re.MULTILINE,
)
for dir in directives:
name = dir.group("dir_name")
args = dir.group("args")
if name == self.DIR_ALLOW_REGRESSIONS:
self.allow_regressions = True
elif name == self.DIR_SKIP_EXTENSIVE:
Expand All @@ -114,10 +120,17 @@ def __init__(self, body: str):
self.allow_many_extensive = True
elif name == self.DIR_TEST_LIBM:
self.always_test_libm = True
elif name == self.DIR_EXTRA_EXTENSIVE:
self.extra_extensive = [x.strip() for x in args.split(",")]
args = None
else:
eprint(f"Found unexpected directive `{name}`")
exit(1)

if args is not None:
eprint("Found arguments where not expected")
exit(1)

pprint.pp(self)


Expand Down Expand Up @@ -276,29 +289,35 @@ def emit_workflow_output(self):

skip_tests = False
error_on_many_tests = False
extra_tests = {}

pr = PrInfo.from_env()
if pr is not None:
skip_tests = pr.cfg.skip_extensive
error_on_many_tests = not pr.cfg.allow_many_extensive
for fn_name in pr.cfg.extra_extensive:
extra_tests.setdefault(base_name(fn_name)[1], []).append(fn_name)

if skip_tests:
eprint("Skipping all extensive tests")

changed = self.changed_routines()
eprint(f"Changed: {changed}")

matrix = []
total_to_test = 0

# Figure out which extensive tests need to run
for ty in TYPES:
ty_changed = changed.get(ty, [])
ty_to_test = [] if skip_tests else ty_changed
ty_to_test = [] if skip_tests else ty_changed.copy()
ty_to_test.extend(extra_tests.get(ty, []))
total_to_test += len(ty_to_test)

item = {
"ty": ty,
"changed": ",".join(ty_changed),
"to_test": ",".join(ty_to_test),
"changed": ",".join(sorted(set(ty_changed))),
"to_test": ",".join(sorted(set(ty_to_test))),
}

matrix.append(item)
Expand All @@ -319,6 +338,33 @@ def emit_workflow_output(self):
exit(1)


def base_name(name: str) -> tuple[str, str]:
"""Return the basename and type from a full function name. Keep in sync with Rust's
`fn base_name`.
"""
known_mappings = [
("erff", ("erf", "f32")),
("erf", ("erf", "f64")),
("modff", ("modf", "f32")),
("modf", ("modf", "f64")),
("lgammaf_r", ("lgamma_r", "f32")),
("lgamma_r", ("lgamma_r", "f64")),
]

found = next((base for (full, base) in known_mappings if full == name), None)
if found is not None:
return found

if name.endswith("f"):
return (name.rstrip("f"), "f32")
elif name.endswith("f16"):
return (name.rstrip("f16"), "f16")
elif name.endswith("f128"):
return (name.rstrip("f128"), "f128")

return (name, "f64")


def locate_baseline(flags: list[str]) -> None:
"""Find the most recent baseline from CI, download it if specified.

Expand Down
Loading