Skip to content

Commit 01840b3

Browse files
authored
Merge pull request #1887 from weaviate/contextual-ai-modules-support
Add Contextual AI's Generative and Reranker Client support
2 parents 4bb315c + 469cc7a commit 01840b3

File tree

12 files changed

+795
-288
lines changed

12 files changed

+795
-288
lines changed

integration/test_collection_openai.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,170 @@ def test_near_text_generate_with_dynamic_rag(
740740
assert g0.debug is None
741741
assert g0.metadata is None
742742
assert g1.metadata is None
743+
744+
745+
@pytest.mark.parametrize("parameter,answer", [("text", "yes"), ("content", "no")])
746+
def test_contextualai_generative_search_single(
747+
collection_factory: CollectionFactory, parameter: str, answer: str
748+
) -> None:
749+
"""Test Contextual AI generative search with single prompt."""
750+
api_key = os.environ.get("CONTEXTUAL_API_KEY")
751+
if api_key is None:
752+
pytest.skip("No Contextual AI API key found.")
753+
754+
collection = collection_factory(
755+
name="TestContextualAIGenerativeSingle",
756+
generative_config=Configure.Generative.contextualai(
757+
model="v2",
758+
max_new_tokens=100,
759+
temperature=0.1,
760+
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context. Answer with yes or no only.",
761+
avoid_commentary=False,
762+
),
763+
vectorizer_config=Configure.Vectorizer.none(),
764+
properties=[
765+
Property(name="text", data_type=DataType.TEXT),
766+
Property(name="content", data_type=DataType.TEXT),
767+
],
768+
headers={"X-Contextual-Api-Key": api_key},
769+
ports=(8086, 50057),
770+
)
771+
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
772+
pytest.skip("Generative search requires Weaviate 1.23.1 or higher")
773+
774+
collection.data.insert_many(
775+
[
776+
DataObject(properties={"text": "bananas are great", "content": "bananas are bad"}),
777+
DataObject(properties={"text": "apples are great", "content": "apples are bad"}),
778+
]
779+
)
780+
781+
res = collection.generate.fetch_objects(
782+
single_prompt=f"is it good or bad based on {{{parameter}}}? Just answer with yes or no without punctuation",
783+
)
784+
for obj in res.objects:
785+
assert obj.generated is not None
786+
assert obj.generated.lower() == answer
787+
assert res.generated is None
788+
789+
790+
def test_contextualai_generative_with_knowledge_parameter(
791+
collection_factory: CollectionFactory,
792+
) -> None:
793+
"""Test Contextual AI generative search with knowledge parameter override."""
794+
api_key = os.environ.get("CONTEXTUAL_API_KEY")
795+
if api_key is None:
796+
pytest.skip("No Contextual AI API key found.")
797+
798+
collection = collection_factory(
799+
name="TestContextualAIGenerativeKnowledge",
800+
generative_config=Configure.Generative.contextualai(
801+
model="v2",
802+
max_new_tokens=100,
803+
temperature=0.1,
804+
system_prompt="You are a helpful assistant.",
805+
avoid_commentary=False,
806+
),
807+
vectorizer_config=Configure.Vectorizer.none(),
808+
properties=[
809+
Property(name="text", data_type=DataType.TEXT),
810+
],
811+
headers={"X-Contextual-Api-Key": api_key},
812+
ports=(8086, 50057),
813+
)
814+
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
815+
pytest.skip("Generative search requires Weaviate 1.23.1 or higher")
816+
817+
collection.data.insert_many(
818+
[
819+
DataObject(properties={"text": "base knowledge"}),
820+
]
821+
)
822+
823+
# Test with knowledge parameter override
824+
res = collection.generate.fetch_objects(
825+
single_prompt="What is the custom knowledge?",
826+
config=GenerativeConfig.contextualai(
827+
knowledge=["Custom knowledge override", "Additional context"],
828+
),
829+
)
830+
for obj in res.objects:
831+
assert obj.generated is not None
832+
assert isinstance(obj.generated, str)
833+
834+
835+
def test_contextualai_generative_and_rerank_combined(collection_factory: CollectionFactory) -> None:
836+
"""Test Contextual AI generative search combined with reranking."""
837+
contextual_api_key = os.environ.get("CONTEXTUAL_API_KEY")
838+
if contextual_api_key is None:
839+
pytest.skip("No Contextual AI API key found.")
840+
841+
collection = collection_factory(
842+
name="TestContextualAIGenerativeAndRerank",
843+
generative_config=Configure.Generative.contextualai(
844+
model="v2",
845+
max_new_tokens=100,
846+
temperature=0.1,
847+
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
848+
avoid_commentary=False,
849+
),
850+
reranker_config=Configure.Reranker.contextualai(
851+
model="ctxl-rerank-v2-instruct-multilingual",
852+
instruction="Prioritize documents that contain the query term",
853+
),
854+
vectorizer_config=Configure.Vectorizer.text2vec_openai(),
855+
properties=[Property(name="text", data_type=DataType.TEXT)],
856+
headers={"X-Contextual-Api-Key": contextual_api_key},
857+
ports=(8086, 50057),
858+
)
859+
if collection._connection._weaviate_version < _ServerVersion(1, 23, 1):
860+
pytest.skip("Generative reranking requires Weaviate 1.23.1 or higher")
861+
862+
insert = collection.data.insert_many(
863+
[{"text": "This is a test"}, {"text": "This is another test"}]
864+
)
865+
uuid1 = insert.uuids[0]
866+
vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector
867+
assert vector1 is not None
868+
869+
for _idx, query in enumerate(
870+
[
871+
lambda: collection.generate.bm25(
872+
"test",
873+
rerank=Rerank(prop="text", query="another"),
874+
single_prompt="What is it? {text}",
875+
),
876+
lambda: collection.generate.hybrid(
877+
"test",
878+
rerank=Rerank(prop="text", query="another"),
879+
single_prompt="What is it? {text}",
880+
),
881+
lambda: collection.generate.near_object(
882+
uuid1,
883+
rerank=Rerank(prop="text", query="another"),
884+
single_prompt="What is it? {text}",
885+
),
886+
lambda: collection.generate.near_vector(
887+
vector1["default"],
888+
rerank=Rerank(prop="text", query="another"),
889+
single_prompt="What is it? {text}",
890+
),
891+
lambda: collection.generate.near_text(
892+
"test",
893+
rerank=Rerank(prop="text", query="another"),
894+
single_prompt="What is it? {text}",
895+
),
896+
]
897+
):
898+
objects = query().objects
899+
assert len(objects) == 2
900+
assert objects[0].metadata.rerank_score is not None
901+
assert objects[0].generated is not None
902+
assert objects[1].metadata.rerank_score is not None
903+
assert objects[1].generated is not None
904+
905+
assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore
906+
0
907+
].metadata.rerank_score > [
908+
obj for obj in objects if "another" not in obj.properties["text"]
909+
][0].metadata.rerank_score

