1+ # Apache Software License 2.0
2+ #
3+ # Copyright (c) ZenML GmbH 2024. All rights reserved.
4+ #
5+ # Licensed under the Apache License, Version 2.0 (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ #
17+ import json
18+ from datetime import datetime
19+ from pathlib import Path
20+ from typing import Optional
21+ from uuid import UUID
22+
23+ import click
24+ import yaml
25+ from zenml .client import Client
26+
27+ from pipelines .llm_basic_rag import llm_basic_rag
28+
29+
30+ @click .command (
31+ help = """
32+ ZenML LLM Complete - Rag Pipeline
33+ """
34+ )
35+ @click .option (
36+ "--no-cache" ,
37+ "no_cache" ,
38+ is_flag = True ,
39+ default = False ,
40+ help = "Disable cache." ,
41+ )
42+
43+ @click .option (
44+ "--create-template" ,
45+ "create_template" ,
46+ is_flag = True ,
47+ default = False ,
48+ help = "Create a run template." ,
49+ )
50+ @click .option (
51+ "--config" ,
52+ "config" ,
53+ default = "rag_local_dev.yaml" ,
54+ help = "Specify a configuration file"
55+ )
56+ @click .option (
57+ "--action-id" ,
58+ "action_id" ,
59+ default = None ,
60+ help = "Specify an action ID"
61+ )
62+ def main (
63+ no_cache : bool = False ,
64+ config : Optional [str ]= "rag_local_dev.yaml" ,
65+ create_template : bool = False ,
66+ action_id : Optional [str ] = None
67+ ):
68+ """
69+ Executes the pipeline to train a basic RAG model.
70+
71+ Args:
72+ no_cache (bool): If `True`, cache will be disabled.
73+ config (str): The path to the configuration file.
74+ create_template (bool): If `True`, a run template will be created.
75+ action_id (str): The action ID.
76+ """
77+ client = Client ()
78+ config_path = Path (__file__ ).parent / "configs" / config
79+
80+ with (open (config_path ,"r" ) as file ):
81+ config = yaml .safe_load (file )
82+
83+ if create_template :
84+ # run pipeline
85+ run = llm_basic_rag .with_options (
86+ config_path = str (config_path ),
87+ enable_cache = not no_cache
88+ )()
89+ # create new run template
90+ rt = client .create_run_template (
91+ name = f"production-llm-complete-{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} " ,
92+ deployment_id = run .deployment_id
93+ )
94+ # update the action with the new template
95+ client .update_action (
96+ name_id_or_prefix = UUID (action_id ),
97+ configuration = {
98+ "template_id" : str (rt .id ),
99+ "run_config" : pop_restricted_configs (config )
100+ }
101+ )
102+
103+ else :
104+ llm_basic_rag .with_options (
105+ config_path = str (config_path ),
106+ enable_cache = not no_cache
107+ )()
108+
109+
110+ def pop_restricted_configs (run_configuration : dict ) -> dict :
111+ """Removes restricted configuration items from a run configuration dictionary.
112+
113+ Args:
114+ run_configuration: Dictionary containing run configuration settings
115+
116+ Returns:
117+ Modified dictionary with restricted items removed
118+ """
119+ # Pop top-level restricted items
120+ run_configuration .pop ('parameters' , None )
121+ run_configuration .pop ('build' , None )
122+ run_configuration .pop ('schedule' , None )
123+
124+ # Pop docker settings if they exist
125+ if 'settings' in run_configuration :
126+ run_configuration ['settings' ].pop ('docker' , None )
127+
128+ # Pop docker settings from steps if they exist
129+ if 'steps' in run_configuration :
130+ for step in run_configuration ['steps' ].values ():
131+ if 'settings' in step :
132+ step ['settings' ].pop ('docker' , None )
133+
134+ return run_configuration
135+
136+ if __name__ == "__main__" :
137+ main ()
0 commit comments