Skip to content

Commit 7cf5d5d

Browse files
Feature:3973 Extend list pipeline filters
- Filter pipelines by the name or ID of the latest executor user - Extend client.list_pipelines with new filter option. - Integration tests with multi user switch
1 parent ea1ba85 commit 7cf5d5d

File tree

4 files changed

+184
-17
lines changed

4 files changed

+184
-17
lines changed

src/zenml/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,6 +2410,7 @@ def list_pipelines(
24102410
updated: Optional[Union[datetime, str]] = None,
24112411
name: Optional[str] = None,
24122412
latest_run_status: Optional[str] = None,
2413+
latest_run_user: Optional[Union[UUID, str]] = None,
24132414
project: Optional[Union[str, UUID]] = None,
24142415
user: Optional[Union[UUID, str]] = None,
24152416
tag: Optional[str] = None,
@@ -2429,6 +2430,8 @@ def list_pipelines(
24292430
name: The name of the pipeline to filter by.
24302431
latest_run_status: Filter by the status of the latest run of a
24312432
pipeline.
2433+
latest_run_user: Filter by the name or UUID of the user that
2434+
executed the latest run.
24322435
project: The project name/ID to filter by.
24332436
user: The name/ID of the user to filter by.
24342437
tag: Tag to filter by.
@@ -2449,6 +2452,7 @@ def list_pipelines(
24492452
updated=updated,
24502453
name=name,
24512454
latest_run_status=latest_run_status,
2455+
latest_run_user=latest_run_user,
24522456
project=project or self.active_project.id,
24532457
user=user,
24542458
tag=tag,

src/zenml/models/v2/core/pipeline.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Optional,
2222
Type,
2323
TypeVar,
24+
Union,
2425
)
2526
from uuid import UUID
2627

@@ -268,6 +269,7 @@ class PipelineFilter(ProjectScopedFilter, TaggableFilter):
268269
*ProjectScopedFilter.FILTER_EXCLUDE_FIELDS,
269270
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
270271
"latest_run_status",
272+
"latest_run_user",
271273
]
272274
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
273275
*ProjectScopedFilter.CLI_EXCLUDE_FIELDS,
@@ -284,6 +286,20 @@ class PipelineFilter(ProjectScopedFilter, TaggableFilter):
284286
"This will always be applied as an `AND` filter for now.",
285287
)
286288

