Skip to content

Commit 03aa795

Browse files
authored
Improved pipeline/step run fetching (#3776)
* Pass project ID when fetching steps of a run * Improved loading of step runs * Improved loading of pipeline runs * Linting * Viewonly relationship * Optimize DAG DB queries * Linting
1 parent 67fea79 commit 03aa795

File tree

4 files changed

+67
-32
lines changed

4 files changed

+67
-32
lines changed

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def steps(self) -> Dict[str, "StepRunResponse"]:
480480
for step in pagination_utils.depaginate(
481481
Client().list_run_steps,
482482
pipeline_run_id=self.id,
483+
project=self.project_id,
483484
)
484485
}
485486

src/zenml/zen_stores/schemas/pipeline_run_schemas.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pydantic import ConfigDict
2222
from sqlalchemy import UniqueConstraint
23-
from sqlalchemy.orm import joinedload
23+
from sqlalchemy.orm import selectinload
2424
from sqlalchemy.sql.base import ExecutableOption
2525
from sqlmodel import TEXT, Column, Field, Relationship
2626

@@ -259,19 +259,19 @@ def get_query_options(
259259
from zenml.zen_stores.schemas import ModelVersionSchema
260260

261261
options = [
262-
joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
262+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
263263
jl_arg(PipelineDeploymentSchema.pipeline)
264264
),
265-
joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
265+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
266266
jl_arg(PipelineDeploymentSchema.stack)
267267
),
268-
joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
268+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
269269
jl_arg(PipelineDeploymentSchema.build)
270270
),
271-
joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
271+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
272272
jl_arg(PipelineDeploymentSchema.schedule)
273273
),
274-
joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
274+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
275275
jl_arg(PipelineDeploymentSchema.code_reference)
276276
),
277277
]
@@ -286,14 +286,14 @@ def get_query_options(
286286
if include_resources:
287287
options.extend(
288288
[
289-
joinedload(
289+
selectinload(
290290
jl_arg(PipelineRunSchema.model_version)
291291
).joinedload(
292292
jl_arg(ModelVersionSchema.model), innerjoin=True
293293
),
294-
joinedload(jl_arg(PipelineRunSchema.logs)),
295-
joinedload(jl_arg(PipelineRunSchema.user)),
296-
# joinedload(jl_arg(PipelineRunSchema.tags)),
294+
selectinload(jl_arg(PipelineRunSchema.logs)),
295+
selectinload(jl_arg(PipelineRunSchema.user)),
296+
selectinload(jl_arg(PipelineRunSchema.tags)),
297297
]
298298
)
299299

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic import ConfigDict
2222
from sqlalchemy import TEXT, Column, String, UniqueConstraint
2323
from sqlalchemy.dialects.mysql import MEDIUMTEXT
24-
from sqlalchemy.orm import joinedload
24+
from sqlalchemy.orm import joinedload, selectinload
2525
from sqlalchemy.sql.base import ExecutableOption
2626
from sqlmodel import Field, Relationship, SQLModel
2727

@@ -50,6 +50,7 @@
5050
from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
5151
from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
5252
PipelineDeploymentSchema,
53+
StepConfigurationSchema,
5354
)
5455
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
5556
from zenml.zen_stores.schemas.project_schemas import ProjectSchema
@@ -187,6 +188,14 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
187188
original_step_run: Optional["StepRunSchema"] = Relationship(
188189
sa_relationship_kwargs={"remote_side": "StepRunSchema.id"}
189190
)
191+
step_configuration_schema: Optional["StepConfigurationSchema"] = (
192+
Relationship(
193+
sa_relationship_kwargs=dict(
194+
viewonly=True,
195+
primaryjoin="and_(foreign(StepConfigurationSchema.name) == StepRunSchema.name, foreign(StepConfigurationSchema.deployment_id) == StepRunSchema.deployment_id)",
196+
),
197+
)
198+
)
190199

191200
model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment]
192201

