|
20 | 20 | "metadata": {}, |
21 | 21 | "outputs": [], |
22 | 22 | "source": [ |
23 | | - "# Note: `faiss-cpu` is included here due to an assumption in the Hugging Face `RagRetriever` class\n", |
24 | | - "# that the FAISS library is required, even though it's not directly used in this example.\n", |
25 | | - "%pip install --quiet feast[milvus] sentence-transformers datasets faiss-cpu\n", |
| 23 | + "%pip install --quiet feast[milvus] sentence-transformers datasets\n", |
26 | 24 | "%pip install bigtree==0.19.2\n", |
27 | 25 | "%pip install marshmallow==3.10.0 " |
28 | 26 | ] |
|
51 | 49 | ")" |
52 | 50 | ] |
53 | 51 | }, |
| 52 | + { |
| 53 | + "cell_type": "markdown", |
| 54 | + "metadata": {}, |
| 55 | + "source": [ |
| 56 | + "The dataset is chunked to contain a preset number of chars, which is the max supported by Feast. Ensuring the chunk only contains whole words, thus the retrieved context can form sentences without incomplete words." |
| 57 | + ] |
| 58 | + }, |
54 | 59 | { |
55 | 60 | "cell_type": "code", |
56 | 61 | "execution_count": null, |
57 | 62 | "metadata": {}, |
58 | 63 | "outputs": [], |
59 | 64 | "source": [ |
60 | | - "def chunk_dataset(examples, chunk_size=100, overlap=20, max_chars=500):\n", |
| 65 | + "def chunk_dataset(examples, max_chars=380):\n", |
61 | 66 | " all_chunks = []\n", |
62 | 67 | " all_ids = []\n", |
63 | 68 | " all_titles = []\n", |
64 | 69 | "\n", |
65 | | - " for i, text in enumerate(examples['text']): # Iterate over texts in the batch\n", |
| 70 | + " for i, text in enumerate(examples['text']): # Iterate over texts in the batch\n", |
66 | 71 | " words = text.split()\n", |
67 | | - " chunks = []\n", |
68 | | - " for j in range(0, len(words), chunk_size - overlap):\n", |
69 | | - " chunk_words = words[j:j + chunk_size]\n", |
70 | | - " if len(chunk_words) < 20:\n", |
71 | | - " continue\n", |
72 | | - " chunk_text_value = ' '.join(chunk_words) # Store the chunk text\n", |
73 | | - " chunk_text_value = chunk_text_value[:max_chars]\n", |
74 | | - " chunks.append(chunk_text_value)\n", |
75 | | - " all_ids.append(f\"{examples['id'][i]}_{j}\") # Unique ID for the chunk\n", |
76 | | - " all_titles.append(examples['title'][i])\n", |
| 72 | + " if not words:\n", |
| 73 | + " continue\n", |
77 | 74 | "\n", |
78 | | - " all_chunks.extend(chunks)\n", |
| 75 | + " current_chunk_words = []\n", |
| 76 | + " for word in words:\n", |
| 77 | + " # Check if adding the next word exceeds the character limit\n", |
| 78 | + " if len(' '.join(current_chunk_words + [word])) > max_chars:\n", |
| 79 | + " # If the current chunk is valid, save it\n", |
| 80 | + " if current_chunk_words:\n", |
| 81 | + " chunk_text = ' '.join(current_chunk_words)\n", |
| 82 | + " all_chunks.append(chunk_text)\n", |
| 83 | + " all_ids.append(f\"{examples['id'][i]}_{len(all_chunks)}\") # Unique ID for the chunk\n", |
| 84 | + " all_titles.append(examples['title'][i])\n", |
| 85 | + " # Start a new chunk with the current word\n", |
| 86 | + " current_chunk_words = [word]\n", |
| 87 | + " else:\n", |
| 88 | + " current_chunk_words.append(word)\n", |
| 89 | + "\n", |
| 90 | + " # Add the last remaining chunk\n", |
| 91 | + " if current_chunk_words:\n", |
| 92 | + " chunk_text = ' '.join(current_chunk_words)\n", |
| 93 | + " all_chunks.append(chunk_text)\n", |
| 94 | + " all_ids.append(f\"{examples['id'][i]}_{len(all_chunks)}\") # Unique ID for the chunk\n", |
| 95 | + " all_titles.append(examples['title'][i])\n", |
79 | 96 | "\n", |
80 | 97 | " return {'id': all_ids, 'title': all_titles, 'text': all_chunks}\n", |
81 | 98 | "\n", |
|
120 | 137 | "#### Create parquet file as historical data source" |
121 | 138 | ] |
122 | 139 | }, |
| 140 | + { |
| 141 | + "cell_type": "code", |
| 142 | + "execution_count": null, |
| 143 | + "metadata": {}, |
| 144 | + "outputs": [], |
| 145 | + "source": [ |
| 146 | + "%mkdir feature_repo/data" |
| 147 | + ] |
| 148 | + }, |
123 | 149 | { |
124 | 150 | "cell_type": "code", |
125 | 151 | "execution_count": null, |
|
145 | 171 | "print(df[\"embedding\"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts()) # Check lengths\n", |
146 | 172 | "\n", |
147 | 173 | "# Save to Parquet\n", |
148 | | - "df.to_parquet(\"wiki_dpr.parquet\", index=False)\n", |
| 174 | + "df.to_parquet(\"feature_repo/data/wiki_dpr.parquet\", index=False)\n", |
149 | 175 | "print(\"Saved to wiki_dpr.parquet\")" |
150 | 176 | ] |
151 | 177 | }, |
|
231 | 257 | "source": [ |
232 | 258 | "import sys\n", |
233 | 259 | "sys.path.append(\"..\")\n", |
234 | | - "from feast_rag_retriever import FeastVectorStore, FeastRAGRetriever, FeastIndex\n", |
235 | | - "from rag_project_repo import wiki_passage_feature_view\n", |
| 260 | + "from ragproject_repo import wiki_passage_feature_view\n", |
| 261 | + "from feast.vector_store import FeastVectorStore\n", |
| 262 | + "from feast.rag_retriever import FeastIndex, FeastRAGRetriever\n", |
236 | 263 | "\n", |
237 | 264 | "generator_config=generator_model.config\n", |
238 | 265 | "question_encoder = AutoModel.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n", |
|
245 | 272 | "}\n", |
246 | 273 | "\n", |
247 | 274 | "vector_store = FeastVectorStore(\n", |
248 | | - " store=store,\n", |
| 275 | + " repo_path=\".\",\n", |
249 | 276 | " rag_view=wiki_passage_feature_view,\n", |
250 | | - " features=[\"wiki_passages:passage_text\", \"wiki_passages:embedding\"]\n", |
| 277 | + " features=[\"wiki_passages:passage_text\", \"wiki_passages:embedding\", \"wiki_passages:passage_id\"]\n", |
251 | 278 | ")\n", |
252 | 279 | "\n", |
253 | | - "feast_index = FeastIndex(vector_store=vector_store)\n", |
| 280 | + "feast_index = FeastIndex()\n", |
254 | 281 | "\n", |
255 | 282 | "config = RagConfig(\n", |
256 | 283 | " question_encoder=query_encoder_config,\n", |
|
262 | 289 | " question_encoder_tokenizer=question_encoder_tokenizer,\n", |
263 | 290 | " generator_tokenizer=generator_tokenizer,\n", |
264 | 291 | " feast_repo_path=\".\",\n", |
265 | | - " vector_store=vector_store,\n", |
| 292 | + " feature_view=vector_store.rag_view,\n", |
| 293 | + " features=vector_store.features,\n", |
266 | 294 | " generator_model=generator_model, \n", |
267 | 295 | " search_type=\"vector\",\n", |
268 | 296 | " id_field=\"passage_id\",\n", |
| 297 | + " text_field=\"passage_text\",\n", |
269 | 298 | " config=config,\n", |
270 | 299 | " index=feast_index,\n", |
271 | 300 | ")" |
|
0 commit comments