2121
2222import click
2323import yaml
24+ from pipelines .llm_basic_rag import llm_basic_rag
2425from zenml .client import Client
2526from zenml .exceptions import ZenKeyError
2627
27- from pipelines .llm_basic_rag import llm_basic_rag
28-
2928
3029@click .command (
3130 help = """
3938 default = False ,
4039 help = "Disable cache." ,
4140)
42-
4341@click .option (
4442 "--create-template" ,
4543 "create_template" ,
5149 "--config" ,
5250 "config" ,
5351 default = "rag_local_dev.yaml" ,
54- help = "Specify a configuration file"
52+ help = "Specify a configuration file" ,
5553)
5654@click .option (
5755 "--service-account-id" ,
5856 "service_account_id" ,
5957 default = None ,
60- help = "Specify a service account ID"
58+ help = "Specify a service account ID" ,
6159)
6260@click .option (
6361 "--event-source-id" ,
6462 "event_source_id" ,
6563 default = None ,
66- help = "Specify an event source ID"
64+ help = "Specify an event source ID" ,
6765)
6866def main (
6967 no_cache : bool = False ,
70- config : Optional [str ]= "rag_local_dev.yaml" ,
68+ config : Optional [str ] = "rag_local_dev.yaml" ,
7169 create_template : bool = False ,
7270 service_account_id : Optional [str ] = None ,
73- event_source_id : Optional [str ] = None
71+ event_source_id : Optional [str ] = None ,
7472):
7573 """
7674 Executes the pipeline to train a basic RAG model.
@@ -86,43 +84,43 @@ def main(
8684 client = Client ()
8785 config_path = Path (__file__ ).parent / "configs" / config
8886
89- with ( open (config_path ,"r" ) as file ) :
87+ with open (config_path , "r" ) as file :
9088 config = yaml .safe_load (file )
9189
9290 if create_template :
93-
9491 # run pipeline
9592 run = llm_basic_rag .with_options (
96- config_path = str (config_path ),
97- enable_cache = not no_cache
93+ config_path = str (config_path ), enable_cache = not no_cache
9894 )()
9995 # create new run template
10096 rt = client .create_run_template (
10197 name = f"production-llm-complete-{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} " ,
102- deployment_id = run .deployment_id
98+ deployment_id = run .deployment_id ,
10399 )
104100
105101 try :
106102 # Check if an action ahs already be configured for this pipeline
107103 action = client .get_action (
108104 name_id_or_prefix = "LLM Complete (production)" ,
109- allow_name_prefix_match = True
105+ allow_name_prefix_match = True ,
110106 )
111107 except ZenKeyError :
112108 if not event_source_id :
113- raise RuntimeError ("An event source is required for this workflow." )
109+ raise RuntimeError (
110+ "An event source is required for this workflow."
111+ )
114112
115113 if not service_account_id :
116114 service_account_id = client .create_service_account (
117115 name = "github-action-sa" ,
118- description = "To allow triggered pipelines to run with M2M authentication."
116+ description = "To allow triggered pipelines to run with M2M authentication." ,
119117 ).id
120118
121119 action_id = client .create_action (
122120 name = "LLM Complete (production)" ,
123121 configuration = {
124122 "template_id" : str (rt .id ),
125- "run_config" : pop_restricted_configs (config )
123+ "run_config" : pop_restricted_configs (config ),
126124 },
127125 service_account_id = service_account_id ,
128126 auth_window = 0 ,
@@ -132,7 +130,7 @@ def main(
132130 event_source_id = UUID (event_source_id ),
133131 event_filter = {"event_type" : "tag_event" },
134132 action_id = action_id ,
135- description = "Trigger pipeline to reindex everytime the docs are updated through git."
133+ description = "Trigger pipeline to reindex everytime the docs are updated through git." ,
136134 )
137135 else :
138136 # update the action with the new template
@@ -141,14 +139,13 @@ def main(
141139 name_id_or_prefix = action .id ,
142140 configuration = {
143141 "template_id" : str (rt .id ),
144- "run_config" : pop_restricted_configs (config )
145- }
142+ "run_config" : pop_restricted_configs (config ),
143+ },
146144 )
147145
148146 else :
149147 llm_basic_rag .with_options (
150- config_path = str (config_path ),
151- enable_cache = not no_cache
148+ config_path = str (config_path ), enable_cache = not no_cache
152149 )()
153150
154151
@@ -162,22 +159,22 @@ def pop_restricted_configs(run_configuration: dict) -> dict:
162159 Modified dictionary with restricted items removed
163160 """
164161 # Pop top-level restricted items
165- run_configuration .pop (' parameters' , None )
166- run_configuration .pop (' build' , None )
167- run_configuration .pop (' schedule' , None )
162+ run_configuration .pop (" parameters" , None )
163+ run_configuration .pop (" build" , None )
164+ run_configuration .pop (" schedule" , None )
168165
169166 # Pop docker settings if they exist
170- if ' settings' in run_configuration :
171- run_configuration [' settings' ].pop (' docker' , None )
167+ if " settings" in run_configuration :
168+ run_configuration [" settings" ].pop (" docker" , None )
172169
173170 # Pop docker settings from steps if they exist
174- if ' steps' in run_configuration :
175- for step in run_configuration [' steps' ].values ():
176- if ' settings' in step :
177- step [' settings' ].pop (' docker' , None )
171+ if " steps" in run_configuration :
172+ for step in run_configuration [" steps" ].values ():
173+ if " settings" in step :
174+ step [" settings" ].pop (" docker" , None )
178175
179176 return run_configuration
180177
181178
182179if __name__ == "__main__" :
183- main ()
180+ main ()
0 commit comments