Skip to content

Commit e851657

Browse files
Fix: workflow().with_resources(...) properly copies default source and dependency imports (#393)
# The problem `with_resources` function omitted copying of the imports causing problems when this parameter was used # This PR's solution properly copy all parameters # Checklist _Check that this PR satisfies the following items:_ - [x] Tests have been added for new features/changed behavior (if no new features have been added, check the box). - [x] The [changelog file](CHANGELOG.md) has been updated with a user-readable description of the changes (if the change isn't visible to the user in any way, check the box). - [x] The PR's title is prefixed with `<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:` - [x] The PR is linked to a JIRA ticket (if there's no suitable ticket, check the box).
1 parent bb12529 commit e851657

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
* `sdk.workflow(fn, resources=...)` will no longer show type errors from linters.
2121
* CLI log dumping now correctly saves stdout and stderr logs
22+
* `workflow().with_resources(...)` properly copies default source and dependency imports
2223

2324
💅 *Improvements*
2425

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ dependencies = [
6464
# Capture stdout/stderr
6565
"wurlitzer>=3.0",
6666
# For dremio client
67-
"pyarrow>=10.0",
67+
# pyarrow 16.0 crashed ray workers for unknown reason. Crashes were not
68+
# reproducable on mac - so carefull with taking that restriction away.
69+
"pyarrow>=10.0,<16.0",
6870
"pandas>=1.4",
6971
]
7072

src/orquestra/sdk/_client/_base/_workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def with_resources(
243243
data_aggregation=self._data_aggregation,
244244
workflow_args=self._workflow_args,
245245
workflow_kwargs=self._workflow_kwargs,
246+
default_source_import=self.default_source_import,
247+
default_dependency_imports=self.default_dependency_imports,
246248
)
247249

248250

tests/sdk/test_consistent_return_shapes.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,10 @@ def test_consistent_returns_for_single_value(
461461
m = re.match(
462462
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
463463
)
464-
assert m is not None
464+
465+
assert (
466+
m is not None
467+
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
465468
run_id_ray = m.group("run_id").strip()
466469
assert "Workflow Submitted!" in run_ce.stdout.decode()
467470

@@ -520,7 +523,9 @@ def test_consistent_returns_for_multiple_values(
520523
m = re.match(
521524
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
522525
)
523-
assert m is not None
526+
assert (
527+
m is not None
528+
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
524529
run_id_ray = m.group("run_id").strip()
525530
assert "Workflow Submitted!" in run_ce.stdout.decode()
526531

@@ -591,7 +596,9 @@ def test_consistent_downloads_for_single_value(
591596
m = re.match(
592597
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
593598
)
594-
assert m is not None
599+
assert (
600+
m is not None
601+
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
595602
run_id_ray = m.group("run_id").strip()
596603
assert mock_ce_run_single in run_ce.stdout.decode()
597604

@@ -669,7 +676,9 @@ def test_consistent_downloads_for_multiple_values(
669676
m = re.match(
670677
r"Workflow Submitted! Run ID: (?P<run_id>.*)", run_ray.stdout.decode()
671678
)
672-
assert m is not None
679+
assert (
680+
m is not None
681+
), f"STDOUT: {run_ray.stdout.decode()},\n\nSTDERR: {run_ray.stderr.decode()}"
673682
run_id_ray = m.group("run_id").strip()
674683
assert mock_ce_run_multiple in run_ce.stdout.decode()
675684

tests/sdk/test_workflow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,24 @@ def my_workflow():
399399
pass
400400

401401
assert my_workflow._default_dependency_imports == expected_imports
402+
403+
404+
def test_with_resources_copies_imports():
405+
@sdk.workflow(
406+
default_dependency_imports=[sdk.PythonImports("abc")],
407+
default_source_import=sdk.GitImport(repo_url="abc", git_ref="xyz"),
408+
)
409+
def my_workflow():
410+
pass
411+
412+
initial_workflow = my_workflow()
413+
modified_workflow = initial_workflow.with_resources(cpu="xyz")
414+
415+
assert (
416+
initial_workflow.default_dependency_imports
417+
== modified_workflow.default_dependency_imports
418+
)
419+
assert (
420+
initial_workflow.default_source_import
421+
== modified_workflow.default_source_import
422+
)

0 commit comments

Comments
 (0)