Skip to content

Commit 10fe155

Browse files
Timtech4uCShorten
andauthored
Weaviate multitenancy support (#2199)
* feat: weaviate multitenancy support * update `.tenant` to `.with_tenant` * Update weaviate_rm.py * Update WeaviateRM.md --------- Co-authored-by: Connor Shorten <[email protected]>
1 parent c86852b commit 10fe155

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

docs/docs/deep-dive/retrieval_models_clients/WeaviateRM.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ WeaviateRM(
1919
)
2020
```
2121

22+
## Using Multitenancy
23+
Multi-tenancy allows a collection to efficiently serve isolated groups of data. Each "tenant" in a multi-tenant collection can only access its own data, while sharing the same data structure and settings.
24+
25+
If your Weaviate instance is tenant-aware, you can provide a tenant_id in the WeaviateRM constructor or as a keyword argument:
26+
27+
```python
28+
retriever_model = WeaviateRM(
29+
weaviate_collection_name="<WEAVIATE_COLLECTION>",
30+
weaviate_client=weaviate_client,
31+
tenant_id="tenant123"
32+
)
33+
34+
results = retriever_model("Your query here", tenant_id="tenantXYZ")
35+
```
36+
When tenant_id is specified, this will scope all retrieval requests to the tenant ID provided.
37+
2238
## Under the Hood
2339

2440
`forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None, **kwargs) -> dspy.Prediction`

dspy/retrieve/weaviate_rm.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class WeaviateRM(dspy.Retrieve):
2424
weaviate_collection_name (str): The name of the Weaviate collection.
2525
weaviate_client (WeaviateClient): An instance of the Weaviate client.
2626
k (int, optional): The default number of top passages to retrieve. Default to 3.
27+
tenant_id (str, optional): The tenant to retrieve objects from.
2728
2829
Examples:
2930
Below is a code snippet that shows how to use Weaviate as the default retriever:
@@ -51,11 +52,13 @@ def __init__(
5152
weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client],
5253
weaviate_collection_text_key: Optional[str] = "content",
5354
k: int = 3,
55+
tenant_id: Optional[str] = None,
5456
):
5557
self._weaviate_collection_name = weaviate_collection_name
5658
self._weaviate_client = weaviate_client
5759
self._weaviate_collection = self._weaviate_client.collections.get(self._weaviate_collection_name)
5860
self._weaviate_collection_text_key = weaviate_collection_text_key
61+
self._tenant_id = tenant_id
5962

6063
# Check the type of weaviate_client (this is added to support v3 and v4)
6164
if hasattr(weaviate_client, "collections"):
@@ -82,26 +85,23 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
8285
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
8386
queries = [q for q in queries if q]
8487
passages, parsed_results = [], []
88+
tenant = kwargs.pop("tenant_id", self._tenant_id)
8589
for query in queries:
8690
if self._client_type == "WeaviateClient":
87-
results = self._weaviate_collection.query.hybrid(
88-
query=query,
89-
limit=k,
90-
**kwargs,
91-
)
91+
if tenant:
92+
results = self._weaviate_collection.query.with_tenant(tenant).hybrid(query=query, limit=k, **kwargs)
93+
else:
94+
results = self._weaviate_collection.query.hybrid(query=query, limit=k, **kwargs)
9295

9396
parsed_results = [result.properties[self._weaviate_collection_text_key] for result in results.objects]
9497

9598
elif self._client_type == "Client":
96-
results = (
97-
self._weaviate_client.query.get(
98-
self._weaviate_collection_name,
99-
[self._weaviate_collection_text_key],
99+
q = self._weaviate_client.query.get(
100+
self._weaviate_collection_name, [self._weaviate_collection_text_key]
100101
)
101-
.with_hybrid(query=query)
102-
.with_limit(k)
103-
.do()
104-
)
102+
if tenant:
103+
q = q.with_tenant(tenant)
104+
results = q.with_hybrid(query=query).with_limit(k).do()
105105

106106
results = results["data"]["Get"][self._weaviate_collection_name]
107107
parsed_results = [result[self._weaviate_collection_text_key] for result in results]
@@ -115,7 +115,7 @@ def get_objects(self, num_samples: int, fields: List[str]) -> List[dict]:
115115
if self._client_type == "WeaviateClient":
116116
objects = []
117117
counter = 0
118-
for item in self._weaviate_collection.iterator():
118+
for item in self._weaviate_collection.iterator(): # TODO: add tenancy scoping
119119
if counter >= num_samples:
120120
break
121121
new_object = {}
@@ -133,6 +133,6 @@ def insert(self, new_object_properties: dict):
133133
self._weaviate_collection.data.insert(
134134
properties=new_object_properties,
135135
uuid=get_valid_uuid(uuid4())
136-
)
136+
) # TODO: add tenancy scoping
137137
else:
138138
raise AttributeError("`insert` is not supported for the v3 Weaviate Python client, please upgrade to v4.")

0 commit comments

Comments
 (0)