|
20 | 20 | "metadata": {},
|
21 | 21 | "outputs": [],
|
22 | 22 | "source": [
|
23 |
| - "# Note: `faiss-cpu` is included here due to an assumption in the Hugging Face `RagRetriever` class\n", |
24 |
| - "# that the FAISS library is required, even though it's not directly used in this example.\n", |
25 |
| - "%pip install --quiet feast[milvus] sentence-transformers datasets faiss-cpu\n", |
| 23 | + "%pip install --quiet feast[milvus] sentence-transformers datasets\n", |
26 | 24 | "%pip install bigtree==0.19.2\n",
|
27 | 25 | "%pip install marshmallow==3.10.0 "
|
28 | 26 | ]
|
|
51 | 49 | ")"
|
52 | 50 | ]
|
53 | 51 | },
|
| 52 | + { |
| 53 | + "cell_type": "markdown", |
| 54 | + "metadata": {}, |
| 55 | + "source": [ |
| 56 | + "The dataset is chunked to contain a preset number of chars, which is the max supported by Feast. Ensuring the chunk only contains whole words, thus the retrieved context can form sentences without incomplete words." |
| 57 | + ] |
| 58 | + }, |
54 | 59 | {
|
55 | 60 | "cell_type": "code",
|
56 | 61 | "execution_count": null,
|
57 | 62 | "metadata": {},
|
58 | 63 | "outputs": [],
|
59 | 64 | "source": [
|
60 |
| - "def chunk_dataset(examples, chunk_size=100, overlap=20, max_chars=500):\n", |
| 65 | + "def chunk_dataset(examples, max_chars=380):\n", |
61 | 66 | " all_chunks = []\n",
|
62 | 67 | " all_ids = []\n",
|
63 | 68 | " all_titles = []\n",
|
64 | 69 | "\n",
|
65 |
| - " for i, text in enumerate(examples['text']): # Iterate over texts in the batch\n", |
| 70 | + " for i, text in enumerate(examples['text']): # Iterate over texts in the batch\n", |
66 | 71 | " words = text.split()\n",
|
67 |
| - " chunks = []\n", |
68 |
| - " for j in range(0, len(words), chunk_size - overlap):\n", |
69 |
| - " chunk_words = words[j:j + chunk_size]\n", |
70 |
| - " if len(chunk_words) < 20:\n", |
71 |
| - " continue\n", |
72 |
| - " chunk_text_value = ' '.join(chunk_words) # Store the chunk text\n", |
73 |
| - " chunk_text_value = chunk_text_value[:max_chars]\n", |
74 |
| - " chunks.append(chunk_text_value)\n", |
75 |
| - " all_ids.append(f\"{examples['id'][i]}_{j}\") # Unique ID for the chunk\n", |
76 |
| - " all_titles.append(examples['title'][i])\n", |
| 72 | + " if not words:\n", |
| 73 | + " continue\n", |
77 | 74 | "\n",
|
78 |
| - " all_chunks.extend(chunks)\n", |
| 75 | + " current_chunk_words = []\n", |
| 76 | + " for word in words:\n", |
| 77 | + " # Check if adding the next word exceeds the character limit\n", |
| 78 | + " if len(' '.join(current_chunk_words + [word])) > max_chars:\n", |
| 79 | + " # If the current chunk is valid, save it\n", |
| 80 | + " if current_chunk_words:\n", |
| 81 | + " chunk_text = ' '.join(current_chunk_words)\n", |
| 82 | + " all_chunks.append(chunk_text)\n", |
| 83 | + " all_ids.append(f\"{examples['id'][i]}_{len(all_chunks)}\") # Unique ID for the chunk\n", |
| 84 | + " all_titles.append(examples['title'][i])\n", |
| 85 | + " # Start a new chunk with the current word\n", |
| 86 | + " current_chunk_words = [word]\n", |
| 87 | + " else:\n", |
| 88 | + " current_chunk_words.append(word)\n", |
| 89 | + "\n", |
| 90 | + " # Add the last remaining chunk\n", |
| 91 | + " if current_chunk_words:\n", |
| 92 | + " chunk_text = ' '.join(current_chunk_words)\n", |
| 93 | + " all_chunks.append(chunk_text)\n", |
| 94 | + " all_ids.append(f\"{examples['id'][i]}_{len(all_chunks)}\") # Unique ID for the chunk\n", |
| 95 | + " all_titles.append(examples['title'][i])\n", |
79 | 96 | "\n",
|
80 | 97 | " return {'id': all_ids, 'title': all_titles, 'text': all_chunks}\n",
|
81 | 98 | "\n",
|
|
120 | 137 | "#### Create parquet file as historical data source"
|
121 | 138 | ]
|
122 | 139 | },
|
| 140 | + { |
| 141 | + "cell_type": "code", |
| 142 | + "execution_count": null, |
| 143 | + "metadata": {}, |
| 144 | + "outputs": [], |
| 145 | + "source": [ |
| 146 | + "%mkdir feature_repo/data" |
| 147 | + ] |
| 148 | + }, |
123 | 149 | {
|
124 | 150 | "cell_type": "code",
|
125 | 151 | "execution_count": null,
|
|
145 | 171 | "print(df[\"embedding\"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts()) # Check lengths\n",
|
146 | 172 | "\n",
|
147 | 173 | "# Save to Parquet\n",
|
148 |
| - "df.to_parquet(\"wiki_dpr.parquet\", index=False)\n", |
| 174 | + "df.to_parquet(\"feature_repo/data/wiki_dpr.parquet\", index=False)\n", |
149 | 175 | "print(\"Saved to wiki_dpr.parquet\")"
|
150 | 176 | ]
|
151 | 177 | },
|
|
231 | 257 | "source": [
|
232 | 258 | "import sys\n",
|
233 | 259 | "sys.path.append(\"..\")\n",
|
234 |
| - "from feast_rag_retriever import FeastVectorStore, FeastRAGRetriever, FeastIndex\n", |
235 |
| - "from rag_project_repo import wiki_passage_feature_view\n", |
| 260 | + "from ragproject_repo import wiki_passage_feature_view\n", |
| 261 | + "from feast.vector_store import FeastVectorStore\n", |
| 262 | + "from feast.rag_retriever import FeastIndex, FeastRAGRetriever\n", |
236 | 263 | "\n",
|
237 | 264 | "generator_config=generator_model.config\n",
|
238 | 265 | "question_encoder = AutoModel.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
|
|
245 | 272 | "}\n",
|
246 | 273 | "\n",
|
247 | 274 | "vector_store = FeastVectorStore(\n",
|
248 |
| - " store=store,\n", |
| 275 | + " repo_path=\".\",\n", |
249 | 276 | " rag_view=wiki_passage_feature_view,\n",
|
250 |
| - " features=[\"wiki_passages:passage_text\", \"wiki_passages:embedding\"]\n", |
| 277 | + " features=[\"wiki_passages:passage_text\", \"wiki_passages:embedding\", \"wiki_passages:passage_id\"]\n", |
251 | 278 | ")\n",
|
252 | 279 | "\n",
|
253 |
| - "feast_index = FeastIndex(vector_store=vector_store)\n", |
| 280 | + "feast_index = FeastIndex()\n", |
254 | 281 | "\n",
|
255 | 282 | "config = RagConfig(\n",
|
256 | 283 | " question_encoder=query_encoder_config,\n",
|
|
262 | 289 | " question_encoder_tokenizer=question_encoder_tokenizer,\n",
|
263 | 290 | " generator_tokenizer=generator_tokenizer,\n",
|
264 | 291 | " feast_repo_path=\".\",\n",
|
265 |
| - " vector_store=vector_store,\n", |
| 292 | + " feature_view=vector_store.rag_view,\n", |
| 293 | + " features=vector_store.features,\n", |
266 | 294 | " generator_model=generator_model, \n",
|
267 | 295 | " search_type=\"vector\",\n",
|
268 | 296 | " id_field=\"passage_id\",\n",
|
| 297 | + " text_field=\"passage_text\",\n", |
269 | 298 | " config=config,\n",
|
270 | 299 | " index=feast_index,\n",
|
271 | 300 | ")"
|
|
0 commit comments