Skip to content

Commit b84b74e

Browse files
Feat: Add Python extras support in GithubImport (#403)
# The problem Github import did not support extras https://zapatacomputing.atlassian.net/browse/ORQSDK-915 # This PR's solution Add new API to support extras in github import # Checklist _Check that this PR satisfies the following items:_ - [x] Tests have been added for new features/changed behavior (if no new features have been added, check the box). - [x] The [changelog file](CHANGELOG.md) has been updated with a user-readable description of the changes (if the change isn't visible to the user in any way, check the box). - [x] The PR's title is prefixed with `<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:` - [x] The PR is linked to a JIRA ticket (if there's no suitable ticket, check the box).
1 parent 8220823 commit b84b74e

File tree

8 files changed

+249
-50
lines changed

8 files changed

+249
-50
lines changed

projects/orquestra-sdk/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
🔥 *Features*
88

99
* Add `WorkflowDef.with_head_node_resources()` function to programmatically set head node resources for a workflow definition
10+
* Add Python extras support in `GithubImport` object
1011

1112
🧟 *Deprecations*
1213

projects/orquestra-sdk/src/orquestra/sdk/_client/_base/_dsl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class GitImportWithAuth:
133133
git_ref: str
134134
username: Optional[str]
135135
auth_secret: Optional[Secret]
136+
package_name: Optional[str] = None
137+
extras: Optional[Tuple[str, ...]] = None
136138

137139

138140
@dataclass(frozen=True, eq=True)
@@ -166,6 +168,8 @@ def GithubImport(
166168
git_ref: str = "main",
167169
username: Optional[str] = None,
168170
personal_access_token: Optional[Secret] = None,
171+
package_name: Optional[str] = None,
172+
extras: Optional[Union[List[str], str]] = None,
169173
):
170174
"""Helper to create GitImports from Github repos.
171175
@@ -176,6 +180,11 @@ def GithubImport(
176180
username: the username used to access GitHub
177181
personal_access_token: must be configured in GitHub for access to the specified
178182
repo.
183+
package_name: package name that will be used during pip install. Example:
184+
my_package @ git+https://github.com/my_repo@main
185+
extras: name of extra (or list of name of extras) to be installed from repo. Ex:
186+
my_package[extra] @ git+https://github.com/my_repo@main
187+
Due to pip restrictions, passing extras require package name to be passed
179188
180189
Raises:
181190
TypeError: when a value that is not a `sdk.Secret` is passed as
@@ -207,12 +216,26 @@ def GithubImport(
207216
" Support for default workspaces will be sunset in the future.",
208217
FutureWarning,
209218
)
219+
if extras is not None and package_name is None:
220+
raise TypeError(
221+
"Due to PIP syntax restrictions, passing extras require" " package name."
222+
)
223+
224+
_extras: Optional[Tuple[str, ...]]
225+
if extras is None:
226+
_extras = None
227+
elif isinstance(extras, str):
228+
_extras = (extras,)
229+
else:
230+
_extras = tuple(extras)
210231

211232
return GitImportWithAuth(
212233
repo_url=f"https://github.com/{repo}.git",
213234
git_ref=git_ref,
214235
username=username,
215236
auth_secret=personal_access_token,
237+
extras=_extras,
238+
package_name=package_name,
216239
)
217240

218241

projects/orquestra-sdk/src/orquestra/sdk/_client/_base/_traversal.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
################################################################################
2-
# © Copyright 2022-2023 Zapata Computing Inc.
2+
# © Copyright 2022-2024 Zapata Computing Inc.
33
################################################################################
44
"""Transforms a DSL-based workflow into Intermediate Representation format.
55
@@ -417,6 +417,8 @@ def _make_import_model(imp: _dsl.Import):
417417
id=id_,
418418
repo_url=url,
419419
git_ref=imp.git_ref,
420+
package_name=imp.package_name,
421+
extras=imp.extras,
420422
)
421423
elif isinstance(imp, _dsl.InlineImport):
422424
return ir.InlineImport(id=id_)

