Skip to content

Commit 6369417

Browse files
authored
examples: adopt RAG examples for remote execution (#3117)
To be able to run the RAG DAG in a deployment we need non-local file local storage. The POC was build to pass data between jobs using a local file. Since I want to deploy the jobs I need a way to pass data between them without that since they do not share a file system. Postgres based storage was created for that. So moved created one and adopt it. It's currently copied in both jobs. I will refactor it away after this PR. I also ended up removing NLKT from everywhere. And also few doc fixes
1 parent ddd7ac1 commit 6369417

29 files changed

+476
-338
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2021-2024 VMware, Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import pickle
4+
from typing import Any
5+
from typing import List
6+
from typing import Optional
7+
from typing import Union
8+
9+
from common.storage import IStorage
10+
from sqlalchemy import Column
11+
from sqlalchemy import create_engine
12+
from sqlalchemy import LargeBinary
13+
from sqlalchemy import MetaData
14+
from sqlalchemy import select
15+
from sqlalchemy import String
16+
from sqlalchemy import Table
17+
from sqlalchemy.exc import IntegrityError
18+
19+
20+
class DatabaseStorage(IStorage):
21+
def __init__(self, connection_string: str):
22+
self.engine = create_engine(connection_string)
23+
self.metadata = MetaData()
24+
self.table = Table(
25+
"vdk_storage",
26+
self.metadata,
27+
Column("name", String, primary_key=True),
28+
Column("content", LargeBinary),
29+
Column("content_type", String),
30+
)
31+
self.metadata.create_all(self.engine)
32+
33+
def store(self, name: str, content: Union[str, bytes, Any]) -> None:
34+
serialized_content, content_type = self._serialize_content(content)
35+
ins = self.table.insert().values(
36+
name=name, content=serialized_content, content_type=content_type
37+
)
38+
try:
39+
with self.engine.connect() as conn:
40+
conn.execute(ins)
41+
conn.commit()
42+
except IntegrityError:
43+
# Handle duplicate name by updating existing content
44+
upd = (
45+
self.table.update()
46+
.where(self.table.c.name == name)
47+
.values(content=serialized_content, content_type=content_type)
48+
)
49+
with self.engine.connect() as conn:
50+
conn.execute(upd)
51+
conn.commit()
52+
53+
def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
54+
sel = self.table.select().where(self.table.c.name == name)
55+
with self.engine.connect() as conn:
56+
result = conn.execute(sel).fetchone()
57+
if result:
58+
return self._deserialize_content(result.content, result.content_type)
59+
return None
60+
61+
def list_contents(self) -> List[str]:
62+
sel = select(self.table.c.name)
63+
with self.engine.connect() as conn:
64+
result = conn.execute(sel).fetchall()
65+
return [row[0] for row in result]
66+
67+
def remove(self, name: str) -> bool:
68+
del_stmt = self.table.delete().where(self.table.c.name == name)
69+
with self.engine.connect() as conn:
70+
result = conn.execute(del_stmt)
71+
conn.commit()
72+
return result.rowcount > 0
73+
74+
@staticmethod
75+
def _serialize_content(content: Union[str, bytes, Any]) -> tuple[bytes, str]:
76+
if isinstance(content, bytes):
77+
return content, "bytes"
78+
elif isinstance(content, str):
79+
return content.encode(), "string"
80+
else:
81+
# Fallback to pickle for other types
82+
return pickle.dumps(content), "pickle"
83+
84+
@staticmethod
85+
def _deserialize_content(content: bytes, content_type: Optional[str]) -> Any:
86+
if content_type == "pickle":
87+
return pickle.loads(content)
88+
if content_type == "string":
89+
return content.decode()
90+
return content
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2021-2024 VMware, Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import json
4+
import os
5+
from typing import Any
6+
from typing import List
7+
from typing import Optional
8+
from typing import Union
9+
10+
from storage import IStorage
11+
12+
13+
class FileStorage(IStorage):
14+
def __init__(self, base_path: str):
15+
self.base_path = base_path
16+
if not os.path.exists(self.base_path):
17+
os.makedirs(self.base_path)
18+
19+
def _get_file_path(self, name: str) -> str:
20+
return os.path.join(self.base_path, name)
21+
22+
def store(
23+
self,
24+
name: str,
25+
content: Union[str, bytes, Any],
26+
content_type: Optional[str] = None,
27+
) -> None:
28+
file_path = self._get_file_path(name)
29+
with open(file_path, "w") as file:
30+
if isinstance(content, (str, bytes)):
31+
# Directly save strings and bytes
32+
file.write(content if isinstance(content, str) else content.decode())
33+
else:
34+
# Assume JSON serializable for other types
35+
json.dump(content, file)
36+
37+
def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
38+
file_path = self._get_file_path(name)
39+
if not os.path.exists(file_path):
40+
return None
41+
with open(file_path) as file:
42+
try:
43+
return json.load(file)
44+
except json.JSONDecodeError:
45+
# Content was not JSON, return as string
46+
file.seek(0)
47+
return file.read()
48+
49+
def list_contents(self) -> List[str]:
50+
return os.listdir(self.base_path)
51+
52+
def remove(self, name: str) -> bool:
53+
file_path = self._get_file_path(name)
54+
if os.path.exists(file_path):
55+
os.remove(file_path)
56+
return True
57+
return False
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2021-2024 VMware, Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Any
4+
from typing import List
5+
from typing import Optional
6+
from typing import Union
7+
8+
9+
class IStorage:
10+
def store(self, name: str, content: Union[str, bytes, Any]) -> None:
11+
"""
12+
Stores the given content under the specified name. If the content is not a string or bytes,
13+
the method tries to serialize it based on the content_type (if provided) or infers the type.
14+
15+
:param name: The unique name to store the content under.
16+
:param content: The content to store. Can be of type str, bytes, or any serializable type.
17+
"""
18+
pass
19+
20+
def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
21+
"""
22+
Retrieves the content stored under the specified name. The method attempts to deserialize
23+
the content to its original type if possible.
24+
25+
:param name: The name of the content to retrieve.
26+
:return: The retrieved content, which can be of type str, bytes, or any deserialized Python object.
27+
Returns None if the content does not exist.
28+
"""
29+
pass
30+
31+
def list_contents(self) -> List[str]:
32+
"""
33+
Lists the names of all stored contents.
34+
35+
:return: A list of names representing the stored contents.
36+
"""
37+
pass
38+
39+
def remove(self, name: str) -> bool:
40+
"""
41+
Removes the content stored under the specified name.
42+
43+
:param name: The name of the content to remove.
44+
:return: True if the content was successfully removed, False otherwise.
45+
"""
46+
pass

examples/confluence-reader/fetch_confluence_space.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import json
44
import logging
55
import os
6+
import pathlib
67
from datetime import datetime
78

9+
from common.database_storage import DatabaseStorage
810
from confluence_document import ConfluenceDocument
911
from langchain_community.document_loaders import ConfluenceLoader
1012
from vdk.api.job_input import IJobInput
1113

14+
1215
log = logging.getLogger(__name__)
1316

1417

@@ -109,17 +112,12 @@ def __init__(self, confluence_url, token, space_key):
109112
self.loader = ConfluenceLoader(url=self.confluence_url, token=self.token)
110113

111114
def fetch_confluence_documents(self, cql_query):
112-
try:
113-
# TODO: think about configurable limits ? or some streaming solution
114-
# How do we fit all documents in memory ?
115-
raw_documents = self.loader.load(cql=cql_query, limit=50, max_pages=200)
116-
return [
117-
ConfluenceDocument(doc.metadata, doc.page_content)
118-
for doc in raw_documents
119-
]
120-
except Exception as e:
121-
log.error(f"Error fetching documents from Confluence: {e}")
122-
raise e
115+
# TODO: think about configurable limits ? or some streaming solution
116+
# How do we fit all documents in memory ?
117+
raw_documents = self.loader.load(cql=cql_query, limit=50, max_pages=200)
118+
return [
119+
ConfluenceDocument(doc.metadata, doc.page_content) for doc in raw_documents
120+
]
123121

124122
def fetch_updated_pages_in_confluence_space(
125123
self, last_date="1900-02-06 17:54", parent_page_id=None
@@ -147,9 +145,10 @@ def fetch_all_pages_in_confluence_space(self, parent_page_id=None):
147145

148146

149147
def get_value(job_input, key: str, default_value=None):
150-
return job_input.get_arguments().get(
151-
key, job_input.get_property(key, os.environ.get(key.upper(), default_value))
152-
)
148+
value = os.environ.get(key.upper(), default_value)
149+
value = job_input.get_property(key, value)
150+
value = job_input.get_secret(key, value)
151+
return job_input.get_arguments().get(key, value)
153152

154153

155154
def set_property(job_input: IJobInput, key, value):
@@ -165,12 +164,20 @@ def run(job_input: IJobInput):
165164
token = get_value(job_input, "confluence_token")
166165
space_key = get_value(job_input, "confluence_space_key")
167166
parent_page_id = get_value(job_input, "confluence_parent_page_id")
168-
last_date = get_value(job_input, "last_date", "1900-01-01 12:00")
169-
data_file = get_value(
170-
job_input,
171-
"data_file",
172-
os.path.join(job_input.get_temporary_write_directory(), "confluence_data.json"),
167+
last_date = (
168+
job_input.get_property(confluence_url, {})
169+
.setdefault(space_key, {})
170+
.setdefault(parent_page_id, {})
171+
.get("last_date", "1900-01-01 12:00")
173172
)
173+
data_file = os.path.join(
174+
job_input.get_temporary_write_directory(), "confluence_data.json"
175+
)
176+
storage_name = get_value(job_input, "storage_name", "confluence_data")
177+
storage = DatabaseStorage(get_value(job_input, "storage_connection_string"))
178+
# TODO: this is not optimal . We just care about the IDs, we should not need to retrieve everything
179+
data = storage.retrieve(storage_name)
180+
pathlib.Path(data_file).write_text(data if data else "[]")
174181

175182
confluence_reader = ConfluenceDataSource(confluence_url, token, space_key)
176183

@@ -189,3 +196,8 @@ def run(job_input: IJobInput):
189196
data_file,
190197
confluence_reader.fetch_all_pages_in_confluence_space(parent_page_id),
191198
)
199+
200+
# TODO: it would be better to save each page in separate row.
201+
# But that's quick solution for now to pass the data to the next job
202+
203+
storage.store(storage_name, pathlib.Path(data_file).read_text())

examples/confluence-reader/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
atlassian-python-api
66
langchain_community
77
lxml
8+
psycopg2-binary
9+
sqlalchemy

examples/fetch-embed-job-example/10_fetch_confluence_space.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)