Skip to content

Commit 5f7143a

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents ebf35f8 + c22f025 commit 5f7143a

File tree

10 files changed

+499
-0
lines changed

10 files changed

+499
-0
lines changed

examples/rag-llm/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Retrieval-Augmented Generation on OpenShift AI
2+
3+
These examples show how a user can run Retrieval-Augmented Generation using Jupyter Notebooks provided by OpenShift AI.
4+
5+
The huggingface_rag example is based on this HuggingFace blog post - https://huggingface.co/blog/ngxson/make-your-own-rag
6+
7+
8+
## Requirements
9+
10+
* An OpenShift cluster with OpenShift AI (RHOAI) 2.17+ installed:
11+
* The `dashboard` and `workbenches` components enabled
12+
* Sufficient worker node to run workbench with NVIDIA GPUs (Ampere-based or newer recommended) or AMD GPUs (AMD Instinct MI300X or newer recommended)
13+
14+
15+
## Setup
16+
17+
* Access the OpenShift AI dashboard, for example from the top navigation bar menu:
18+
![](./docs/01.png)
19+
* Log in, then go to _Data Science Projects_ and create a project:
20+
![](./docs/02.png)
21+
* Once the project is created, click on _Create a workbench_:
22+
![](./docs/03.png)
23+
* Then create a workbench with the following settings:
24+
* Select the `PyTorch` (or the `ROCm-PyTorch`) notebook image:
25+
![](./docs/04a.png)
26+
* Select the `Medium` container size and add an accelerator:
27+
![](./docs/04b.png)
28+
* Keep the default 20GB workbench storage, it is enough to run the inference from within the workbench:
29+
* Review the configuration and click "Create workbench":
30+
![](./docs/04c.png)
31+
* From "Workbenches" page, click on _Open_ when the workbench you've just created becomes ready:
32+
![](./docs/05.png)
33+
* From the workbench, clone this repository, i.e., `https://github.com/opendatahub-io/distributed-workloads.git`
34+
* Navigate to the `distributed-workloads/examples/rag-llm` directory and open one of available notebooks
35+
36+
You can now proceed with the instructions from the notebook. Enjoy!

examples/rag-llm/docs/01.png

89.8 KB
Loading

examples/rag-llm/docs/02.png

72.6 KB
Loading

examples/rag-llm/docs/03.png

110 KB
Loading

examples/rag-llm/docs/04a.png

79.8 KB
Loading

examples/rag-llm/docs/04b.png

50.2 KB
Loading

examples/rag-llm/docs/04c.png

67.7 KB
Loading

examples/rag-llm/docs/05.png

