Skip to content

Commit 3145af5

Browse files
committed
Enable the option to select cpu or cuda for the torch library on linux machines
1 parent 90b6831 commit 3145af5

File tree

5 files changed

+84
-19
lines changed

5 files changed

+84
-19
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ENV POETRY_CACHE_DIR=/tmp/poetry \
1616

1717
COPY poetry.lock pyproject.toml /build/
1818

19-
RUN poetry export > requirements.txt && rm --recursive --force -- "${POETRY_CACHE_DIR}"
19+
RUN poetry export --extras cuda > requirements.txt && rm --recursive --force -- "${POETRY_CACHE_DIR}"
2020

2121
# Step -- 2.
2222
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 AS runtime

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ lock:
2828

2929
bootstrap: poetry.lock poetry.toml pyproject.toml
3030
@poetry check
31-
@poetry install -vv --compile --no-cache --with dev --with docs --with tests
31+
@poetry install -vv --compile --extras cpu --no-cache --with dev --with docs --with tests
3232

3333
build: bootstrap
3434
@poetry build --clean

poetry.lock

Lines changed: 63 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ keywords = []
2121
dependencies = [ # see information: https://python-poetry.org/docs/dependency-specification
2222
"loguru (>=0.7,<1.0)",
2323
"numpy (>=2.2,<3.0)",
24-
"torch (>=2.6,<2.7)",
24+
]
25+
26+
[project.optional-dependencies]
27+
cpu = [
28+
"torch (>=2.6,<2.7) ; sys_platform == 'darwin'",
29+
"torch (>=2.6+cpu,<2.7+cpu) ; sys_platform == 'linux'",
30+
]
31+
cuda = [
32+
"torch (>=2.6+cu126,<2.7+cu126) ; sys_platform == 'linux'",
2533
]
2634

2735
[project.urls]
@@ -39,8 +47,10 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
3947
python = ">=3.10,<3.13"
4048
# ---
4149
torch = [
50+
{markers = "extra == 'cpu' and extra != 'cuda'", platform = "linux", source = "pytorch_cpu"},
51+
{markers = "extra == 'cuda' and extra != 'cpu'", platform = "linux", source = "pytorch_cuda"},
52+
# ---
4253
{platform = "darwin", source = "pypi_public"},
43-
{platform = "linux", source = "pytorch"},
4454
]
4555

4656
[tool.poetry.group.docs]
@@ -84,7 +94,12 @@ priority = "primary"
8494
url = "https://pypi.org/simple/"
8595

8696
[[tool.poetry.source]]
87-
name = "pytorch"
97+
name = "pytorch_cpu"
98+
priority = "explicit"
99+
url = "https://download.pytorch.org/whl/cpu"
100+
101+
[[tool.poetry.source]]
102+
name = "pytorch_cuda"
88103
priority = "explicit"
89104
url = "https://download.pytorch.org/whl/cu126"
90105

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ skip_missing_interpreters = {env:TOX_SKIP_MISSING_INTERPRETERS:True}
1010
[testenv]
1111
description = "Execute the test driver using {basepython}."
1212
allowlist_externals = poetry
13-
commands_pre = poetry install --with tests
13+
commands_pre = poetry install --extras cpu --with tests
1414
commands = poetry run pytest
1515
skip_install = true
1616

0 commit comments

Comments
 (0)