From 96e0ccea472a5de22d565cb1e622c5c53595b4a4 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 16 May 2023 10:42:09 -0600 Subject: [PATCH 1/4] TST: support torch device testing --- ci/cirrus_general_ci.yml | 3 ++- scipy/conftest.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ci/cirrus_general_ci.yml b/ci/cirrus_general_ci.yml index b7f8fc0e52e2..d0fed56dab2e 100644 --- a/ci/cirrus_general_ci.yml +++ b/ci/cirrus_general_ci.yml @@ -143,4 +143,5 @@ macos_arm64_test_task: python -m pip install click rich_click doit pydevtool python -m pip install pytest pooch export DYLD_LIBRARY_PATH=/usr/local/gfortran/lib:/opt/arm64-builds/lib - python dev.py test + # spot check array API torch MPS device compliance + SCIPY_TORCH_DEVICE=mps python dev.py test -b numpy -b pytorch -s cluster diff --git a/scipy/conftest.py b/scipy/conftest.py index 3d2601f5ef18..d0fc6d7f3141 100644 --- a/scipy/conftest.py +++ b/scipy/conftest.py @@ -132,6 +132,10 @@ def check_fpu_mode(request): array_api_compatible = pytest.mark.parametrize("xp", array_api_backends) +if "pytorch" in array_api_available_backends: + torch_device_setting = os.environ.get("SCIPY_TORCH_DEVICE", "cpu") + torch.set_default_device(torch_device_setting) + skip_if_array_api = pytest.mark.skipif( SCIPY_ARRAY_API, reason="do not run with Array API on", From 2aead3e607f6f8454425ba5523440e25a7371d50 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 16 May 2023 10:52:43 -0600 Subject: [PATCH 2/4] CI: try activating cirrus for my fork. --- .cirrus.star | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cirrus.star b/.cirrus.star index 9056082c84bf..5bd79b42f275 100644 --- a/.cirrus.star +++ b/.cirrus.star @@ -17,7 +17,7 @@ def main(ctx): # - commit message containing [wheel build] ###################################################################### - if env.get("CIRRUS_REPO_FULL_NAME") != "scipy/scipy": + if env.get("CIRRUS_REPO_FULL_NAME") != "tylerjereddy/scipy": return [] if env.get("CIRRUS_CRON", "") == "nightly": From e1e348067add184a8271c51307c29edb1b7672eb Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 16 May 2023 10:56:56 -0600 Subject: [PATCH 3/4] CI: add torch --- ci/cirrus_general_ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/cirrus_general_ci.yml b/ci/cirrus_general_ci.yml index d0fed56dab2e..33dc424eb9d1 100644 --- a/ci/cirrus_general_ci.yml +++ b/ci/cirrus_general_ci.yml @@ -143,5 +143,6 @@ macos_arm64_test_task: python -m pip install click rich_click doit pydevtool python -m pip install pytest pooch export DYLD_LIBRARY_PATH=/usr/local/gfortran/lib:/opt/arm64-builds/lib + python -m pip install torch # spot check array API torch MPS device compliance SCIPY_TORCH_DEVICE=mps python dev.py test -b numpy -b pytorch -s cluster From 4bb004d4b9c16a6b1ef914fd514913c66d3b8e24 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Tue, 16 May 2023 11:04:44 -0600 Subject: [PATCH 4/4] CI: add array-api-compat --- ci/cirrus_general_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/cirrus_general_ci.yml b/ci/cirrus_general_ci.yml index 33dc424eb9d1..0b6094439cf7 100644 --- a/ci/cirrus_general_ci.yml +++ b/ci/cirrus_general_ci.yml @@ -143,6 +143,6 @@ macos_arm64_test_task: python -m pip install click rich_click doit pydevtool python -m pip install pytest pooch export DYLD_LIBRARY_PATH=/usr/local/gfortran/lib:/opt/arm64-builds/lib - python -m pip install torch + python -m pip install torch array-api-compat # spot check array API torch MPS device compliance SCIPY_TORCH_DEVICE=mps python dev.py test -b numpy -b pytorch -s cluster