1+ # Apache Software License 2.0
2+ #
3+ # Copyright (c) ZenML GmbH 2024. All rights reserved.
4+ #
5+ # Licensed under the Apache License, Version 2.0 (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ #
17+
18+ from datetime import datetime
19+ from pathlib import Path
20+ from typing import Optional
21+ from uuid import UUID
22+
23+ import click
24+ from zenml .client import Client
25+
26+ from pipelines import llm_basic_rag
27+
28+
29+ @click .command (
30+ help = """
31+ ZenML LLM Complete - Rag Pipeline
32+ """
33+ )
34+ @click .option (
35+ "--no-cache" ,
36+ "no_cache" ,
37+ is_flag = True ,
38+ default = False ,
39+ help = "Disable cache." ,
40+ )
41+
42+ @click .option (
43+ "--create-template" ,
44+ "create_template" ,
45+ is_flag = True ,
46+ default = False ,
47+ help = "Create a run template." ,
48+ )
49+ @click .option (
50+ "--config" ,
51+ "config" ,
52+ default = "rag_local_dev.yaml" ,
53+ help = "Specify a configuration file"
54+ )
55+ @click .option (
56+ "--action-id" ,
57+ "action_id" ,
58+ default = None ,
59+ help = "Specify an action ID"
60+ )
61+ def main (
62+ no_cache : bool = False ,
63+ config : Optional [str ]= "rag_local_dev.yaml" ,
64+ create_template : bool = False ,
65+ action_id : Optional [str ] = None
66+ ):
67+ """
68+ Executes the pipeline to train a basic RAG model.
69+
70+ Args:
71+ no_cache (bool): If `True`, cache will be disabled.
72+ config (str): The path to the configuration file.
73+ create_template (bool): If `True`, a run template will be created.
74+ action_id (str): The action ID.
75+ """
76+ client = Client ()
77+ config_path = Path (__file__ ).parent .parent / "configs" / config
78+
79+ if create_template :
80+ # run pipeline
81+ run = llm_basic_rag .with_options (
82+ config_path = str (config_path ),
83+ enable_cache = not no_cache
84+ )()
85+ # create new run template
86+ rt = client .create_run_template (
87+ name = f"production-llm-complete-{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} " ,
88+ deployment_id = run .deployment_id
89+ )
90+ # update the action with the new template
91+ client .update_action (
92+ name_id_or_prefix = UUID (action_id ),
93+ configuration = {
94+ "template_id" : str (rt .id )
95+ }
96+ )
97+
98+ else :
99+ llm_basic_rag .with_options (
100+ config_path = str (config_path ),
101+ enable_cache = not no_cache
102+ )()
103+
104+
105+ if __name__ == "__main__" :
106+ main ()
0 commit comments