Skip to content

Commit 3440fbc

Browse files
committed
formatting
1 parent e1b10ce commit 3440fbc

File tree

6 files changed

+111
-66
lines changed

6 files changed

+111
-66
lines changed

llm-complete-guide/gh_action_rag.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121

2222
import click
2323
import yaml
24+
from pipelines.llm_basic_rag import llm_basic_rag
2425
from zenml.client import Client
2526
from zenml.exceptions import ZenKeyError
2627

27-
from pipelines.llm_basic_rag import llm_basic_rag
28-
2928

3029
@click.command(
3130
help="""
@@ -39,7 +38,6 @@
3938
default=False,
4039
help="Disable cache.",
4140
)
42-
4341
@click.option(
4442
"--create-template",
4543
"create_template",
@@ -51,26 +49,26 @@
5149
"--config",
5250
"config",
5351
default="rag_local_dev.yaml",
54-
help="Specify a configuration file"
52+
help="Specify a configuration file",
5553
)
5654
@click.option(
5755
"--service-account-id",
5856
"service_account_id",
5957
default=None,
60-
help="Specify a service account ID"
58+
help="Specify a service account ID",
6159
)
6260
@click.option(
6361
"--event-source-id",
6462
"event_source_id",
6563
default=None,
66-
help="Specify an event source ID"
64+
help="Specify an event source ID",
6765
)
6866
def main(
6967
no_cache: bool = False,
70-
config: Optional[str]= "rag_local_dev.yaml",
68+
config: Optional[str] = "rag_local_dev.yaml",
7169
create_template: bool = False,
7270
service_account_id: Optional[str] = None,
73-
event_source_id: Optional[str] = None
71+
event_source_id: Optional[str] = None,
7472
):
7573
"""
7674
Executes the pipeline to train a basic RAG model.
@@ -86,43 +84,43 @@ def main(
8684
client = Client()
8785
config_path = Path(__file__).parent / "configs" / config
8886

89-
with (open(config_path,"r") as file):
87+
with open(config_path, "r") as file:
9088
config = yaml.safe_load(file)
9189

9290
if create_template:
93-
9491
# run pipeline
9592
run = llm_basic_rag.with_options(
96-
config_path=str(config_path),
97-
enable_cache=not no_cache
93+
config_path=str(config_path), enable_cache=not no_cache
9894
)()
9995
# create new run template
10096
rt = client.create_run_template(
10197
name=f"production-llm-complete-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
102-
deployment_id=run.deployment_id
98+
deployment_id=run.deployment_id,
10399
)
104100

105101
try:
106102
# Check if an action ahs already be configured for this pipeline
107103
action = client.get_action(
108104
name_id_or_prefix="LLM Complete (production)",
109-
allow_name_prefix_match=True
105+
allow_name_prefix_match=True,
110106
)
111107
except ZenKeyError:
112108
if not event_source_id:
113-
raise RuntimeError("An event source is required for this workflow.")
109+
raise RuntimeError(
110+
"An event source is required for this workflow."
111+
)
114112

115113
if not service_account_id:
116114
service_account_id = client.create_service_account(
117115
name="github-action-sa",
118-
description="To allow triggered pipelines to run with M2M authentication."
116+
description="To allow triggered pipelines to run with M2M authentication.",
119117
).id
120118

121119
action_id = client.create_action(
122120
name="LLM Complete (production)",
123121
configuration={
124122
"template_id": str(rt.id),
125-
"run_config": pop_restricted_configs(config)
123+
"run_config": pop_restricted_configs(config),
126124
},
127125
service_account_id=service_account_id,
128126
auth_window=0,
@@ -132,7 +130,7 @@ def main(
132130
event_source_id=UUID(event_source_id),
133131
event_filter={"event_type": "tag_event"},
134132
action_id=action_id,
135-
description="Trigger pipeline to reindex everytime the docs are updated through git."
133+
description="Trigger pipeline to reindex everytime the docs are updated through git.",
136134
)
137135
else:
138136
# update the action with the new template
@@ -141,14 +139,13 @@ def main(
141139
name_id_or_prefix=action.id,
142140
configuration={
143141
"template_id": str(rt.id),
144-
"run_config": pop_restricted_configs(config)
145-
}
142+
"run_config": pop_restricted_configs(config),
143+
},
146144
)
147145

148146
else:
149147
llm_basic_rag.with_options(
150-
config_path=str(config_path),
151-
enable_cache=not no_cache
148+
config_path=str(config_path), enable_cache=not no_cache
152149
)()
153150

154151

@@ -162,22 +159,22 @@ def pop_restricted_configs(run_configuration: dict) -> dict:
162159
Modified dictionary with restricted items removed
163160
"""
164161
# Pop top-level restricted items
165-
run_configuration.pop('parameters', None)
166-
run_configuration.pop('build', None)
167-
run_configuration.pop('schedule', None)
162+
run_configuration.pop("parameters", None)
163+
run_configuration.pop("build", None)
164+
run_configuration.pop("schedule", None)
168165

169166
# Pop docker settings if they exist
170-
if 'settings' in run_configuration:
171-
run_configuration['settings'].pop('docker', None)
167+
if "settings" in run_configuration:
168+
run_configuration["settings"].pop("docker", None)
172169

173170
# Pop docker settings from steps if they exist
174-
if 'steps' in run_configuration:
175-
for step in run_configuration['steps'].values():
176-
if 'settings' in step:
177-
step['settings'].pop('docker', None)
171+
if "steps" in run_configuration:
172+
for step in run_configuration["steps"].values():
173+
if "settings" in step:
174+
step["settings"].pop("docker", None)
178175

179176
return run_configuration
180177

181178

182179
if __name__ == "__main__":
183-
main()
180+
main()

llm-complete-guide/pipelines/llm_basic_rag.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
# limitations under the License.
1616
#
1717

18-
from zenml import pipeline
19-
2018
from steps.populate_index import (
2119
generate_embeddings,
2220
index_generator,
2321
preprocess_documents,
2422
)
2523
from steps.url_scraper import url_scraper
2624
from steps.web_url_loader import web_url_loader
25+
from zenml import pipeline
2726

2827

2928
@pipeline

llm-complete-guide/pipelines/llm_eval.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
import os
1716
from pathlib import Path
1817
from typing import Optional
1918

2019
import click
21-
2220
from steps.eval_e2e import e2e_evaluation, e2e_evaluation_llm_judged
2321
from steps.eval_retrieval import (
2422
retrieval_evaluation_full,
@@ -82,12 +80,9 @@ def llm_eval() -> None:
8280
"--config",
8381
"config",
8482
default="rag_local_dev.yaml",
85-
help="Specify a configuration file"
83+
help="Specify a configuration file",
8684
)
87-
def main(
88-
no_cache: bool = False,
89-
config: Optional[str] = "rag_eval.yaml"
90-
):
85+
def main(no_cache: bool = False, config: Optional[str] = "rag_eval.yaml"):
9186
"""
9287
Executes the pipeline to train a basic RAG model.
9388
@@ -98,10 +93,9 @@ def main(
9893
config_path = Path(__file__).parent.parent / "configs" / config
9994

10095
llm_eval.with_options(
101-
config_path=str(config_path),
102-
enable_cache=not no_cache
96+
config_path=str(config_path), enable_cache=not no_cache
10397
)()
10498

10599

106100
if __name__ == "__main__":
107-
main()
101+
main()

llm-complete-guide/steps/finetune_embeddings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
DATASET_NAME_DISTILABEL,
2424
EMBEDDINGS_MODEL_ID_BASELINE,
2525
EMBEDDINGS_MODEL_ID_FINE_TUNED,
26-
EMBEDDINGS_MODEL_MATRYOSHKA_DIMS, SECRET_NAME,
26+
EMBEDDINGS_MODEL_MATRYOSHKA_DIMS,
27+
SECRET_NAME,
2728
)
2829
from datasets import DatasetDict, concatenate_datasets, load_dataset
2930
from datasets.arrow_dataset import Dataset
@@ -294,7 +295,7 @@ def finetune(
294295
trainer.model.push_to_hub(
295296
f"zenml/{EMBEDDINGS_MODEL_ID_FINE_TUNED}",
296297
exist_ok=True,
297-
token=zenml_client.get_secret(SECRET_NAME).secret_values["hf_token"]
298+
token=zenml_client.get_secret(SECRET_NAME).secret_values["hf_token"],
298299
)
299300

300301
log_model_metadata(

0 commit comments

Comments
 (0)