diff --git a/.cirrus.star b/.cirrus.star index 9056082c84bf..5bd79b42f275 100644 --- a/.cirrus.star +++ b/.cirrus.star @@ -17,7 +17,7 @@ def main(ctx): # - commit message containing [wheel build] ###################################################################### - if env.get("CIRRUS_REPO_FULL_NAME") != "scipy/scipy": + if env.get("CIRRUS_REPO_FULL_NAME") != "tylerjereddy/scipy": return [] if env.get("CIRRUS_CRON", "") == "nightly": diff --git a/ci/cirrus_general_ci.yml b/ci/cirrus_general_ci.yml index b7f8fc0e52e2..0b6094439cf7 100644 --- a/ci/cirrus_general_ci.yml +++ b/ci/cirrus_general_ci.yml @@ -143,4 +143,6 @@ macos_arm64_test_task: python -m pip install click rich_click doit pydevtool python -m pip install pytest pooch export DYLD_LIBRARY_PATH=/usr/local/gfortran/lib:/opt/arm64-builds/lib - python dev.py test + python -m pip install torch array-api-compat + # spot check array API torch MPS device compliance + SCIPY_TORCH_DEVICE=mps python dev.py test -b numpy -b pytorch -s cluster diff --git a/scipy/conftest.py b/scipy/conftest.py index 3d2601f5ef18..d0fc6d7f3141 100644 --- a/scipy/conftest.py +++ b/scipy/conftest.py @@ -132,6 +132,10 @@ def check_fpu_mode(request): array_api_compatible = pytest.mark.parametrize("xp", array_api_backends) +if "pytorch" in array_api_available_backends: + torch_device_setting = os.environ.get("SCIPY_TORCH_DEVICE", "cpu") + torch.set_default_device(torch_device_setting) + skip_if_array_api = pytest.mark.skipif( SCIPY_ARRAY_API, reason="do not run with Array API on",