117 KB
Loading
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "b29cb9f2-e3c0-44cc-8327-7757c5add287",
6+
"metadata": {},
7+
"source": [
8+
"# Setup\n",
9+
"\n",
10+
"Install all required dependencies."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "bff9c793-7ca5-4f3b-8353-b55d3acb3b4e",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"!pip install --quiet --upgrade transformers datasets faiss-cpu"
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"id": "24ef69a7-c616-4b06-b1ad-d3cb98abe7df",
26+
"metadata": {},
27+
"source": [
28+
"# Hugging Face RAG"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"id": "3f0abf06-145f-4644-b25d-823c6ffc58af",
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"# Models\n",
39+
"encoder_model = \"facebook/dpr-ctx_encoder-multiset-base\"\n",
40+
"generator_model = \"facebook/rag-sequence-nq\""
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"id": "43d74dc3-39f1-4d22-9c74-aaedd0131093",
46+
"metadata": {},
47+
"source": [
48+
"Prepare chunk dataset."
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"id": "cb54a2f0-aef6-4308-a8b4-07e9c7cca23b",
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"import urllib.request\n",
59+
"from datasets import Dataset\n",
60+
"\n",
61+
"link = \"https://huggingface.co/ngxson/demo_simple_rag_py/raw/main/cat-facts.txt\"\n",
62+
"dataset_list = []\n",
63+
"\n",
64+
"# Retrieve knowledge from provided link, use every line as a separate chunk.\n",
65+
"for line in urllib.request.urlopen(link):\n",
66+
" dataset_list.append({\"text\": line.decode('utf-8'), \"title\": \"cats\"})\n",
67+
"\n",
68+
"print(f'Loaded {len(dataset_list)} entries')\n",
69+
"\n",
70+
"dataset = Dataset.from_list(dataset_list)"
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"id": "677c95fe-1d36-4dfe-bf0d-1283857e5ee7",
76+
"metadata": {},
77+
"source": [
78+
"Encode dataset chunks into embeddings (vector representations), append embeddings into dataset.\n",
79+
"\n",
80+
"Add faiss index for similarity search."
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": null,
86+
"id": "3c118e29-7fbf-4741-a474-3e5a3d46d8c1",
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"from transformers import (\n",
91+
" DPRContextEncoder,\n",
92+
" DPRContextEncoderTokenizerFast,\n",
93+
")\n",
94+
"import torch\n",
95+
"\n",
96+
"\n",
97+
"torch.set_grad_enabled(False)\n",
98+
"\n",
99+
"ctx_encoder = DPRContextEncoder.from_pretrained(encoder_model)\n",
100+
"ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(encoder_model)\n",
101+
"ds_with_embeddings = dataset.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example[\"text\"], return_tensors=\"pt\"))[0][0].numpy()})\n",
102+
"ds_with_embeddings.add_faiss_index(column='embeddings')\n"
103+
]
104+
},
105+
{
106+
"cell_type": "markdown",
107+
"id": "bc5bb3cd-9785-43fa-b1c4-e16e78b69073",
108+
"metadata": {},
109+
"source": [
110+
"**Specify user query here**"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"id": "9dc40487-bf1c-49ed-9106-0dc46e38820c",
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"input_query = \"what is the name of the tiniest cat\""
121+
]
122+
},
123+
{
124+
"cell_type": "markdown",
125+
"id": "b52d2349-02aa-47d6-a5d0-783e8361feee",
126+
"metadata": {},
127+
"source": [
128+
"Generate response for user query using context from dataset."
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"id": "8c61b305-5f62-4f99-8577-0708ba5e5f28",
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration\n",
139+
"\n",
140+
"tokenizer = RagTokenizer.from_pretrained(generator_model)\n",
141+
"\n",
142+
"# Construct retriever to return relevant context from dataset\n",
143+
"retriever = RagRetriever.from_pretrained(\n",
144+
" generator_model, index_name=\"custom\", indexed_dataset=ds_with_embeddings\n",
145+
")\n",
146+
"\n",
147+
"model = RagSequenceForGeneration.from_pretrained(generator_model, retriever=retriever)\n",
148+
"\n",
149+
"# Move model to GPU\n",
150+
"device = 0\n",
151+
"model = model.to(device)\n",
152+
"\n",
153+
"input_dict = tokenizer.prepare_seq2seq_batch(input_query, return_tensors=\"pt\").to(device)\n",
154+
"\n",
155+
"generated = model.generate(input_ids=input_dict[\"input_ids\"])\n",
156+
"print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])"
157+
]
158+
},
159+
{
160+
"cell_type": "markdown",
161+
"id": "d52ad0e5-de8c-49f8-8e2c-0a4811e4f095",
162+
"metadata": {},
163+
"source": [
164+
"# Cleaning Up\n",
165+
"\n",
166+
"Delete model from GPU."
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": null,
172+
"id": "ff1c1fd1-ac51-42d4-b879-53d466b2c045",
173+
"metadata": {},
174+
"outputs": [],
175+
"source": [
176+
"import torch\n",
177+
"\n",
178+
"\n",
179+
"del model, input_dict\n",
180+
"torch.cuda.empty_cache()"
181+
]
182+
}
183+
],
184+
"metadata": {
185+
"kernelspec": {
186+
"display_name": "Python 3.11",
187+
"language": "python",
188+
"name": "python3"
189+
},
190+
"language_info": {
191+
"codemirror_mode": {
192+
"name": "ipython",
193+
"version": 3
194+
},
195+
"file_extension": ".py",
196+
"mimetype": "text/x-python",
197+
"name": "python",
198+
"nbconvert_exporter": "python",
199+
"pygments_lexer": "ipython3",
200+
"version": "3.11.9"
201+
}
202+
},
203+
"nbformat": 4,
204+
"nbformat_minor": 5
205+
}

0 commit comments

Comments
 (0)