projects/orquestra-sdk/src/orquestra/sdk/_runtime/_ray/_build_workflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,14 @@ def _(imp: ir.GitImport):
396396
if not protocol.startswith("git+"):
397397
protocol = f"git+{protocol}"
398398
url = _build_git_url(imp.repo_url, protocol)
399-
return [f"{url}@{imp.git_ref}"]
399+
400+
url_string = f"{url}@{imp.git_ref}"
401+
extras_string = "" if imp.extras is None else f"[{','.join(imp.extras)}]"
402+
package_name_string = (
403+
"" if imp.package_name is None else f"{imp.package_name}{extras_string} @ "
404+
)
405+
406+
return [f"{package_name_string}{url_string}"]
400407

401408

402409
def _import_pip_env(

projects/orquestra-sdk/src/orquestra/sdk/_shared/schema/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class GitImport(BaseModel):
5454
id: ImportId
5555
repo_url: GitURL
5656
git_ref: str
57+
package_name: t.Optional[str] = None
58+
extras: t.Optional[t.Tuple[str, ...]] = None
5759

5860
# we need this in the JSON to know which class to use when deserializing
5961
type: t.Literal["GIT_IMPORT"] = "GIT_IMPORT"

projects/orquestra-sdk/tests/runtime/ray/test_build_workflow.py

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
################################################################################
2-
# © Copyright 2023 Zapata Computing Inc.
2+
# © Copyright 2024 Zapata Computing Inc.
33
################################################################################
44

55
import re
@@ -199,32 +199,91 @@ class TestGitImports:
199199
def patch_env(self, monkeypatch: pytest.MonkeyPatch):
200200
monkeypatch.setenv("ORQ_RAY_DOWNLOAD_GIT_IMPORTS", "1")
201201

202-
def test_http(self, patch_env):
203-
imp = ir.GitImport(
204-
id="mock-import",
205-
repo_url=_git_url_utils.parse_git_url("https://mock/mock/mock"),
206-
git_ref="mock",
207-
)
208-
pip = _build_workflow._pip_string(imp)
209-
assert pip == ["git+https://mock/mock/mock@mock"]
210-
211-
def test_pip_ssh_format(self, patch_env):
212-
imp = ir.GitImport(
213-
id="mock-import",
214-
repo_url=_git_url_utils.parse_git_url("ssh://git@mock/mock/mock"),
215-
git_ref="mock",
216-
)
217-
pip = _build_workflow._pip_string(imp)
218-
assert pip == ["git+ssh://git@mock/mock/mock@mock"]
219-
220-
def test_usual_ssh_format(self, patch_env):
221-
imp = ir.GitImport(
222-
id="mock-import",
223-
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
224-
git_ref="mock",
225-
)
226-
pip = _build_workflow._pip_string(imp)
227-
assert pip == ["git+ssh://git@mock/mock/mock@mock"]
202+
@pytest.mark.parametrize(
203+
"imp, expected",
204+
[
205+
(
206+
ir.GitImport(
207+
id="mock-import",
208+
repo_url=_git_url_utils.parse_git_url("https://mock/mock/mock"),
209+
git_ref="mock",
210+
),
211+
["git+https://mock/mock/mock@mock"],
212+
),
213+
(
214+
ir.GitImport(
215+
id="mock-import",
216+
repo_url=_git_url_utils.parse_git_url(
217+
"ssh://git@mock/mock/mock"
218+
),
219+
git_ref="mock",
220+
),
221+
["git+ssh://git@mock/mock/mock@mock"],
222+
),
223+
(
224+
ir.GitImport(
225+
id="mock-import",
226+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
227+
git_ref="mock",
228+
),
229+
["git+ssh://git@mock/mock/mock@mock"],
230+
),
231+
(
232+
ir.GitImport(
233+
id="mock-import",
234+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
235+
git_ref="mock",
236+
package_name="pack_mock",
237+
),
238+
["pack_mock @ git+ssh://git@mock/mock/mock@mock"],
239+
),
240+
(
241+
ir.GitImport(
242+
id="mock-import",
243+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
244+
git_ref="mock",
245+
package_name="pack_mock",
246+
extras=None,
247+
),
248+
["pack_mock @ git+ssh://git@mock/mock/mock@mock"],
249+
),
250+
(
251+
ir.GitImport(
252+
id="mock-import",
253+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
254+
git_ref="mock",
255+
package_name="pack_mock",
256+
extras=("extra_mock",),
257+
),
258+
["pack_mock[extra_mock] @ git+ssh://git@mock/mock/mock@mock"],
259+
),
260+
(
261+
ir.GitImport(
262+
id="mock-import",
263+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
264+
git_ref="mock",
265+
package_name="pack_mock",
266+
extras=("extra_mock", "e_mock"),
267+
),
268+
[
269+
"pack_mock[extra_mock,e_mock] @ "
270+
"git+ssh://git@mock/mock/mock@mock"
271+
],
272+
),
273+
(
274+
ir.GitImport(
275+
id="mock-import",
276+
repo_url=_git_url_utils.parse_git_url("git@mock:mock/mock"),
277+
git_ref="mock",
278+
package_name=None,
279+
extras=("extra_mock", "e_mock"),
280+
),
281+
["git+ssh://git@mock/mock/mock@mock"],
282+
),
283+
],
284+
)
285+
def test_build_pip_string(self, patch_env, imp, expected):
286+
assert _build_workflow._pip_string(imp) == expected
228287

229288
def test_no_env_set(self):
230289
imp = ir.GitImport(

projects/orquestra-sdk/tests/runtime/ray/test_integration.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from orquestra.sdk._client._base._config import LOCAL_RUNTIME_CONFIGURATION
2323
from orquestra.sdk._client._base._testing import _example_wfs, _ipc
2424
from orquestra.sdk._runtime._ray import _build_workflow, _client, _dag, _ray_logs
25-
from orquestra.sdk._runtime._ray._env import RAY_TEMP_PATH_ENV
25+
from orquestra.sdk._runtime._ray._env import (
26+
RAY_DOWNLOAD_GIT_IMPORTS_ENV,
27+
RAY_TEMP_PATH_ENV,
28+
)
2629
from orquestra.sdk._shared import exceptions
2730
from orquestra.sdk._shared.abc import RuntimeInterface
2831
from orquestra.sdk._shared.schema import ir
@@ -1489,3 +1492,60 @@ def wf():
14891492
assert '"OVERWRITTEN_IN_WF"' in artifacts # expected in inv3
14901493
assert '"OVERWRITTEN"' in artifacts # expected in inv1
14911494
assert artifacts.count('"SET_BEFORE_RAY_STARTS"') == 2 # expected in inv2 and 4
1495+
1496+
1497+
@pytest.mark.slow
1498+
class TestGithubImportExtras:
1499+
def test_passing_extras(self, runtime: _dag.RayRuntime, monkeypatch):
1500+
@sdk.task(
1501+
dependency_imports=[
1502+
sdk.GithubImport(
1503+
repo="SebastianMorawiec/test_repo", package_name="test_repo"
1504+
)
1505+
]
1506+
)
1507+
def task_no_extra():
1508+
exception_happened = False
1509+
1510+
try:
1511+
import polars # type: ignore # noqa
1512+
except ModuleNotFoundError:
1513+
exception_happened = True
1514+
1515+
assert exception_happened
1516+
1517+
return 21
1518+
1519+
@sdk.task(
1520+
dependency_imports=[
1521+
sdk.GithubImport(
1522+
repo="SebastianMorawiec/test_repo",
1523+
package_name="test_repo",
1524+
extras="polars",
1525+
)
1526+
]
1527+
)
1528+
def task_with_extra():
1529+
import polars # type: ignore # noqa
1530+
1531+
return 36
1532+
1533+
@sdk.workflow
1534+
def wf():
1535+
return task_no_extra(), task_with_extra()
1536+
1537+
# Given
1538+
# This package should not be installed before running test
1539+
with pytest.raises(ModuleNotFoundError):
1540+
import polars # type: ignore # noqa
1541+
monkeypatch.setenv(RAY_DOWNLOAD_GIT_IMPORTS_ENV, "1")
1542+
1543+
wf_model = wf().model
1544+
wf_run_id = runtime.create_workflow_run(wf_model, None, False)
1545+
_wait_to_finish_wf(wf_run_id, runtime, timeout=120)
1546+
1547+
results = runtime.get_workflow_run_outputs_non_blocking(wf_run_id)
1548+
1549+
artifacts = [res.value for res in results]
1550+
1551+
assert artifacts == ["21", "36"]

0 commit comments

Comments
 (0)