Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_create_export_and_recreate(client: weaviate.WeaviateClient, request: Su
assert export.generative_config is not None
assert export.generative_config.generative == GenerativeSearches.COHERE
assert export.generative_config.model["model"] == "command-r-plus"
assert export.generative_config.model["kProperty"] == 10
assert export.generative_config.model["k"] == 10

client.collections.delete([name1, name2])
assert not client.collections.exists(name1)
Expand Down
48 changes: 30 additions & 18 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,11 @@ def test_config_with_vectorizer_and_properties(
{
"generative-openai": {
"model": "gpt-4",
"frequencyPenaltyProperty": 0.5,
"maxTokensProperty": 100,
"presencePenaltyProperty": 0.5,
"temperatureProperty": 0.5,
"topPProperty": 0.5,
"frequencyPenalty": 0.5,
"maxTokens": 100,
"presencePenalty": 0.5,
"temperature": 0.5,
"topP": 0.5,
"baseURL": "https://api.openai.com/",
"reasoningEffort": "high",
"verbosity": "verbose",
Expand All @@ -861,19 +861,17 @@ def test_config_with_vectorizer_and_properties(
model="model",
k=10,
max_tokens=100,
return_likelihoods="ALL",
stop_sequences=["stop"],
temperature=0.5,
base_url="https://api.cohere.ai",
),
{
"generative-cohere": {
"model": "model",
"kProperty": 10,
"maxTokensProperty": 100,
"returnLikelihoodsProperty": "ALL",
"stopSequencesProperty": ["stop"],
"temperatureProperty": 0.5,
"k": 10,
"maxTokens": 100,
"stopSequences": ["stop"],
"temperature": 0.5,
"baseURL": "https://api.cohere.ai/",
}
},
Expand Down Expand Up @@ -939,12 +937,26 @@ def test_config_with_vectorizer_and_properties(
},
),
(
Configure.Generative.aws(model="cohere.command-light-text-v14", region="us-east-1"),
Configure.Generative.aws(
model="cohere.command-light-text-v14",
region="us-east-1",
service="bedrock",
endpoint="custom-endpoint",
max_tokens=100,
temperature=0.5,
target_model="target-model",
target_variant="target-variant",
),
{
"generative-aws": {
"model": "cohere.command-light-text-v14",
"region": "us-east-1",
"service": "bedrock",
"endpoint": "custom-endpoint",
"maxTokens": 100,
"temperature": 0.5,
"targetModel": "target-model",
"targetVariant": "target-variant",
}
},
),
Expand All @@ -970,13 +982,13 @@ def test_config_with_vectorizer_and_properties(
),
{
"generative-openai": {
"deploymentId": "id",
"resourceName": "name",
"frequencyPenaltyProperty": 0.5,
"maxTokensProperty": 100,
"presencePenaltyProperty": 0.5,
"temperatureProperty": 0.5,
"topPProperty": 0.5,
"deploymentId": "id",
"frequencyPenalty": 0.5,
"maxTokens": 100,
"presencePenalty": 0.5,
"temperature": 0.5,
"topP": 0.5,
"baseURL": "https://api.openai.com/",
}
},
Expand Down
74 changes: 46 additions & 28 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class _GenerativeXai(_GenerativeProvider):
model: Optional[str]
maxTokens: Optional[int]
baseURL: Optional[str]
topP: Optional[float]


class _GenerativeFriendliai(_GenerativeProvider):
Expand All @@ -416,11 +417,11 @@ class _GenerativeOpenAIConfigBase(_GenerativeProvider):
default=GenerativeSearches.OPENAI, frozen=True, exclude=True
)
baseURL: Optional[AnyHttpUrl]
frequencyPenaltyProperty: Optional[float]
presencePenaltyProperty: Optional[float]
maxTokensProperty: Optional[int]
temperatureProperty: Optional[float]
topPProperty: Optional[float]
frequencyPenalty: Optional[float]
presencePenalty: Optional[float]
maxTokens: Optional[int]
temperature: Optional[float]
topP: Optional[float]

def _to_dict(self) -> Dict[str, Any]:
ret_dict = super()._to_dict()
Expand All @@ -445,12 +446,12 @@ class _GenerativeCohereConfig(_GenerativeProvider):
default=GenerativeSearches.COHERE, frozen=True, exclude=True
)
baseURL: Optional[AnyHttpUrl]
kProperty: Optional[int]
k: Optional[int]
model: Optional[str]
maxTokensProperty: Optional[int]
returnLikelihoodsProperty: Optional[str]
stopSequencesProperty: Optional[List[str]]
temperatureProperty: Optional[float]
maxTokens: Optional[int]
# returnLikelihoods: Optional[str] # Not implemented server-side
stopSequences: Optional[List[str]]
temperature: Optional[float]

def _to_dict(self) -> Dict[str, Any]:
ret_dict = super()._to_dict()
Expand Down Expand Up @@ -481,6 +482,9 @@ class _GenerativeAWSConfig(_GenerativeProvider):
model: Optional[str]
endpoint: Optional[str]
maxTokens: Optional[int]
temperature: Optional[float]
targetModel: Optional[str]
targetVariant: Optional[str]


class _GenerativeAnthropicConfig(_GenerativeProvider):
Expand Down Expand Up @@ -691,6 +695,7 @@ def xai(
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
) -> _GenerativeProvider:
"""Create a `_GenerativeXai` object for use when performing AI generation using the `generative-xai` module.

Expand All @@ -699,9 +704,10 @@ def xai(
model: The model to use. Defaults to `None`, which uses the server-defined default
temperature: The temperature to use. Defaults to `None`, which uses the server-defined default
max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
top_p: The top P value to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeXai(
model=model, temperature=temperature, maxTokens=max_tokens, baseURL=base_url
model=model, temperature=temperature, maxTokens=max_tokens, topP=top_p, baseURL=base_url
)

@staticmethod
Expand Down Expand Up @@ -750,12 +756,12 @@ def openai(
"""
return _GenerativeOpenAIConfig(
baseURL=base_url,
frequencyPenaltyProperty=frequency_penalty,
maxTokensProperty=max_tokens,
frequencyPenalty=frequency_penalty,
maxTokens=max_tokens,
model=model,
presencePenaltyProperty=presence_penalty,
temperatureProperty=temperature,
topPProperty=top_p,
presencePenalty=presence_penalty,
temperature=temperature,
topP=top_p,
verbosity=verbosity,
reasoningEffort=reasoning_effort,
)
Expand Down Expand Up @@ -789,12 +795,12 @@ def azure_openai(
return _GenerativeAzureOpenAIConfig(
baseURL=base_url,
deploymentId=deployment_id,
frequencyPenaltyProperty=frequency_penalty,
maxTokensProperty=max_tokens,
presencePenaltyProperty=presence_penalty,
frequencyPenalty=frequency_penalty,
maxTokens=max_tokens,
presencePenalty=presence_penalty,
resourceName=resource_name,
temperatureProperty=temperature,
topPProperty=top_p,
temperature=temperature,
topP=top_p,
)

@staticmethod
Expand All @@ -816,19 +822,18 @@ def cohere(
model: The model to use. Defaults to `None`, which uses the server-defined default
k: The number of sequences to generate. Defaults to `None`, which uses the server-defined default
max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
return_likelihoods: Whether to return the likelihoods. Defaults to `None`, which uses the server-defined default
return_likelihoods: (Deprecated) The return likelihoods setting to use. No longer has any effect.
stop_sequences: The stop sequences to use. Defaults to `None`, which uses the server-defined default
temperature: The temperature to use. Defaults to `None`, which uses the server-defined default
base_url: The base URL where the API request should go. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeCohereConfig(
baseURL=base_url,
kProperty=k,
maxTokensProperty=max_tokens,
k=k,
maxTokens=max_tokens,
model=model,
returnLikelihoodsProperty=return_likelihoods,
stopSequencesProperty=stop_sequences,
temperatureProperty=temperature,
stopSequences=stop_sequences,
temperature=temperature,
)

@staticmethod
Expand Down Expand Up @@ -916,6 +921,9 @@ def aws(
endpoint: Optional[str] = None,
service: Union[AWSService, str] = "bedrock",
max_tokens: Optional[int] = None,
target_model: Optional[str] = None,
target_variant: Optional[str] = None,
temperature: Optional[float] = None,
) -> _GenerativeProvider:
"""Create a `_GenerativeAWSConfig` object for use when performing AI generation using the `generative-aws` module.

Expand All @@ -928,9 +936,19 @@ def aws(
region: The AWS region to run the model from, REQUIRED.
endpoint: The model to use, REQUIRED for service "sagemaker".
service: The AWS service to use, options are "bedrock" and "sagemaker".
target_model: The target model to use. Defaults to `None`, which uses the server-defined default
target_variant: The target variant to use. Defaults to `None`, which uses the server-defined default
temperature: The temperature to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeAWSConfig(
model=model, region=region, service=service, endpoint=endpoint, maxTokens=max_tokens
model=model,
region=region,
service=service,
endpoint=endpoint,
maxTokens=max_tokens,
temperature=temperature,
targetModel=target_model,
targetVariant=target_variant,
)

@staticmethod
Expand Down