integration/test_collection_rerank.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,61 @@ def test_queries_with_rerank_and_group_by(collection_factory: CollectionFactory)
138138
].rerank_score > [group for prop, group in ret.groups.items() if "another" not in prop][
139139
0
140140
].rerank_score
141+
142+
143+
def test_queries_with_rerank_contextualai(collection_factory: CollectionFactory) -> None:
144+
"""Test Contextual AI reranker with various query types."""
145+
api_key = os.environ.get("CONTEXTUAL_API_KEY")
146+
if api_key is None:
147+
pytest.skip("No Contextual AI API key found.")
148+
149+
collection = collection_factory(
150+
name="Test_test_queries_with_rerank_contextualai",
151+
reranker_config=wvc.config.Configure.Reranker.contextualai(
152+
model="ctxl-rerank-v2-instruct-multilingual",
153+
instruction="Prioritize documents that contain the query term",
154+
),
155+
vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(),
156+
properties=[wvc.config.Property(name="text", data_type=wvc.config.DataType.TEXT)],
157+
headers={"X-Contextual-Api-Key": api_key},
158+
ports=(8086, 50057),
159+
)
160+
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
161+
pytest.skip("Reranking requires Weaviate 1.23.1 or higher")
162+
163+
insert = collection.data.insert_many(
164+
[{"text": "This is a test"}, {"text": "This is another test"}]
165+
)
166+
uuid1 = insert.uuids[0]
167+
vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector
168+
assert vector1 is not None
169+
170+
for _idx, query in enumerate(
171+
[
172+
lambda: collection.query.bm25(
173+
"test", rerank=wvc.query.Rerank(prop="text", query="another")
174+
),
175+
lambda: collection.query.hybrid(
176+
"test", rerank=wvc.query.Rerank(prop="text", query="another")
177+
),
178+
lambda: collection.query.near_object(
179+
uuid1, rerank=wvc.query.Rerank(prop="text", query="another")
180+
),
181+
lambda: collection.query.near_vector(
182+
vector1["default"], rerank=wvc.query.Rerank(prop="text", query="another")
183+
),
184+
lambda: collection.query.near_text(
185+
"test", rerank=wvc.query.Rerank(prop="text", query="another")
186+
),
187+
]
188+
):
189+
objects = query().objects
190+
assert len(objects) == 2
191+
assert objects[0].metadata.rerank_score is not None
192+
assert objects[1].metadata.rerank_score is not None
193+
194+
assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore
195+
0
196+
].metadata.rerank_score > [
197+
obj for obj in objects if "another" not in obj.properties["text"]
198+
][0].metadata.rerank_score