289+
latest_run_user: Optional[Union[UUID, str]] = Field(
290+
default=None,
291+
description="Filter by the name or id of the last user that executed the pipeline. ",
292+
)
293+
294+
@property
295+
def filter_by_latest_execution(self) -> bool:
296+
"""Property returning whether filtering considers latest pipeline execution.
297+
298+
Returns:
299+
True if latest pipeline execution filters are used (e.g. latest_run_status).
300+
"""
301+
return bool(self.latest_run_user) or bool(self.latest_run_status)
302+
287303
def apply_filter(
288304
self, query: AnyQuery, table: Type["AnySchema"]
289305
) -> AnyQuery:
@@ -300,9 +316,13 @@ def apply_filter(
300316

301317
from sqlmodel import and_, col, func, select
302318

303-
from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema
319+
from zenml.zen_stores.schemas import (
320+
PipelineRunSchema,
321+
PipelineSchema,
322+
UserSchema,
323+
)
304324

305-
if self.latest_run_status:
325+
if self.filter_by_latest_execution:
306326
latest_pipeline_run_subquery = (
307327
select(
308328
PipelineRunSchema.pipeline_id,
@@ -313,28 +333,39 @@ def apply_filter(
313333
.subquery()
314334
)
315335

316-
query = (
317-
query.join(
318-
PipelineRunSchema,
319-
PipelineSchema.id == PipelineRunSchema.pipeline_id,
336+
query = query.join(
337+
PipelineRunSchema,
338+
PipelineSchema.id == PipelineRunSchema.pipeline_id,
339+
).join(
340+
latest_pipeline_run_subquery,
341+
and_(
342+
PipelineRunSchema.pipeline_id
343+
== latest_pipeline_run_subquery.c.pipeline_id,
344+
PipelineRunSchema.created
345+
== latest_pipeline_run_subquery.c.created,
346+
),
347+
)
348+
349+
if self.latest_run_user:
350+
query = query.join(
351+
UserSchema, UserSchema.id == PipelineRunSchema.user_id
320352
)
321-
.join(
322-
latest_pipeline_run_subquery,
323-
and_(
324-
PipelineRunSchema.pipeline_id
325-
== latest_pipeline_run_subquery.c.pipeline_id,
326-
PipelineRunSchema.created
327-
== latest_pipeline_run_subquery.c.created,
328-
),
353+
354+
query = query.where(
355+
self.generate_name_or_id_query_conditions(
356+
value=self.latest_run_user,
357+
table=UserSchema,
358+
)
329359
)
330-
.where(
360+
361+
if self.latest_run_status:
362+
query = query.where(
331363
self.generate_custom_query_conditions_for_column(
332364
value=self.latest_run_status,
333365
table=PipelineRunSchema,
334366
column="status",
335367
)
336368
)
337-
)
338369

339370
return query
340371

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import uuid
2+
3+
import pytest
4+
from tests.integration.functional.zen_stores.utils import LoginContext
5+
6+
from zenml import pipeline, step
7+
from zenml.client import Client
8+
from zenml.config.global_config import GlobalConfiguration
9+
from zenml.enums import StoreType
10+
from zenml.models.v2.core.user import UserResponse
11+
12+
13+
@step
14+
def greet_step(name: str) -> str:
15+
return f"Welcome {name}"
16+
17+
18+
@step
19+
def crush_step():
20+
raise ValueError("Oh noooo...")
21+
22+
23+
@pipeline
24+
def pipeline_that_completes():
25+
greet_step("test")
26+
27+
28+
@pipeline()
29+
def pipeline_that_crushes():
30+
crush_step()
31+
32+
33+
def submit_pipelines_programmatically(client: Client):
34+
pipeline_that_completes()
35+
36+
try:
37+
pipeline_that_crushes()
38+
except RuntimeError:
39+
pass
40+
41+
42+
def check_pipelines(client: Client, user_name: str):
43+
pipelines = client.list_pipelines(
44+
latest_run_status="completed", name="pipeline_that_completes"
45+
)
46+
47+
assert pipelines.size == 1
48+
49+
pipelines = client.list_pipelines(
50+
latest_run_status="failed", name="pipeline_that_crushes"
51+
)
52+
53+
assert pipelines.size == 1
54+
55+
pipelines = client.list_pipelines(
56+
name="pipeline_that_completes", latest_run_user=user_name
57+
)
58+
59+
assert pipelines.size == 1
60+
61+
pipelines = client.list_pipelines(
62+
name="pipeline_that_crushes", latest_run_user=user_name
63+
)
64+
65+
assert pipelines.size == 1
66+
67+
68+
def check_all_user_id_alternatives_work(client: Client, user: UserResponse):
69+
pipelines = client.list_pipelines(
70+
name="pipeline_that_completes", latest_run_user=user.id
71+
)
72+
73+
assert pipelines.size == 1
74+
75+
pipelines = client.list_pipelines(
76+
name="pipeline_that_completes", latest_run_user=str(user.id)
77+
)
78+
79+
assert pipelines.size == 1
80+
81+
pipelines = client.list_pipelines(
82+
name="pipeline_that_completes", latest_run_user=user.name
83+
)
84+
85+
assert pipelines.size == 1
86+
87+
88+
def test_multi_user_pipeline_executions():
89+
if GlobalConfiguration().zen_store.config.type != StoreType.REST:
90+
pytest.skip("Multi-user pipelines are supported only over REST.")
91+
92+
client = Client()
93+
94+
submit_pipelines_programmatically(client)
95+
96+
check_pipelines(client, user_name="default")
97+
98+
user = client.create_user(
99+
name=f"tester_{uuid.uuid4()}", password="1234", is_admin=True
100+
)
101+
102+
with LoginContext(user_name=user.name, password="1234"):
103+
new_client = Client()
104+
105+
submit_pipelines_programmatically(new_client)
106+
107+
check_pipelines(new_client, user.name)
108+
109+
check_all_user_id_alternatives_work(new_client, user=user)

tests/unit/models/test_pipeline_models.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from zenml.config.pipeline_spec import PipelineSpec
2020
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
21-
from zenml.models import PipelineRequest
21+
from zenml.models import PipelineFilter, PipelineRequest
2222

2323

2424
def test_pipeline_request_model_fails_with_long_name():
@@ -41,3 +41,26 @@ def test_pipeline_request_model_fails_with_long_docstring():
4141
docstring=long_docstring,
4242
spec=PipelineSpec(steps=[]),
4343
)
44+
45+
46+
def test_pipeline_filter_by_latest_execution():
47+
f = PipelineFilter()
48+
49+
assert not f.filter_by_latest_execution
50+
51+
f = PipelineFilter(latest_run_status="completed")
52+
53+
assert f.filter_by_latest_execution
54+
55+
f = PipelineFilter(latest_run_user="test")
56+
57+
assert f.filter_by_latest_execution
58+
59+
# make sure latest pipeline run associated fields are not propagated to base table filters.
60+
61+
latest_run_fields = [
62+
"latest_run_status",
63+
"latest_run_user",
64+
]
65+
66+
assert all(field in f.FILTER_EXCLUDE_FIELDS for field in latest_run_fields)

0 commit comments

Comments
 (0)