Skip to content

Commit 5188a1d

Browse files
baskaryanErick Friis
authored andcommitted
integrations[patch]: remove non-required chat param defaults (#26730)
anthropic: - max_retries openai: - n - temperature - max_retries fireworks - temperature groq - n - max_retries - temperature mistral - max_retries - timeout - max_concurrent_requests - temperature - top_p - safe_mode --------- Co-authored-by: Erick Friis <[email protected]>
1 parent 648f20a commit 5188a1d

File tree

15 files changed

+51
-43
lines changed

15 files changed

+51
-43
lines changed

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
307307
Key init args — client params:
308308
timeout: Optional[float]
309309
Timeout for requests.
310-
max_retries: int
310+
max_retries: Optional[int]
311311
Max number of retries if a request fails.
312312
api_key: Optional[str]
313313
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
@@ -558,8 +558,7 @@ class Joke(BaseModel):
558558
default_request_timeout: Optional[float] = Field(None, alias="timeout")
559559
"""Timeout for requests to Anthropic Completion API."""
560560

561-
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
562-
max_retries: int = 2
561+
max_retries: Optional[int] = None
563562
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
564563

565564
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
@@ -662,9 +661,10 @@ def _client_params(self) -> Dict[str, Any]:
662661
client_params: Dict[str, Any] = {
663662
"api_key": self.anthropic_api_key.get_secret_value(),
664663
"base_url": self.anthropic_api_url,
665-
"max_retries": self.max_retries,
666664
"default_headers": (self.default_headers or None),
667665
}
666+
if self.max_retries is not None:
667+
client_params["max_retries"] = self.max_retries
668668
# value <= 0 indicates the param should be ignored. None is a meaningful value
669669
# for Anthropic client and treated differently than not specifying the param at
670670
# all.

libs/partners/fireworks/langchain_fireworks/chat_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def is_lc_serializable(cls) -> bool:
316316
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
317317
)
318318
"""Model name to use."""
319-
temperature: float = 0.0
319+
temperature: Optional[float] = None
320320
"""What sampling temperature to use."""
321321
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
322322
"""Default stop sequences."""

libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'request_timeout': 60.0,
2323
'stop': list([
2424
]),
25+
'temperature': 0.0,
2526
}),
2627
'lc': 1,
2728
'name': 'ChatFireworks',

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
119119
Key init args — client params:
120120
timeout: Union[float, Tuple[float, float], Any, None]
121121
Timeout for requests.
122-
max_retries: int
122+
max_retries: Optional[int]
123123
Max number of retries.
124124
api_key: Optional[str]
125125
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
@@ -303,7 +303,7 @@ class Joke(BaseModel):
303303
async_client: Any = Field(default=None, exclude=True) #: :meta private:
304304
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
305305
"""Model name to use."""
306-
temperature: float = 0.7
306+
temperature: Optional[float] = None
307307
"""What sampling temperature to use."""
308308
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
309309
"""Default stop sequences."""
@@ -327,11 +327,11 @@ class Joke(BaseModel):
327327
)
328328
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
329329
None."""
330-
max_retries: int = 2
330+
max_retries: Optional[int] = None
331331
"""Maximum number of retries to make when generating."""
332332
streaming: bool = False
333333
"""Whether to stream the results or not."""
334-
n: int = 1
334+
n: Optional[int] = None
335335
"""Number of chat completions to generate for each prompt."""
336336
max_tokens: Optional[int] = None
337337
"""Maximum number of tokens to generate."""
@@ -379,10 +379,11 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
379379
@model_validator(mode="after")
380380
def validate_environment(self) -> Self:
381381
"""Validate that api key and python package exists in environment."""
382-
if self.n < 1:
382+
if self.n is not None and self.n < 1:
383383
raise ValueError("n must be at least 1.")
384-
if self.n > 1 and self.streaming:
384+
elif self.n is not None and self.n > 1 and self.streaming:
385385
raise ValueError("n must be 1 when streaming.")
386+
386387
if self.temperature == 0:
387388
self.temperature = 1e-8
388389

@@ -392,10 +393,11 @@ def validate_environment(self) -> Self:
392393
),
393394
"base_url": self.groq_api_base,
394395
"timeout": self.request_timeout,
395-
"max_retries": self.max_retries,
396396
"default_headers": self.default_headers,
397397
"default_query": self.default_query,
398398
}
399+
if self.max_retries is not None:
400+
client_params["max_retries"] = self.max_retries
399401

400402
try:
401403
import groq

libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
'max_retries': 2,
1818
'max_tokens': 100,
1919
'model_name': 'mixtral-8x7b-32768',
20-
'n': 1,
2120
'request_timeout': 60.0,
2221
'stop': list([
2322
]),

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,12 @@ def _create_retry_decorator(
9595
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
9696

9797
errors = [httpx.RequestError, httpx.StreamError]
98-
return create_base_retry_decorator(
98+
kwargs: dict = dict(
9999
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
100100
)
101+
return create_base_retry_decorator(
102+
**{k: v for k, v in kwargs.items() if v is not None}
103+
)
101104

102105

103106
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
@@ -380,13 +383,13 @@ class ChatMistralAI(BaseChatModel):
380383
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
381384
)
382385
endpoint: Optional[str] = Field(default=None, alias="base_url")
383-
max_retries: int = 5
384-
timeout: int = 120
385-
max_concurrent_requests: int = 64
386+
max_retries: Optional[int] = None
387+
timeout: Optional[int] = None
388+
max_concurrent_requests: Optional[int] = None
386389
model: str = Field(default="mistral-small", alias="model_name")
387-
temperature: float = 0.7
390+
temperature: Optional[float] = None
388391
max_tokens: Optional[int] = None
389-
top_p: float = 1
392+
top_p: Optional[float] = None
390393
"""Decode using nucleus sampling: consider the smallest set of tokens whose
391394
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
392395
random_seed: Optional[int] = None

libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
]),
1010
'kwargs': dict({
1111
'endpoint': 'boo',
12-
'max_concurrent_requests': 64,
1312
'max_retries': 2,
1413
'max_tokens': 100,
1514
'mistral_api_key': dict({
@@ -22,7 +21,6 @@
2221
'model': 'mistral-small',
2322
'temperature': 0.0,
2423
'timeout': 60,
25-
'top_p': 1,
2624
}),
2725
'lc': 1,
2826
'name': 'ChatMistralAI',

libs/partners/openai/langchain_openai/chat_models/azure.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
7979
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
8080
timeout: Union[float, Tuple[float, float], Any, None]
8181
Timeout for requests.
82-
max_retries: int
82+
max_retries: Optional[int]
8383
Max number of retries.
8484
organization: Optional[str]
8585
OpenAI organization ID. If not passed in will be read from env
@@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool:
586586
@model_validator(mode="after")
587587
def validate_environment(self) -> Self:
588588
"""Validate that api key and python package exists in environment."""
589-
if self.n < 1:
589+
if self.n is not None and self.n < 1:
590590
raise ValueError("n must be at least 1.")
591-
if self.n > 1 and self.streaming:
591+
elif self.n is not None and self.n > 1 and self.streaming:
592592
raise ValueError("n must be 1 when streaming.")
593593

594594
if self.disabled_params is None:
@@ -641,10 +641,11 @@ def validate_environment(self) -> Self:
641641
"organization": self.openai_organization,
642642
"base_url": self.openai_api_base,
643643
"timeout": self.request_timeout,
644-
"max_retries": self.max_retries,
645644
"default_headers": self.default_headers,
646645
"default_query": self.default_query,
647646
}
647+
if self.max_retries is not None:
648+
client_params["max_retries"] = self.max_retries
648649
if not self.client:
649650
sync_specific = {"http_client": self.http_client}
650651
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ class BaseChatOpenAI(BaseChatModel):
414414
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
415415
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
416416
"""Model name to use."""
417-
temperature: float = 0.7
417+
temperature: Optional[float] = None
418418
"""What sampling temperature to use."""
419419
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
420420
"""Holds any model parameters valid for `create` call not explicitly specified."""
@@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
435435
)
436436
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
437437
None."""
438-
max_retries: int = 2
438+
max_retries: Optional[int] = None
439439
"""Maximum number of retries to make when generating."""
440440
presence_penalty: Optional[float] = None
441441
"""Penalizes repeated tokens."""
@@ -453,7 +453,7 @@ class BaseChatOpenAI(BaseChatModel):
453453
"""Modify the likelihood of specified tokens appearing in the completion."""
454454
streaming: bool = False
455455
"""Whether to stream the results or not."""
456-
n: int = 1
456+
n: Optional[int] = None
457457
"""Number of chat completions to generate for each prompt."""
458458
top_p: Optional[float] = None
459459
"""Total probability mass of tokens to consider at each step."""
@@ -537,9 +537,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
537537
@model_validator(mode="after")
538538
def validate_environment(self) -> Self:
539539
"""Validate that api key and python package exists in environment."""
540-
if self.n < 1:
540+
if self.n is not None and self.n < 1:
541541
raise ValueError("n must be at least 1.")
542-
if self.n > 1 and self.streaming:
542+
elif self.n is not None and self.n > 1 and self.streaming:
543543
raise ValueError("n must be 1 when streaming.")
544544

545545
# Check OPENAI_ORGANIZATION for backwards compatibility.
@@ -556,10 +556,12 @@ def validate_environment(self) -> Self:
556556
"organization": self.openai_organization,
557557
"base_url": self.openai_api_base,
558558
"timeout": self.request_timeout,
559-
"max_retries": self.max_retries,
560559
"default_headers": self.default_headers,
561560
"default_query": self.default_query,
562561
}
562+
if self.max_retries is not None:
563+
client_params["max_retries"] = self.max_retries
564+
563565
if self.openai_proxy and (self.http_client or self.http_async_client):
564566
openai_proxy = self.openai_proxy
565567
http_client = self.http_client
@@ -614,14 +616,14 @@ def _default_params(self) -> Dict[str, Any]:
614616
"stop": self.stop or None, # also exclude empty list for this
615617
"max_tokens": self.max_tokens,
616618
"extra_body": self.extra_body,
619+
"n": self.n,
620+
"temperature": self.temperature,
617621
"reasoning_effort": self.reasoning_effort,
618622
}
619623

620624
params = {
621625
"model": self.model_name,
622626
"stream": self.streaming,
623-
"n": self.n,
624-
"temperature": self.temperature,
625627
**{k: v for k, v in exclude_if_none.items() if v is not None},
626628
**self.model_kwargs,
627629
}
@@ -1573,7 +1575,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
15731575
15741576
timeout: Union[float, Tuple[float, float], Any, None]
15751577
Timeout for requests.
1576-
max_retries: int
1578+
max_retries: Optional[int]
15771579
Max number of retries.
15781580
api_key: Optional[str]
15791581
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.

libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
}),
1616
'max_retries': 2,
1717
'max_tokens': 100,
18-
'n': 1,
1918
'openai_api_key': dict({
2019
'id': list([
2120
'AZURE_OPENAI_API_KEY',

0 commit comments

Comments
 (0)