11import asyncio
22from datetime import timedelta
33from pydantic import BaseModel , Field
4- from restack_ai .workflow import workflow , log , workflow_info , import_functions
4+ from restack_ai .workflow import workflow , log , workflow_info , import_functions , NonRetryableError
55from .child import ChildWorkflow , ChildWorkflowInput
66
77with import_functions ():
@@ -14,34 +14,39 @@ class ExampleWorkflowInput(BaseModel):
1414class ExampleWorkflow :
1515 @workflow .run
1616 async def run (self , input : ExampleWorkflowInput ):
17- # use the parent run id to create child workflow ids
18- parent_workflow_id = workflow_info ().workflow_id
19-
20- tasks = []
21- for i in range (input .amount ):
22- log .info (f"Queue ChildWorkflow { i + 1 } for execution" )
23- task = workflow .child_execute (
24- workflow = ChildWorkflow ,
25- workflow_id = f"{ parent_workflow_id } -child-execute-{ i + 1 } " ,
26- input = ChildWorkflowInput (name = f"child workflow { i + 1 } " )
17+
18+ try :
19+ # use the parent run id to create child workflow ids
20+ parent_workflow_id = workflow_info ().workflow_id
21+
22+ tasks = []
23+ for i in range (input .amount ):
24+ log .info (f"Queue ChildWorkflow { i + 1 } for execution" )
25+ task = workflow .child_execute (
26+ workflow = ChildWorkflow ,
27+ workflow_id = f"{ parent_workflow_id } -child-execute-{ i + 1 } " ,
28+ workflow_input = ChildWorkflowInput (prompt = "Generate a random joke in max 20 words." ),
29+ )
30+ tasks .append (task )
31+
32+ # Run all child workflows in parallel and wait for their results
33+ results = await asyncio .gather (* tasks )
34+
35+ for i , result in enumerate (results , start = 1 ):
36+ log .info (f"ChildWorkflow { i } completed" , result = result )
37+
38+ generated_text = await workflow .step (
39+ function = llm_generate ,
40+ function_input = GenerateInput (prompt = f"Give me the top 3 unique jokes according to the results. { results } " ),
41+ task_queue = "llm" ,
42+ start_to_close_timeout = timedelta (minutes = 2 )
2743 )
28- tasks .append (task )
2944
30- # Run all child workflows in parallel and wait for their results
31- results = await asyncio .gather (* tasks )
32-
33- for i , result in enumerate (results , start = 1 ):
34- log .info (f"ChildWorkflow { i } completed" , result = result )
35-
36- generated_text = await workflow .step (
37- function = llm_generate ,
38- function_input = GenerateInput (prompt = f"Give me the top 3 unique jokes according to the results. { results } " ),
39- task_queue = "llm" ,
40- start_to_close_timeout = timedelta (minutes = 2 )
41- )
42-
43- return {
44- "top_jokes" : generated_text ,
45- "results" : results
46- }
45+ return {
46+ "top_jokes" : generated_text ,
47+ "results" : results
48+ }
4749
50+ except Exception as e :
51+ log .error (f"ExampleWorkflow failed { e } " )
52+ raise NonRetryableError (message = f"ExampleWorkflow failed { e } " ) from e
0 commit comments