diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index ccbd998af..8b0c8466f 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -913,9 +913,11 @@ def test_config_vector_index_hnsw_and_quantizer_pq(collection_factory: Collectio [ (Configure.Reranker.cohere(), Rerankers.COHERE, {}), ( - Configure.Reranker.cohere(model="rerank-english-v2.0"), + Configure.Reranker.cohere( + model="rerank-english-v2.0", base_url="https://some-cohere-baseurl.ai/" + ), Rerankers.COHERE, - {"model": "rerank-english-v2.0"}, + {"model": "rerank-english-v2.0", "baseURL": "https://some-cohere-baseurl.ai/"}, ), (Configure.Reranker.transformers(), Rerankers.TRANSFORMERS, {}), ], diff --git a/test/collection/test_config.py b/test/collection/test_config.py index b9b57e36d..1964e248d 100644 --- a/test/collection/test_config.py +++ b/test/collection/test_config.py @@ -1092,10 +1092,11 @@ def test_config_with_generative( TEST_CONFIG_WITH_RERANKER = [ ( - Configure.Reranker.cohere(model="model"), + Configure.Reranker.cohere(model="model", base_url="https://some.base.url/"), { "reranker-cohere": { "model": "model", + "baseURL": "https://some.base.url/", }, }, ), diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 99cdf5ff7..1263960ab 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -14,7 +14,7 @@ ) from deprecation import deprecated as docstring_deprecated -from pydantic import AnyHttpUrl, Field, ValidationInfo, field_validator +from pydantic import AnyHttpUrl, AnyUrl, Field, ValidationInfo, field_validator from typing_extensions import TypeAlias from typing_extensions import deprecated as typing_deprecated @@ -524,6 +524,13 @@ class _RerankerCohereConfig(_RerankerProvider): default=Rerankers.COHERE, frozen=True, exclude=True ) model: Optional[Union[RerankerCohereModel, str]] = Field(default=None) + baseURL: Optional[AnyHttpUrl] + + def _to_dict(self) -> Dict[str, Any]: + ret_dict = super()._to_dict() + if self.baseURL is not None: + ret_dict["baseURL"] = self.baseURL.unicode_string() + return ret_dict class _RerankerCustomConfig(_RerankerProvider): @@ -1062,6 +1069,7 @@ def custom( @staticmethod def cohere( model: Optional[Union[RerankerCohereModel, str]] = None, + base_url: Optional[str] = None, ) -> _RerankerProvider: """Create a `_RerankerCohereConfig` object for use when reranking using the `reranker-cohere` module. @@ -1070,8 +1078,11 @@ def cohere( Args: model: The model to use. Defaults to `None`, which uses the server-defined default + base_url: The base URL to send the reranker requests to. Defaults to `None`, which uses the server-defined default. """ - return _RerankerCohereConfig(model=model) + return _RerankerCohereConfig( + model=model, baseURL=AnyUrl(base_url) if base_url is not None else None + ) @staticmethod def jinaai(