Skip to content

Commit abd4fe3

Browse files
committed
fixes and updates
1 parent e3dd869 commit abd4fe3

File tree

9 files changed

+115
-44
lines changed

9 files changed

+115
-44
lines changed

llm-complete-guide/run.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from materializers.document_materializer import DocumentMaterializer
4343
from pipelines import (
4444
finetune_embeddings,
45+
generate_chunk_questions,
4546
generate_synthetic_data,
4647
llm_basic_rag,
4748
llm_eval,
@@ -145,6 +146,13 @@
145146
default=False,
146147
help="Whether to use the reranker.",
147148
)
149+
@click.option(
150+
"--chunks",
151+
"chunks",
152+
is_flag=True,
153+
default=False,
154+
help="Generate chunks for Hugging Face dataset",
155+
)
148156
def main(
149157
rag: bool = False,
150158
evaluation: bool = False,
@@ -157,6 +165,7 @@ def main(
157165
dummyembeddings: bool = False,
158166
argilla: bool = False,
159167
reranked: bool = False,
168+
chunks: bool = False,
160169
):
161170
"""Main entry point for the pipeline execution.
162171
@@ -170,6 +179,7 @@ def main(
170179
local (bool): If `True`, the local LLM via Ollama will be used.
171180
embeddings (bool): If `True`, the embeddings will be fine-tuned.
172181
argilla (bool): If `True`, the Argilla annotations will be used.
182+
chunks (bool): If `True`, the chunks pipeline will be run.
173183
"""
174184
pipeline_args = {"enable_cache": not no_cache}
175185
embeddings_finetune_args = {
@@ -201,6 +211,8 @@ def main(
201211
finetune_embeddings.with_options(**embeddings_finetune_args)()
202212
if dummyembeddings:
203213
chunking_experiment.with_options(**pipeline_args)()
214+
if chunks:
215+
generate_chunk_questions.with_options(**pipeline_args)()
204216

205217

206218
if __name__ == "__main__":

llm-complete-guide/steps/eval_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def perform_retrieval_evaluation(
198198

199199
if all(url_ending not in url for url in urls):
200200
logging.error(
201-
f"Failed for question: {question}. Expected URL ending: {url_ending}. Got: {urls}"
201+
f"Failed for question: {question}. Expected URL containing: {url_ending}. Got: {urls}"
202202
)
203203
failures += 1
204204

llm-complete-guide/steps/finetune_embeddings.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,14 @@ def visualize_results(
373373
color="red",
374374
)
375375
for i, v in enumerate(finetuned_values):
376-
ax.text(v - 1.5, i - height / 2, f"{v:.1f}", va="center", ha="right", color="white")
376+
ax.text(
377+
v - 1.5,
378+
i - height / 2,
379+
f"{v:.1f}",
380+
va="center",
381+
ha="right",
382+
color="white",
383+
)
377384
ax.barh(
378385
[i + height / 2 for i in y],
379386
base_values,
@@ -382,7 +389,14 @@ def visualize_results(
382389
color="blue",
383390
)
384391
for i, v in enumerate(base_values):
385-
ax.text(v - 1.5, i + height / 2, f"{v:.1f}", va="center", ha="right", color="white")
392+
ax.text(
393+
v - 1.5,
394+
i + height / 2,
395+
f"{v:.1f}",
396+
va="center",
397+
ha="right",
398+
color="white",
399+
)
386400

387401
ax.set_xlabel("Scores (%)")
388402
ax.set_title("Evaluation Results")

llm-complete-guide/steps/hf_dataset_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,3 @@ def load_hf_dataset() -> (
2929
train_dataset = load_dataset(DATASET_NAME_DEFAULT, split="train")
3030
test_dataset = load_dataset(DATASET_NAME_DEFAULT, split="test")
3131
return train_dataset, test_dataset
32-

llm-complete-guide/steps/populate_index.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
# https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/
2020
# for providing the base implementation for this indexing functionality
2121

22+
import json
2223
import logging
2324
import math
24-
from typing import Annotated, List
25+
from typing import Annotated
2526

2627
from constants import (
2728
CHUNK_OVERLAP,
@@ -41,16 +42,16 @@
4142

4243
@step
4344
def preprocess_documents(
44-
documents: List[Document],
45-
) -> Annotated[List[Document], ArtifactConfig(name="split_chunks")]:
45+
documents: str,
46+
) -> Annotated[str, ArtifactConfig(name="split_chunks")]:
4647
"""
47-
Preprocesses a list of documents by splitting them into chunks.
48+
Preprocesses a JSON string of documents by splitting them into chunks.
4849
4950
Args:
50-
documents (List[Document]): A list of documents to be preprocessed.
51+
documents (str): A JSON string containing a list of documents to be preprocessed.
5152
5253
Returns:
53-
Annotated[List[Document], ArtifactConfig(name="split_chunks")]: A list of preprocessed documents annotated with an ArtifactConfig.
54+
Annotated[str, ArtifactConfig(name="split_chunks")]: A JSON string containing a list of preprocessed documents annotated with an ArtifactConfig.
5455
5556
Raises:
5657
Exception: If an error occurs during preprocessing.
@@ -64,29 +65,34 @@ def preprocess_documents(
6465
},
6566
)
6667

68+
# Parse the JSON string into a list of Document objects
69+
document_list = [Document(**doc) for doc in json.loads(documents)]
70+
6771
split_docs = split_documents(
68-
documents, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
72+
document_list, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
6973
)
70-
return split_docs
74+
75+
# Convert the list of Document objects back to a JSON string
76+
split_docs_json = json.dumps([doc.__dict__ for doc in split_docs])
77+
78+
return split_docs_json
7179
except Exception as e:
7280
logger.error(f"Error in preprocess_documents: {e}")
7381
raise
7482

7583

7684
@step
7785
def generate_embeddings(
78-
split_documents: List[Document],
79-
) -> Annotated[
80-
List[Document], ArtifactConfig(name="documents_with_embeddings")
81-
]:
86+
split_documents: str,
87+
) -> Annotated[str, ArtifactConfig(name="documents_with_embeddings")]:
8288
"""
8389
Generates embeddings for a list of split documents using a SentenceTransformer model.
8490
8591
Args:
8692
split_documents (List[Document]): A list of Document objects that have been split into chunks.
8793
8894
Returns:
89-
Annotated[List[Document], ArtifactConfig(name="embeddings")]: The list of Document objects with generated embeddings, annotated with an ArtifactConfig.
95+
Annotated[str, ArtifactConfig(name="documents_with_embeddings")]: A JSON string containing the Document objects with generated embeddings, annotated with an ArtifactConfig.
9096
9197
Raises:
9298
Exception: If an error occurs during the generation of embeddings.
@@ -95,7 +101,7 @@ def generate_embeddings(
95101
model = SentenceTransformer(EMBEDDINGS_MODEL)
96102

97103
log_artifact_metadata(
98-
artifact_name="embeddings",
104+
artifact_name="documents_with_embeddings",
99105
metadata={
100106
"embedding_type": EMBEDDINGS_MODEL,
101107
"embedding_dimensionality": EMBEDDING_DIMENSIONALITY,
@@ -106,17 +112,22 @@ def generate_embeddings(
106112
embeddings = model.encode(document_texts)
107113

108114
for doc, embedding in zip(split_documents, embeddings):
109-
doc.embedding = embedding
115+
doc.embedding = (
116+
embedding.tolist()
117+
) # Convert numpy array to list for JSON serialization
110118

111-
return split_documents
119+
# Convert the list of Document objects to a JSON string
120+
documents_json = json.dumps([doc.__dict__ for doc in split_documents])
121+
122+
return documents_json
112123
except Exception as e:
113124
logger.error(f"Error in generate_embeddings: {e}")
114125
raise
115126

116127

117128
@step
118129
def index_generator(
119-
documents: List[Document],
130+
documents: str,
120131
) -> None:
121132
"""Generates an index for the given documents.
122133
@@ -126,7 +137,7 @@ def index_generator(
126137
using the cosine distance measure.
127138
128139
Args:
129-
documents (List[Document]): The list of Document objects with generated embeddings.
140+
documents (str): A JSON string containing the Document objects with generated embeddings.
130141
131142
Raises:
132143
Exception: If an error occurs during the index generation.
@@ -155,6 +166,9 @@ def index_generator(
155166

156167
register_vector(conn)
157168

169+
# load the documents from the JSON string
170+
documents = json.loads(documents)
171+
158172
# Insert data only if it doesn't already exist
159173
for doc in documents:
160174
content = doc.page_content

llm-complete-guide/steps/synthetic_data.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import List
17+
from typing import List, Annotated
1818

1919
import pandas as pd
2020
from datasets import Dataset
2121
from huggingface_hub import create_repo
2222
from litellm import completion
2323
from structures import Document
24-
from zenml import step
24+
from zenml import step, ArtifactConfig
2525
from zenml.client import Client
2626

2727
LOCAL_MODEL = "ollama/mixtral"
@@ -36,7 +36,7 @@ def generate_question(chunk: str, local: bool = False) -> str:
3636
Returns:
3737
Generated question.
3838
"""
39-
model = LOCAL_MODEL if local else "gpt-3.5-turbo"
39+
model = LOCAL_MODEL if local else "gpt-4o"
4040
response = completion(
4141
model=model,
4242
messages=[
@@ -54,16 +54,19 @@ def generate_question(chunk: str, local: bool = False) -> str:
5454
def generate_questions_from_chunks(
5555
docs_with_embeddings: List[Document],
5656
local: bool = False,
57-
) -> List[Document]:
57+
) -> Annotated[str, ArtifactConfig(name="synthetic_questions")]:
5858
"""Generate questions from chunks.
5959
6060
Args:
61-
docs_with_embeddings: List of documents with embeddings.
6261
local: Whether to run the pipeline with a local LLM.
6362
6463
Returns:
65-
List of documents with generated questions added.
64+
JSON string containing a list of documents with generated questions added.
6665
"""
66+
client = Client()
67+
docs_with_embeddings = client.get_artifact_version(
68+
name_id_or_prefix="documents_with_embeddings"
69+
).load()
6770
for doc in docs_with_embeddings:
6871
doc.generated_questions = [generate_question(doc.page_content, local)]
6972

llm-complete-guide/steps/url_scraper.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15-
from typing import List
15+
16+
import json
1617

1718
from typing_extensions import Annotated
18-
from zenml import log_artifact_metadata, step
19+
from zenml import ArtifactConfig, log_artifact_metadata, step
1920

2021
from steps.url_scraping_utils import get_all_pages
2122

@@ -25,17 +26,16 @@ def url_scraper(
2526
docs_url: str = "https://docs.zenml.io",
2627
repo_url: str = "https://github.com/zenml-io/zenml",
2728
website_url: str = "https://zenml.io",
28-
) -> Annotated[List[str], "urls"]:
29+
) -> Annotated[str, ArtifactConfig(name="urls")]:
2930
"""Generates a list of relevant URLs to scrape.
3031
3132
Args:
3233
docs_url: URL to the documentation.
3334
repo_url: URL to the repository.
34-
release_notes_url: URL to the release notes.
3535
website_url: URL to the website.
3636
3737
Returns:
38-
List of URLs to scrape.
38+
JSON string containing a list of URLs to scrape.
3939
"""
4040
# We comment this out to make this pipeline faster
4141
# examples_readme_urls = get_nested_readme_urls(repo_url)
@@ -44,8 +44,9 @@ def url_scraper(
4444
# all_urls = docs_urls + website_urls + examples_readme_urls
4545
all_urls = docs_urls
4646
log_artifact_metadata(
47+
artifact_name="urls",
4748
metadata={
4849
"count": len(all_urls),
4950
},
5051
)
51-
return all_urls
52+
return json.dumps(all_urls)

llm-complete-guide/steps/url_scraping_utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def is_valid_url(url: str, base: str) -> bool:
4848
return not re.search(version_pattern, url)
4949

5050

51+
def strip_query_params(url: str) -> str:
52+
"""Strip query parameters from a URL.
53+
54+
Args:
55+
url (str): The URL to strip query parameters from.
56+
57+
Returns:
58+
str: The URL without query parameters.
59+
"""
60+
return url.split("?")[0]
61+
62+
5163
def get_all_pages(url: str) -> List[str]:
5264
"""
5365
Retrieve all pages with the same base as the given URL.
@@ -60,10 +72,23 @@ def get_all_pages(url: str) -> List[str]:
6072
"""
6173
logger.info(f"Scraping all pages from {url}...")
6274
base_url = urlparse(url).netloc
63-
pages = crawl(url, base_url)
64-
logger.info(f"Found {len(pages)} pages.")
75+
76+
# Use a queue-based approach instead of recursion
77+
pages = set()
78+
queue = [url]
79+
while queue:
80+
current_url = queue.pop(0)
81+
if current_url not in pages:
82+
pages.add(current_url)
83+
links = get_all_links(current_url, base_url)
84+
queue.extend(links)
85+
sleep(1 / RATE_LIMIT) # Rate limit the requests
86+
87+
stripped_pages = [strip_query_params(page) for page in pages]
88+
89+
logger.info(f"Found {len(stripped_pages)} pages.")
6590
logger.info("Done scraping pages.")
66-
return list(pages)
91+
return list(stripped_pages)
6792

6893

6994
def crawl(url: str, base: str, visited: Set[str] = None) -> Set[str]:
@@ -118,6 +143,7 @@ def get_all_links(url: str, base: str) -> List[str]:
118143
parsed_url = urlparse(full_url)
119144
cleaned_url = parsed_url._replace(fragment="").geturl()
120145
if is_valid_url(cleaned_url, base):
146+
print(cleaned_url)
121147
links.append(cleaned_url)
122148

123149
logger.debug(f"Found {len(links)} valid links from {url}")

0 commit comments

Comments
 (0)