2222import click
2323import yaml
2424from zenml .client import Client
25+ from zenml .exceptions import ZenKeyError
2526
2627from pipelines .llm_basic_rag import llm_basic_rag
2728
5354 help = "Specify a configuration file"
5455)
5556@click .option (
56- "--action -id" ,
57- "action_id " ,
57+ "--service-account -id" ,
58+ "service_account_id " ,
5859 default = None ,
59- help = "Specify an action ID"
60+ help = "Specify a service account ID"
61+ )
62+ @click .option (
63+ "--event-source-id" ,
64+ "event_source_id" ,
65+ default = None ,
66+ help = "Specify an event source ID"
6067)
6168def main (
6269 no_cache : bool = False ,
6370 config : Optional [str ]= "rag_local_dev.yaml" ,
6471 create_template : bool = False ,
65- action_id : Optional [str ] = None
72+ service_account_id : Optional [str ] = None ,
73+ event_source_id : Optional [str ] = None
6674):
6775 """
6876 Executes the pipeline to train a basic RAG model.
@@ -72,6 +80,8 @@ def main(
7280 config (str): The path to the configuration file.
7381 create_template (bool): If `True`, a run template will be created.
7482 action_id (str): The action ID.
83+ service_account_id (str): The service account ID.
84+ event_source_id (str): The event source ID.
7585 """
7686 client = Client ()
7787 config_path = Path (__file__ ).parent / "configs" / config
@@ -80,6 +90,7 @@ def main(
8090 config = yaml .safe_load (file )
8191
8292 if create_template :
93+
8394 # run pipeline
8495 run = llm_basic_rag .with_options (
8596 config_path = str (config_path ),
@@ -90,14 +101,49 @@ def main(
90101 name = f"production-llm-complete-{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} " ,
91102 deployment_id = run .deployment_id
92103 )
93- # update the action with the new template
94- client .update_action (
95- name_id_or_prefix = UUID (action_id ),
96- configuration = {
97- "template_id" : str (rt .id ),
98- "run_config" : pop_restricted_configs (config )
99- }
100- )
104+
105+ try :
106+ # Check if an action ahs already be configured for this pipeline
107+ action = client .get_action (
108+ name_id_or_prefix = "LLM Complete (production)" ,
109+ allow_name_prefix_match = True
110+ )
111+ except ZenKeyError :
112+ if not event_source_id :
113+ raise RuntimeError ("An event source is required for this workflow." )
114+
115+ if not service_account_id :
116+ service_account_id = client .create_service_account (
117+ name = "github-action-sa" ,
118+ description = "To allow triggered pipelines to run with M2M authentication."
119+ ).id
120+
121+ action_id = client .create_action (
122+ name = "LLM Complete (production)" ,
123+ configuration = {
124+ "template_id" : str (rt .id ),
125+ "run_config" : pop_restricted_configs (config )
126+ },
127+ service_account_id = service_account_id ,
128+ auth_window = 0 ,
129+ ).id
130+ client .create_trigger (
131+ name = "Production Trigger LLM-Complete" ,
132+ event_source_id = UUID (event_source_id ),
133+ event_filter = {"event_type" : "tag_event" },
134+ action_id = action_id ,
135+ description = "Trigger pipeline to reindex everytime the docs are updated through git."
136+ )
137+ else :
138+ # update the action with the new template
139+ # here we can assume the trigger is fully set up already
140+ client .update_action (
141+ name_id_or_prefix = action .id ,
142+ configuration = {
143+ "template_id" : str (rt .id ),
144+ "run_config" : pop_restricted_configs (config )
145+ }
146+ )
101147
102148 else :
103149 llm_basic_rag .with_options (
0 commit comments