Skip to content

Commit 240fb64

Browse files
authored
Additional sorting options for snapshots and deployments (#4033)
* Additional sorting options for snapshots and deployments * Linting * Add stack to custom sorting options
1 parent c430198 commit 240fb64

File tree

3 files changed

+115
-1
lines changed

3 files changed

+115
-1
lines changed

src/zenml/models/v2/core/deployment.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from zenml.zen_stores.schemas.base_schemas import BaseSchema
5555

5656
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
57+
AnyQuery = TypeVar("AnyQuery", bound=Any)
5758

5859

5960
class DeploymentOperationalState(BaseModel):
@@ -338,6 +339,8 @@ class DeploymentFilter(ProjectScopedFilter, TaggableFilter):
338339
CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
339340
*ProjectScopedFilter.CUSTOM_SORTING_OPTIONS,
340341
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
342+
"snapshot",
343+
"pipeline",
341344
]
342345
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
343346
*ProjectScopedFilter.FILTER_EXCLUDE_FIELDS,
@@ -409,3 +412,55 @@ def get_custom_filters(
409412
custom_filters.append(pipeline_filter)
410413

411414
return custom_filters
415+
416+
def apply_sorting(
417+
self,
418+
query: "AnyQuery",
419+
table: Type["AnySchema"],
420+
) -> "AnyQuery":
421+
"""Apply sorting to the query.
422+
423+
Args:
424+
query: The query to which to apply the sorting.
425+
table: The query table.
426+
427+
Returns:
428+
The query with sorting applied.
429+
"""
430+
from sqlmodel import asc, desc
431+
432+
from zenml.enums import SorterOps
433+
from zenml.zen_stores.schemas import (
434+
DeploymentSchema,
435+
PipelineSchema,
436+
PipelineSnapshotSchema,
437+
)
438+
439+
sort_by, operand = self.sorting_params
440+
441+
if sort_by == "pipeline":
442+
query = query.outerjoin(
443+
PipelineSnapshotSchema,
444+
DeploymentSchema.snapshot_id == PipelineSnapshotSchema.id,
445+
).outerjoin(
446+
PipelineSchema,
447+
PipelineSnapshotSchema.pipeline_id == PipelineSchema.id,
448+
)
449+
column: Any = PipelineSchema.name
450+
elif sort_by == "snapshot":
451+
query = query.outerjoin(
452+
PipelineSnapshotSchema,
453+
DeploymentSchema.snapshot_id == PipelineSnapshotSchema.id,
454+
)
455+
column = PipelineSnapshotSchema.name
456+
else:
457+
return super().apply_sorting(query=query, table=table)
458+
459+
query = query.add_columns(column)
460+
461+
if operand == SorterOps.ASCENDING:
462+
query = query.order_by(asc(column))
463+
else:
464+
query = query.order_by(desc(column))
465+
466+
return query

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ class PipelineRunFilter(
615615
*ProjectScopedFilter.CUSTOM_SORTING_OPTIONS,
616616
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
617617
*RunMetadataFilterMixin.CUSTOM_SORTING_OPTIONS,
618-
"tag",
619618
"stack",
620619
"pipeline",
621620
"model",

src/zenml/models/v2/core/pipeline_snapshot.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from zenml.zen_stores.schemas.base_schemas import BaseSchema
6565

6666
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
67+
AnyQuery = TypeVar("AnyQuery", bound=Any)
6768

6869

6970
# ------------------ Request Model ------------------
@@ -617,6 +618,9 @@ class PipelineSnapshotFilter(ProjectScopedFilter, TaggableFilter):
617618
CUSTOM_SORTING_OPTIONS = [
618619
*ProjectScopedFilter.CUSTOM_SORTING_OPTIONS,
619620
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
621+
"pipeline",
622+
"stack",
623+
"deployment",
620624
]
621625
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
622626
*ProjectScopedFilter.CLI_EXCLUDE_FIELDS,
@@ -771,6 +775,62 @@ def get_custom_filters(
771775

772776
return custom_filters
773777

778+
def apply_sorting(
779+
self,
780+
query: "AnyQuery",
781+
table: Type["AnySchema"],
782+
) -> "AnyQuery":
783+
"""Apply sorting to the query.
784+
785+
Args:
786+
query: The query to which to apply the sorting.
787+
table: The query table.
788+
789+
Returns:
790+
The query with sorting applied.
791+
"""
792+
from sqlmodel import asc, desc
793+
794+
from zenml.enums import SorterOps
795+
from zenml.zen_stores.schemas import (
796+
DeploymentSchema,
797+
PipelineSchema,
798+
PipelineSnapshotSchema,
799+
StackSchema,
800+
)
801+
802+
sort_by, operand = self.sorting_params
803+
804+
if sort_by == "pipeline":
805+
query = query.outerjoin(
806+
PipelineSchema,
807+
PipelineSnapshotSchema.pipeline_id == PipelineSchema.id,
808+
)
809+
column = PipelineSchema.name
810+
elif sort_by == "stack":
811+
query = query.outerjoin(
812+
StackSchema,
813+
PipelineSnapshotSchema.stack_id == StackSchema.id,
814+
)
815+
column = StackSchema.name
816+
elif sort_by == "deployment":
817+
query = query.outerjoin(
818+
DeploymentSchema,
819+
PipelineSnapshotSchema.id == DeploymentSchema.snapshot_id,
820+
)
821+
column = DeploymentSchema.name
822+
else:
823+
return super().apply_sorting(query=query, table=table)
824+
825+
query = query.add_columns(column)
826+
827+
if operand == SorterOps.ASCENDING:
828+
query = query.order_by(asc(column))
829+
else:
830+
query = query.order_by(desc(column))
831+
832+
return query
833+
774834

775835
# ------------------ Trigger Model ------------------
776836

0 commit comments

Comments
 (0)