Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,11 @@ def test_config_vector_index_hnsw_and_quantizer_pq(collection_factory: Collectio
[
(Configure.Reranker.cohere(), Rerankers.COHERE, {}),
(
Configure.Reranker.cohere(model="rerank-english-v2.0"),
Configure.Reranker.cohere(
model="rerank-english-v2.0", base_url="https://some-cohere-baseurl.ai/"
),
Rerankers.COHERE,
{"model": "rerank-english-v2.0"},
{"model": "rerank-english-v2.0", "baseURL": "https://some-cohere-baseurl.ai/"},
),
(Configure.Reranker.transformers(), Rerankers.TRANSFORMERS, {}),
],
Expand Down
3 changes: 2 additions & 1 deletion test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,10 +1092,11 @@ def test_config_with_generative(

TEST_CONFIG_WITH_RERANKER = [
(
Configure.Reranker.cohere(model="model"),
Configure.Reranker.cohere(model="model", base_url="https://some.base.url/"),
{
"reranker-cohere": {
"model": "model",
"baseURL": "https://some.base.url/",
},
},
),
Expand Down
15 changes: 13 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)

from deprecation import deprecated as docstring_deprecated
from pydantic import AnyHttpUrl, Field, ValidationInfo, field_validator
from pydantic import AnyHttpUrl, AnyUrl, Field, ValidationInfo, field_validator
from typing_extensions import TypeAlias
from typing_extensions import deprecated as typing_deprecated

Expand Down Expand Up @@ -524,6 +524,13 @@ class _RerankerCohereConfig(_RerankerProvider):
default=Rerankers.COHERE, frozen=True, exclude=True
)
model: Optional[Union[RerankerCohereModel, str]] = Field(default=None)
baseURL: Optional[AnyHttpUrl]

def _to_dict(self) -> Dict[str, Any]:
ret_dict = super()._to_dict()
if self.baseURL is not None:
ret_dict["baseURL"] = self.baseURL.unicode_string()
return ret_dict


class _RerankerCustomConfig(_RerankerProvider):
Expand Down Expand Up @@ -1062,6 +1069,7 @@ def custom(
@staticmethod
def cohere(
model: Optional[Union[RerankerCohereModel, str]] = None,
base_url: Optional[str] = None,
) -> _RerankerProvider:
"""Create a `_RerankerCohereConfig` object for use when reranking using the `reranker-cohere` module.

Expand All @@ -1070,8 +1078,11 @@ def cohere(

Args:
model: The model to use. Defaults to `None`, which uses the server-defined default
base_url: The base URL to send the reranker requests to. Defaults to `None`, which uses the server-defined default.
"""
return _RerankerCohereConfig(model=model)
return _RerankerCohereConfig(
model=model, baseURL=AnyUrl(base_url) if base_url is not None else None
)

@staticmethod
def jinaai(
Expand Down
Loading