Skip to content

Commit 329c44a

Browse files
committed
Add dynamic fan-out/fan-in with run templates (#3826)
* Add dynamic fan-out/fan-in with run templates * Refactor code for advanced features in pipelines * Refactor chunk processing and results aggregation * Add `is_failed` property to `ExecutionStatus` enum * Update check for failed runs to use run status value * Remove unnecessary import in advanced_features.md * Update is_failed property return statement to refer to failed execution * Update advanced features with improved process_chunk logic * Update advanced features documentation and usage example * Update docs/book/how-to/steps-pipelines/advanced_features.md (cherry picked from commit 07beafb)
1 parent fe9bd13 commit 329c44a

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

docs/book/how-to/steps-pipelines/advanced_features.md

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,181 @@ The fan-in, fan-out method has the following limitations:
303303
2. The number of steps need to be known ahead-of-time, and ZenML does not yet support the ability to dynamically create steps on the fly.
304304
{% endhint %}
305305

306+
### Dynamic Fan-out/Fan-in with Run Templates
307+
308+
For scenarios where you need to determine the number of parallel operations at runtime (e.g., based on database queries or dynamic data), you can use [run templates](https://docs.zenml.io/user-guides/tutorial/trigger-pipelines-from-external-systems) to create a more flexible fan-out/fan-in pattern. This approach allows you to trigger multiple pipeline runs dynamically and then aggregate their results.
309+
310+
```python
311+
from typing import List, Optional
312+
from uuid import UUID
313+
import time
314+
315+
from zenml import step, pipeline
316+
from zenml.client import Client
317+
318+
319+
@step
320+
def load_relevant_chunks() -> List[str]:
321+
"""Load chunk identifiers from database or other dynamic source."""
322+
# Example: Query database for chunk IDs
323+
# In practice, this could be a database query, API call, etc.
324+
return ["chunk_1", "chunk_2", "chunk_3", "chunk_4"]
325+
326+
327+
@step
328+
def trigger_chunk_processing(
329+
chunks: List[str],
330+
template_id: Optional[UUID] = None
331+
) -> List[UUID]:
332+
"""Trigger multiple pipeline runs for each chunk and wait for completion."""
333+
client = Client()
334+
335+
# Use template ID if provided, otherwise give the pipeline name
336+
# of the pipeline you want triggered. Giving the pipeline name
337+
# will automatically find the latest template associated with
338+
# that pipeline.
339+
pipeline_name = None if template_id else "chunk_processing_pipeline"
340+
341+
# Trigger all chunk processing runs
342+
run_ids = []
343+
for chunk_id in chunks:
344+
run_config = {
345+
"steps": {
346+
"process_chunk": {
347+
"parameters": {
348+
"chunk_id": chunk_id
349+
}
350+
}
351+
}
352+
}
353+
354+
run = client.trigger_pipeline(
355+
template_id=template_id,
356+
pipeline_name_or_id=pipeline_name,
357+
run_configuration=run_config,
358+
synchronous=False # Run asynchronously
359+
)
360+
run_ids.append(run.id)
361+
362+
# Wait for all runs to complete
363+
print(f"Waiting for {len(run_ids)} chunk processing runs to complete...")
364+
completed_runs = set() # Cache completed runs to avoid re-fetching
365+
while True:
366+
# Only check runs that haven't completed yet
367+
pending_runs = [run_id for run_id in run_ids if run_id not in completed_runs]
368+
369+
for run_id in pending_runs:
370+
run = client.get_pipeline_run(run_id)
371+
if run.status.is_finished:
372+
completed_runs.add(run_id)
373+
374+
if len(completed_runs) == len(run_ids):
375+
print("All chunk processing runs completed!")
376+
break
377+
378+
print(f"Completed: {len(completed_runs)}/{len(run_ids)} runs")
379+
time.sleep(10) # Wait 10 seconds before checking again
380+
381+
return run_ids
382+
383+
384+
@step
385+
def aggregate_results(run_ids: List[UUID]) -> dict:
386+
"""Aggregate results from all chunk processing runs."""
387+
client = Client()
388+
aggregated_results = {}
389+
failed_runs = []
390+
391+
for run_id in run_ids:
392+
run = client.get_pipeline_run(run_id)
393+
394+
# Check if run succeeded
395+
if run.status.value == "failed":
396+
failed_runs.append({
397+
"run_id": str(run_id),
398+
"status": run.status.value,
399+
})
400+
print(f"WARNING: Run {run_id} failed with status {run.status.value}")
401+
continue
402+
403+
# Extract results from successful runs only
404+
if "process_chunk" in run.steps:
405+
step_run = run.steps["process_chunk"]
406+
# Simple assumption: process_chunk step has one output that we can load
407+
chunk_result = step_run.output.load()
408+
aggregated_results[str(run_id)] = chunk_result
409+
410+
411+
# Log summary of results
412+
total_runs = len(run_ids)
413+
successful_runs = len(aggregated_results)
414+
failed_count = len(failed_runs)
415+
416+
print(f"Aggregation complete: {successful_runs}/{total_runs} runs successful")
417+
418+
return {
419+
"successful_results": aggregated_results,
420+
"failed_runs": failed_runs,
421+
"summary": {
422+
"total_runs": total_runs,
423+
"successful_runs": successful_runs,
424+
"failed_runs": failed_count
425+
}
426+
}
427+
428+
429+
@pipeline(enable_cache=False)
430+
def fan_out_fan_in_pipeline(template_id: Optional[UUID] = None):
431+
"""Fan-out/fan-in pipeline that orchestrates dynamic chunk processing."""
432+
# Load chunks dynamically at runtime
433+
chunks = load_relevant_chunks()
434+
435+
# Trigger chunk processing runs and wait for completion
436+
run_ids = trigger_chunk_processing(chunks, template_id)
437+
438+
# Aggregate results from all runs
439+
results = aggregate_results(run_ids)
440+
441+
return results
442+
443+
444+
# Define the chunk processing pipeline that will be triggered
445+
@step
446+
def process_chunk(chunk_id: Optional[str] = None) -> dict:
447+
"""Process a single chunk of data."""
448+
# Simulate chunk processing
449+
print(f"Processing chunk: {chunk_id}")
450+
return {
451+
"chunk_id": chunk_id,
452+
"processed_items": 100,
453+
"status": "completed"
454+
}
455+
456+
457+
@pipeline
458+
def chunk_processing_pipeline():
459+
"""Pipeline that processes a single chunk."""
460+
result = process_chunk()
461+
return result
462+
463+
464+
# Usage example
465+
if __name__ == "__main__":
466+
# First, create a run template for the chunk processing pipeline
467+
# This would typically be done once during setup.
468+
# Make sure a remote stack is set before running this
469+
template = chunk_processing_pipeline.create_run_template(
470+
name="chunk_processing_template",
471+
description="Template for processing individual chunks"
472+
)
473+
474+
# Run the fan-out/fan-in pipeline with the template
475+
# You can also get the template ID from the dashboard
476+
fan_out_fan_in_pipeline(template_id=template.id)
477+
```
478+
479+
This pattern enables dynamic scaling, true parallelism, and database-driven workflows. Key advantages include fault tolerance and separate monitoring for each chunk. Consider resource management and proper error handling when implementing.
480+
306481
### Custom Step Invocation IDs
307482

308483
When calling a ZenML step as part of your pipeline, it gets assigned a unique **invocation ID** that you can use to reference this step invocation when defining the execution order of your pipeline steps or use it to fetch information about the invocation after the pipeline has finished running.

src/zenml/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def is_successful(self) -> bool:
107107
"""
108108
return self in {ExecutionStatus.COMPLETED, ExecutionStatus.CACHED}
109109

110+
@property
111+
def is_failed(self) -> bool:
112+
"""Whether the execution status refers to a failed execution.
113+
114+
Returns:
115+
Whether the execution status refers to a failed execution.
116+
"""
117+
return self in {ExecutionStatus.FAILED}
118+
110119

111120
class LoggingLevels(Enum):
112121
"""Enum for logging levels."""

0 commit comments

Comments
 (0)