test/collection/test_classes_generative.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,29 @@ def test_generative_parameters_images_parsing(
414414
),
415415
),
416416
),
417+
(
418+
GenerativeConfig.contextualai(
419+
model="v2",
420+
max_new_tokens=100,
421+
temperature=0.5,
422+
top_p=0.9,
423+
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
424+
avoid_commentary=False,
425+
knowledge=["knowledge1", "knowledge2"],
426+
)._to_grpc(_GenerativeConfigRuntimeOptions(return_metadata=True)),
427+
generative_pb2.GenerativeProvider(
428+
return_metadata=True,
429+
contextualai=generative_pb2.GenerativeContextualAI(
430+
model="v2",
431+
max_new_tokens=100,
432+
temperature=0.5,
433+
top_p=0.9,
434+
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
435+
avoid_commentary=False,
436+
knowledge=base_pb2.TextArray(values=["knowledge1", "knowledge2"]),
437+
),
438+
),
439+
),
417440
],
418441
)
419442
def test_generative_provider_to_grpc(

test/collection/test_config.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,34 @@ def test_config_with_vectorizer_and_properties(
10431043
}
10441044
},
10451045
),
1046+
(
1047+
Configure.Generative.contextualai(),
1048+
{
1049+
"generative-contextualai": {},
1050+
},
1051+
),
1052+
(
1053+
Configure.Generative.contextualai(
1054+
model="v2",
1055+
temperature=0.7,
1056+
top_p=0.9,
1057+
max_new_tokens=512,
1058+
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
1059+
avoid_commentary=False,
1060+
knowledge=["fact1", "fact2"],
1061+
),
1062+
{
1063+
"generative-contextualai": {
1064+
"model": "v2",
1065+
"temperature": 0.7,
1066+
"topP": 0.9,
1067+
"maxNewTokens": 512,
1068+
"systemPrompt": "You are a helpful assistant that provides accurate and informative responses based on the given context.",
1069+
"avoidCommentary": False,
1070+
"knowledge": ["fact1", "fact2"],
1071+
}
1072+
},
1073+
),
10461074
]
10471075

10481076

@@ -1125,6 +1153,26 @@ def test_config_with_generative(
11251153
"reranker-transformers": {},
11261154
},
11271155
),
1156+
(
1157+
Configure.Reranker.contextualai(),
1158+
{
1159+
"reranker-contextualai": {},
1160+
},
1161+
),
1162+
(
1163+
Configure.Reranker.contextualai(
1164+
model="ctxl-rerank-v2-instruct-multilingual",
1165+
instruction="Prioritize recent documents",
1166+
top_n=5,
1167+
),
1168+
{
1169+
"reranker-contextualai": {
1170+
"model": "ctxl-rerank-v2-instruct-multilingual",
1171+
"instruction": "Prioritize recent documents",
1172+
"topN": 5,
1173+
}
1174+
},
1175+
),
11281176
]
11291177

11301178

0 commit comments

Comments
 (0)