Skip to content

Commit 0e36f38

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents b8108cd + 6cfc426 commit 0e36f38

File tree

6 files changed

+665
-2
lines changed

6 files changed

+665
-2
lines changed

examples/kfto_feast_rag/README.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# End-to-end RAG example using Feast, Milvus, and OpenShift AI.
2+
3+
## Introduction
4+
This example notebook provides a step-by-step demonstration of building and using a RAG system with Feast Feature Store, on OpenShift AI. The notebook walks through:
5+
6+
1. Data Preparation
7+
- Loads a subset of the Wikipedia DPR dataset (1% of training data)
8+
- Implements text chunking with configurable chunk size and overlap
9+
- Processes text into manageable passages with unique IDs
10+
11+
2. Embedding Generation
12+
- Uses `all-MiniLM-L6-v2` sentence transformer model
13+
- Generates 384-dimensional embeddings for text passages
14+
- Demonstrates batch processing with GPU support
15+
16+
3. Feature Store Setup
17+
- Creates a Parquet file as the historical data source
18+
- Configures Feast with the feature repository
19+
- Demonstrates writing embeddings from data source to Milvus online store which can be used for model training later
20+
21+
4. RAG System Implementation
22+
- **Embedding Model**: `all-MiniLM-L6-v2` (configurable)
23+
- **Generator Model**: `granite-3.2-2b-instruct` (configurable)
24+
- **Vector Store**: Custom implementation with Feast integration
25+
- **Retriever**: Custom implementation with Feast integration extending HuggingFace's RagRetriever
26+
27+
5. Query Demonstration
28+
- Perform inference with retrieved context
29+
30+
## Requirements
31+
- An OpenShift cluster with OpenShift AI (RHOAI) 2.20+ installed:
32+
- The dashboard, feastoperator and workbenches components enabled.
33+
- Workbench with medium size container, 1 NVIDIA GPU accelerator, and cluster storage of 200GB.
34+
- A standalone Milvus deployment. See example [here](https://github.com/rh-aiservices-bu/llm-on-openshift/tree/main/vector-databases/milvus#deployment).
35+
36+
## Running the example
37+
From the workbench, clone this repository: https://github.com/opendatahub-io/distributed-workloads.git
38+
Navigate to the distributed-workloads/examples/kfto-feast-rag directory. Here you will find the following files:
39+
40+
* **feast_rag_retriever.py**
41+
This module implements a custom RAG retriever by combining Feast feature store capabilities with HuggingFace transformer-based models. The implementation provides:
42+
43+
- A flexible vector store interface with Feast integration (`FeastVectorStore`)
44+
- A custom RAG retriever (`FeastRAGRetriever`) that supports three search modes:
45+
- Text-based search
46+
- Vector-based search
47+
- Hybrid search
48+
- Seamless integration with HuggingFace transformers library and sentence-transformers
49+
- Configurable document formatting and retrieval options
50+
51+
* **feature_repo/feature_store.yaml**
52+
This is the core configuration file for the RAG project's feature store, configuring a Milvus online store on a local provider.
53+
* In order to configure Milvus you should:
54+
- Update `feature_store.yaml` with your Milvus connection details:
55+
- host
56+
- port (default: 19530)
57+
- credentials (if required)
58+
59+
* **__feature_repo/rag_project_repo.py__**
60+
This is the Feast feature repository configuration that defines the schema and data source for Wikipedia passage embeddings.
61+
62+
* **__rag_feast_kfto.ipynb__**
63+
This is a notebook demonstrating the implementation of a RAG system using Feast feature store. The notebook provides:
64+
65+
- A complete end-to-end example of building a RAG system with:
66+
- Data preparation using the Wiki DPR dataset
67+
- Text chunking and preprocessing
68+
- Vector embedding generation using sentence-transformers
69+
- Integration with Milvus vector store
70+
- Inference utilising a custom RagRetriever: FeastRagRetriever
71+
- Uses `all-MiniLM-L6-v2` for generating embeddings
72+
- Implements `granite-3.2-2b-instruct` as the generator model
73+
74+
Open `rag_feast_kfto.ipynb` and follow the steps in the notebook to run the example.
75+
76+
### Helpful Information
77+
- Ensure your Milvus instance is properly configured and running
78+
- Vector dimensions and similarity metrics can be adjusted in the feature store configuration
79+
- The example uses Wikipedia data, but the system can be adapted for other datasets
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Callable, Dict, List, Optional, Union, Any, Tuple
3+
4+
import numpy as np
5+
from feast import FeatureStore, FeatureView
6+
from sentence_transformers import SentenceTransformer
7+
from transformers import RagRetriever
8+
9+
10+
class VectorStore(ABC):
11+
@abstractmethod
12+
def query(
13+
self,
14+
query_vector: Optional[np.ndarray] = None,
15+
query_string: Optional[str] = None,
16+
top_k: int = 10,
17+
) -> List[Dict[str, Any]]:
18+
pass
19+
20+
21+
class FeastVectorStore(VectorStore):
22+
def __init__(self, store: FeatureStore, rag_view: FeatureView, features: List[str]):
23+
self.store = store
24+
self.rag_view = rag_view
25+
self.store.apply([rag_view])
26+
self.features = features
27+
28+
def query(
29+
self,
30+
query_vector: Optional[np.ndarray] = None,
31+
query_string: Optional[str] = None,
32+
top_k: int = 10,
33+
) -> List[Dict[str, Any]]:
34+
35+
distance_metric = "COSINE" if query_vector is not None else None
36+
query_list = query_vector.tolist() if query_vector is not None else None
37+
38+
response = self.store.retrieve_online_documents_v2(
39+
features=self.features,
40+
query=query_list,
41+
query_string=query_string,
42+
top_k=top_k,
43+
distance_metric=distance_metric,
44+
).to_dict()
45+
46+
results = []
47+
for feature_name in self.features:
48+
short_name = feature_name.split(":")[-1]
49+
feature_values = response[short_name]
50+
for i, value in enumerate(feature_values):
51+
if i >= len(results):
52+
results.append({})
53+
results[i][short_name] = value
54+
55+
return results
56+
57+
58+
# Dummy index - an index is required by the HF Transformers RagRetriever class
59+
class FeastIndex:
60+
def __init__(self, vector_store: VectorStore):
61+
self.vector_store = vector_store
62+
63+
def get_top_docs(self, query_vectors: np.ndarray, n_docs: int = 5):
64+
raise NotImplementedError("get_top_docs is not yet implemented.")
65+
66+
def get_doc_dicts(self, doc_ids: List[str]):
67+
raise NotImplementedError("get_doc_dicts is not yet implemented.")
68+
69+
70+
class FeastRAGRetriever(RagRetriever):
71+
VALID_SEARCH_TYPES = {"text", "vector", "hybrid"}
72+
73+
def __init__(
74+
self,
75+
question_encoder_tokenizer,
76+
question_encoder,
77+
generator_tokenizer,
78+
generator_model,
79+
feast_repo_path: str,
80+
vector_store: VectorStore,
81+
search_type: str,
82+
config: Dict[str, Any],
83+
index: FeastIndex,
84+
format_document: Optional[Callable[[Dict[str, Any]], str]] = None,
85+
id_field: str = "",
86+
query_encoder_model: Union[str, SentenceTransformer] = "all-MiniLM-L6-v2",
87+
**kwargs,
88+
):
89+
if search_type.lower() not in self.VALID_SEARCH_TYPES:
90+
raise ValueError(
91+
f"Unsupported search_type {search_type}. "
92+
f"Must be one of: {self.VALID_SEARCH_TYPES}"
93+
)
94+
super().__init__(
95+
config=config,
96+
question_encoder_tokenizer=question_encoder_tokenizer,
97+
generator_tokenizer=generator_tokenizer,
98+
index=index,
99+
init_retrieval=False,
100+
**kwargs,
101+
)
102+
self.question_encoder = question_encoder
103+
self.generator_model = generator_model
104+
self.generator_tokenizer = generator_tokenizer
105+
self.feast = FeatureStore(repo_path=feast_repo_path)
106+
self.vector_store = vector_store
107+
self.search_type = search_type.lower()
108+
self.format_document = format_document or FeastRAGRetriever._default_format_document
109+
self.id_field = id_field
110+
111+
if isinstance(query_encoder_model, str):
112+
self.query_encoder = SentenceTransformer(query_encoder_model)
113+
else:
114+
self.query_encoder = query_encoder_model
115+
116+
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int, query: Optional[str] = None) -> Tuple[np.ndarray, List[Dict[str, str]]]:
117+
# Convert hidden states to query vector by pooling
118+
query_vector = question_hidden_states.mean(dim=1).squeeze().detach().cpu().numpy()
119+
120+
# Decode text query if needed (for hybrid or text search)
121+
if query is None and self.search_type in ("text", "hybrid"):
122+
query = self.question_encoder_tokenizer.decode(
123+
question_hidden_states.argmax(axis=-1),
124+
skip_special_tokens=True
125+
)
126+
127+
if self.search_type == "text":
128+
results = self.vector_store.query(query_string=query, top_k=n_docs)
129+
130+
elif self.search_type == "vector":
131+
results = self.vector_store.query(query_vector=query_vector, top_k=n_docs)
132+
133+
elif self.search_type == "hybrid":
134+
results = self.vector_store.query(
135+
query_string=query,
136+
query_vector=query_vector,
137+
top_k=n_docs
138+
)
139+
else:
140+
raise ValueError(f"Unsupported search type: {self.search_type}")
141+
142+
# Cosine similarity scoring
143+
doc_embeddings = np.array([doc["embedding"] for doc in results])
144+
query_norm = np.linalg.norm(query_vector)
145+
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
146+
147+
query_norm = np.maximum(query_norm, 1e-10)
148+
doc_norms = np.maximum(doc_norms, 1e-10)
149+
150+
similarities = np.dot(doc_embeddings, query_vector) / (doc_norms * query_norm)
151+
doc_scores = similarities.reshape(1, -1)
152+
# passage_text is hardcoded at the moment
153+
doc_dicts = [{"text": doc["passage_text"]} for doc in results]
154+
155+
return doc_scores, doc_dicts
156+
157+
def generate_answer(
158+
self, query: str, top_k: int = 5, max_new_tokens: int = 100
159+
) -> str:
160+
# Convert query to hidden states format expected by retrieve
161+
inputs = self.question_encoder_tokenizer(
162+
query, return_tensors="pt", padding=True, truncation=True
163+
)
164+
question_hidden_states = self.question_encoder(**inputs).last_hidden_state
165+
166+
# Get documents using retrieve method
167+
doc_scores, doc_dicts = self.retrieve(question_hidden_states, n_docs=top_k)
168+
169+
# Format context from retrieved documents
170+
contexts = [doc["text"] for doc in doc_dicts]
171+
context = "\n\n".join(contexts)
172+
173+
prompt = (
174+
f"Use the following context to answer the question. Context:\n{context}\n\n"
175+
f"Question: {query}\nAnswer:"
176+
)
177+
178+
self.generator_tokenizer.pad_token = self.generator_tokenizer.eos_token
179+
inputs = self.generator_tokenizer(
180+
prompt, return_tensors="pt", padding=True, truncation=True
181+
)
182+
input_ids = inputs["input_ids"]
183+
attention_mask = inputs["attention_mask"]
184+
output_ids = self.generator_model.generate(
185+
input_ids=input_ids,
186+
attention_mask=attention_mask,
187+
max_new_tokens=max_new_tokens,
188+
pad_token_id=self.generator_tokenizer.pad_token_id,
189+
)
190+
return self.generator_tokenizer.decode(output_ids[0], skip_special_tokens=True)
191+
192+
@staticmethod
193+
def _default_format_document(doc: Dict[str, Any]) -> str:
194+
lines = []
195+
for key, value in doc.items():
196+
# Skip vectors by checking for long float lists
197+
if (
198+
isinstance(value, list)
199+
and len(value) > 10
200+
and all(isinstance(x, (float, int)) for x in value)
201+
):
202+
continue
203+
lines.append(f"{key.replace('_', ' ').capitalize()}: {value}")
204+
return "\n".join(lines)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
project: ragproject
2+
provider: local
3+
registry: data/registry.db
4+
online_store:
5+
type: milvus
6+
host: # Insert Milvus route host
7+
username: # Insert Milvus username if required
8+
password: # Insert Milvus password if required
9+
port: 19530
10+
vector_enabled: true
11+
embedding_dim: 384
12+
index_type: FLAT
13+
metric_type: COSINE
14+
offline_store:
15+
type: file
16+
entity_key_serialization_version: 3
17+
auth:
18+
type: no_auth
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from datetime import timedelta
2+
3+
from feast import Entity, FeatureView, Field, FileSource, ValueType
4+
from feast.data_format import ParquetFormat
5+
from feast.types import Array, Float32, String
6+
7+
# Define your entity (primary key for feature lookup)
8+
wiki_passage = Entity(
9+
name="passage_id",
10+
join_keys=["passage_id"],
11+
value_type=ValueType.STRING,
12+
description="Unique ID of a Wikipedia passage",
13+
)
14+
15+
parquet_file_path = "data/wiki_dpr.parquet"
16+
17+
# Define offline source
18+
wiki_dpr_source = FileSource(
19+
name="wiki_dpr_source",
20+
file_format=ParquetFormat(),
21+
path=parquet_file_path,
22+
timestamp_field="event_timestamp",
23+
)
24+
25+
# Define the feature view for the Wikipedia passage content
26+
wiki_passage_feature_view = FeatureView(
27+
name="wiki_passages",
28+
entities=[wiki_passage],
29+
ttl=timedelta(days=1),
30+
schema=[
31+
Field(
32+
name="passage_text",
33+
dtype=String,
34+
description="Content of the Wikipedia passage",
35+
),
36+
Field(
37+
name="embedding",
38+
dtype=Array(Float32),
39+
description="vectors",
40+
vector_index=True,
41+
vector_length=384,
42+
vector_search_metric="COSINE",
43+
),
44+
],
45+
online=True,
46+
source=wiki_dpr_source,
47+
description="Content features of Wikipedia passages",
48+
)

0 commit comments

Comments
 (0)