1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- from typing import List , Annotated
17+ from typing import Annotated
18+ import logging
19+ import json
1820
1921import pandas as pd
2022from datasets import Dataset
2123from huggingface_hub import create_repo
2224from litellm import completion
2325from structures import Document
24- from zenml import step , ArtifactConfig
26+ from zenml import ArtifactConfig , step
2527from zenml .client import Client
2628
29+ logger = logging .getLogger (__name__ )
30+
2731LOCAL_MODEL = "ollama/mixtral"
2832
2933
@@ -52,31 +56,37 @@ def generate_question(chunk: str, local: bool = False) -> str:
5256
5357@step
5458def generate_questions_from_chunks (
55- docs_with_embeddings : List [ Document ] ,
59+ docs_with_embeddings : str ,
5660 local : bool = False ,
61+ logging_interval : int = 10 ,
5762) -> Annotated [str , ArtifactConfig (name = "synthetic_questions" )]:
5863 """Generate questions from chunks.
5964
6065 Args:
66+ docs_with_embeddings: JSON string containing a list of Document objects with embeddings.
6167 local: Whether to run the pipeline with a local LLM.
6268
6369 Returns:
6470 JSON string containing a list of documents with generated questions added.
6571 """
66- client = Client ()
67- docs_with_embeddings = client . get_artifact_version (
68- name_id_or_prefix = "documents_with_embeddings"
69- ). load ()
70- for doc in docs_with_embeddings :
72+ document_list = [
73+ Document ( ** doc ) for doc in json . loads ( docs_with_embeddings )
74+ ]
75+
76+ for i , doc in enumerate ( document_list , 1 ) :
7177 doc .generated_questions = [generate_question (doc .page_content , local )]
78+ if i % logging_interval == 0 :
79+ logger .info (
80+ f"Progress: { i } /{ len (document_list )} documents processed"
81+ )
82+ logger .info (
83+ f"Generated question for document { i } : { doc .generated_questions [0 ]} "
84+ )
7285
73- assert all (doc .generated_questions for doc in docs_with_embeddings )
86+ assert all (doc .generated_questions for doc in document_list )
7487
7588 # Convert List[Document] to DataFrame
76- df = pd .DataFrame ([doc .__dict__ for doc in docs_with_embeddings ])
77-
78- # Convert numpy arrays to lists
79- df ["embedding" ] = df ["embedding" ].apply (lambda x : x .tolist ())
89+ df = pd .DataFrame ([doc .__dict__ for doc in document_list ])
8090
8191 # upload the parquet file to a private dataset on the huggingface hub
8292 client = Client ()
@@ -86,14 +96,15 @@ def generate_questions_from_chunks(
8696 "zenml/rag_qa_embedding_questions" ,
8797 token = hf_token ,
8898 exist_ok = True ,
89- private = True ,
9099 repo_type = "dataset" ,
91100 )
92101
102+ # add an extra `__pydantic_initialised__` column to the dataframe
103+ df ["__pydantic_initialised__" ] = True
104+
93105 dataset = Dataset .from_pandas (df )
94106 dataset .push_to_hub (
95107 repo_id = "zenml/rag_qa_embedding_questions" ,
96- private = True ,
97108 token = hf_token ,
98109 create_pr = True ,
99110 )
0 commit comments