@@ -209,17 +218,25 @@ def get_query_options(
209218
Returns:
210219
A list of query options.
211220
"""
212-
from zenml.zen_stores.schemas import ModelVersionSchema
221+
from zenml.zen_stores.schemas import (
222+
ArtifactVersionSchema,
223+
ModelVersionSchema,
224+
)
213225

214226
options = [
215-
joinedload(jl_arg(StepRunSchema.deployment)),
216-
joinedload(jl_arg(StepRunSchema.pipeline_run)),
227+
selectinload(jl_arg(StepRunSchema.deployment)).load_only(
228+
jl_arg(PipelineDeploymentSchema.pipeline_configuration)
229+
),
230+
selectinload(jl_arg(StepRunSchema.pipeline_run)).load_only(
231+
jl_arg(PipelineRunSchema.start_time)
232+
),
233+
joinedload(jl_arg(StepRunSchema.step_configuration_schema)),
217234
]
218235

219236
if include_metadata:
220237
options.extend(
221238
[
222-
joinedload(jl_arg(StepRunSchema.logs)),
239+
selectinload(jl_arg(StepRunSchema.logs)),
223240
# joinedload(jl_arg(StepRunSchema.parents)),
224241
# joinedload(jl_arg(StepRunSchema.run_metadata)),
225242
]
@@ -228,12 +245,28 @@ def get_query_options(
228245
if include_resources:
229246
options.extend(
230247
[
231-
joinedload(jl_arg(StepRunSchema.model_version)).joinedload(
248+
selectinload(
249+
jl_arg(StepRunSchema.model_version)
250+
).joinedload(
232251
jl_arg(ModelVersionSchema.model), innerjoin=True
233252
),
234-
joinedload(jl_arg(StepRunSchema.user)),
235-
# joinedload(jl_arg(StepRunSchema.input_artifacts)),
236-
# joinedload(jl_arg(StepRunSchema.output_artifacts)),
253+
selectinload(jl_arg(StepRunSchema.user)),
254+
selectinload(jl_arg(StepRunSchema.input_artifacts))
255+
.joinedload(
256+
jl_arg(StepRunInputArtifactSchema.artifact_version),
257+
innerjoin=True,
258+
)
259+
.joinedload(
260+
jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
261+
),
262+
selectinload(jl_arg(StepRunSchema.output_artifacts))
263+
.joinedload(
264+
jl_arg(StepRunOutputArtifactSchema.artifact_version),
265+
innerjoin=True,
266+
)
267+
.joinedload(
268+
jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
269+
),
237270
]
238271
)
239272

@@ -290,10 +323,7 @@ def to_model(
290323
"""
291324
step = None
292325
if self.deployment is not None:
293-
step_configurations = self.deployment.get_step_configurations(
294-
include=[self.name]
295-
)
296-
if step_configurations:
326+
if self.step_configuration_schema:
297327
pipeline_configuration = (
298328
PipelineConfiguration.model_validate_json(
299329
self.deployment.pipeline_configuration
@@ -304,7 +334,7 @@ def to_model(
304334
inplace=True,
305335
)
306336
step = Step.from_dict(
307-
json.loads(step_configurations[0].config),
337+
json.loads(self.step_configuration_schema.config),
308338
pipeline_configuration=pipeline_configuration,
309339
)
310340
if not step and self.step_configuration:

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
ArgumentError,
6565
IntegrityError,
6666
)
67-
from sqlalchemy.orm import Mapped, joinedload, noload
67+
from sqlalchemy.orm import Mapped, noload, selectinload
6868
from sqlalchemy.sql.base import ExecutableOption
6969
from sqlalchemy.util import immutabledict
7070
from sqlmodel import Session as SqlModelSession
@@ -5332,13 +5332,17 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG:
53325332
schema_class=PipelineRunSchema,
53335333
session=session,
53345334
query_options=[
5335-
joinedload(jl_arg(PipelineRunSchema.deployment)),
5336-
# joinedload(jl_arg(PipelineRunSchema.step_runs)).sele(
5337-
# jl_arg(StepRunSchema.input_artifacts)
5338-
# ),
5339-
# joinedload(jl_arg(PipelineRunSchema.step_runs)).joinedload(
5340-
# jl_arg(StepRunSchema.output_artifacts)
5341-
# ),
5335+
selectinload(
5336+
jl_arg(PipelineRunSchema.deployment)
5337+
).load_only(
5338+
jl_arg(PipelineDeploymentSchema.pipeline_configuration)
5339+
),
5340+
selectinload(
5341+
jl_arg(PipelineRunSchema.step_runs)
5342+
).selectinload(jl_arg(StepRunSchema.input_artifacts)),
5343+
selectinload(
5344+
jl_arg(PipelineRunSchema.step_runs)
5345+
).selectinload(jl_arg(StepRunSchema.output_artifacts)),
53425346
],
53435347
)
53445348
assert run.deployment is not None

0 commit comments

Comments
 (0)