Skip to content

Commit 35e8485

Browse files
Merge branch 'add-cerebras-support': Refactor Cerebras to use OpenAI client
2 parents eae558b + 12a46ae commit 35e8485

File tree

11 files changed

+235
-120
lines changed

11 files changed

+235
-120
lines changed

docs/models/cerebras.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Cerebras
2+
3+
## Install
4+
5+
To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs `openai`):
6+
7+
```bash
8+
pip install "pydantic-ai-slim[cerebras]"
9+
```
10+
11+
or
12+
13+
```bash
14+
uv add "pydantic-ai-slim[cerebras]"
15+
```
16+
17+
## Configuration
18+
19+
To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key.
20+
21+
`CerebrasModelName` contains a list of available Cerebras models.
22+
23+
## Environment variable
24+
25+
Once you have the API key, you can set it as an environment variable:
26+
27+
```bash
28+
export CEREBRAS_API_KEY='your-api-key'
29+
```
30+
31+
You can then use `CerebrasModel` by name:
32+
33+
```python
34+
from pydantic_ai import Agent
35+
36+
agent = Agent('cerebras:llama-3.3-70b')
37+
...
38+
```
39+
40+
Or initialise the model directly with just the model name:
41+
42+
```python
43+
from pydantic_ai import Agent
44+
from pydantic_ai.models.cerebras import CerebrasModel
45+
46+
model = CerebrasModel('llama-3.3-70b')
47+
agent = Agent(model)
48+
...
49+
```
50+
51+
## `provider` argument
52+
53+
You can provide a custom `Provider` via the `provider` argument:
54+
55+
```python
56+
from pydantic_ai import Agent
57+
from pydantic_ai.models.cerebras import CerebrasModel
58+
from pydantic_ai.providers.cerebras import CerebrasProvider
59+
60+
model = CerebrasModel(
61+
'llama-3.3-70b', provider=CerebrasProvider(api_key='your-api-key')
62+
)
63+
agent = Agent(model)
64+
...
65+
```
66+
67+
You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`:
68+
69+
```python
70+
from httpx import AsyncClient
71+
72+
from pydantic_ai import Agent
73+
from pydantic_ai.models.cerebras import CerebrasModel
74+
from pydantic_ai.providers.cerebras import CerebrasProvider
75+
76+
custom_http_client = AsyncClient(timeout=30)
77+
model = CerebrasModel(
78+
'llama-3.3-70b',
79+
provider=CerebrasProvider(api_key='your-api-key', http_client=custom_http_client),
80+
)
81+
agent = Agent(model)
82+
...
83+
```

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nav:
3030
- models/anthropic.md
3131
- models/google.md
3232
- models/bedrock.md
33+
- models/cerebras.md
3334
- models/cohere.md
3435
- models/groq.md
3536
- models/mistral.md

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@
134134
'cerebras:llama-3.3-70b',
135135
'cerebras:llama3.1-8b',
136136
'cerebras:qwen-3-235b-a22b-instruct-2507',
137-
'cerebras:qwen-3-235b-a22b-thinking-2507',
138137
'cerebras:qwen-3-32b',
139138
'cerebras:zai-glm-4.6',
140139
'cohere:c4ai-aya-expanse-32b',
@@ -800,7 +799,6 @@ def infer_model( # noqa: C901
800799
'openai',
801800
'azure',
802801
'deepseek',
803-
'cerebras',
804802
'fireworks',
805803
'github',
806804
'grok',
@@ -818,7 +816,11 @@ def infer_model( # noqa: C901
818816
elif model_kind in ('google-gla', 'google-vertex'):
819817
model_kind = 'google'
820818

821-
if model_kind == 'openai-chat':
819+
if model_kind == 'cerebras':
820+
from .cerebras import CerebrasModel
821+
822+
return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type]
823+
elif model_kind == 'openai-chat':
822824
from .openai import OpenAIChatModel
823825

824826
return OpenAIChatModel(model_name, provider=provider)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Cerebras model implementation using OpenAI-compatible API."""
2+
3+
from __future__ import annotations as _annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Literal
7+
8+
from ..profiles import ModelProfileSpec
9+
from ..providers import Provider
10+
from ..settings import ModelSettings
11+
from .openai import OpenAIChatModel
12+
13+
try:
14+
from openai import AsyncOpenAI
15+
except ImportError as _import_error: # pragma: no cover
16+
raise ImportError(
17+
'Please install the `openai` package to use the Cerebras model, '
18+
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"'
19+
) from _import_error
20+
21+
__all__ = ('CerebrasModel', 'CerebrasModelName')
22+
23+
CerebrasModelName = Literal[
24+
'gpt-oss-120b',
25+
'llama-3.3-70b',
26+
'llama3.1-8b',
27+
'qwen-3-235b-a22b-instruct-2507',
28+
'qwen-3-32b',
29+
'zai-glm-4.6',
30+
]
31+
32+
33+
@dataclass(init=False)
34+
class CerebrasModel(OpenAIChatModel):
35+
"""A model that uses Cerebras's OpenAI-compatible API.
36+
37+
Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE).
38+
39+
Apart from `__init__`, all methods are private or match those of the base class.
40+
"""
41+
42+
def __init__(
43+
self,
44+
model_name: CerebrasModelName,
45+
*,
46+
provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras',
47+
profile: ModelProfileSpec | None = None,
48+
settings: ModelSettings | None = None,
49+
):
50+
"""Initialize a Cerebras model.
51+
52+
Args:
53+
model_name: The name of the Cerebras model to use.
54+
provider: The provider to use. Defaults to 'cerebras'.
55+
profile: The model profile to use. Defaults to a profile based on the model name.
56+
settings: Model-specific settings that will be used as defaults for this model.
57+
"""
58+
super().__init__(model_name, provider=provider, profile=profile, settings=settings)

pydantic_ai_slim/pydantic_ai/providers/cerebras.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic_ai.models import cached_async_http_client
1111
from pydantic_ai.profiles.harmony import harmony_model_profile
1212
from pydantic_ai.profiles.meta import meta_model_profile
13-
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
13+
from pydantic_ai.profiles.openai import OpenAIModelProfile
1414
from pydantic_ai.profiles.qwen import qwen_model_profile
1515
from pydantic_ai.providers import Provider
1616

@@ -19,7 +19,7 @@
1919
except ImportError as _import_error: # pragma: no cover
2020
raise ImportError(
2121
'Please install the `openai` package to use the Cerebras provider, '
22-
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
22+
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
2323
) from _import_error
2424

2525

@@ -39,27 +39,22 @@ def client(self) -> AsyncOpenAI:
3939
return self._client
4040

4141
def model_profile(self, model_name: str) -> ModelProfile | None:
42-
prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile}
42+
prefix_to_profile = {
43+
'llama': meta_model_profile,
44+
'qwen': qwen_model_profile,
45+
'gpt-oss': harmony_model_profile,
46+
}
4347

4448
profile = None
49+
model_name_lower = model_name.lower()
4550
for prefix, profile_func in prefix_to_profile.items():
46-
model_name = model_name.lower()
47-
if model_name.startswith(prefix):
51+
if model_name_lower.startswith(prefix):
4852
profile = profile_func(model_name)
53+
break
4954

50-
# According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
51-
# Cerebras doesn't support some model settings.
52-
unsupported_model_settings = (
53-
'frequency_penalty',
54-
'logit_bias',
55-
'presence_penalty',
56-
'parallel_tool_calls',
57-
'service_tier',
58-
)
59-
return OpenAIModelProfile(
60-
json_schema_transformer=OpenAIJsonSchemaTransformer,
61-
openai_unsupported_model_settings=unsupported_model_settings,
62-
).update(profile)
55+
# Wrap in OpenAIModelProfile with web search disabled
56+
# Cerebras doesn't support web search
57+
return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile)
6358

6459
@overload
6560
def __init__(self) -> None: ...
@@ -70,6 +65,9 @@ def __init__(self, *, api_key: str) -> None: ...
7065
@overload
7166
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
7267

68+
@overload
69+
def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...
70+
7371
@overload
7472
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
7573

@@ -80,6 +78,14 @@ def __init__(
8078
openai_client: AsyncOpenAI | None = None,
8179
http_client: httpx.AsyncClient | None = None,
8280
) -> None:
81+
"""Create a new Cerebras provider.
82+
83+
Args:
84+
api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable
85+
will be used if available.
86+
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`.
87+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
88+
"""
8389
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
8490
if not api_key and openai_client is None:
8591
raise UserError(

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
7272
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
7373
google = ["google-genai>=1.51.0"]
7474
anthropic = ["anthropic>=0.70.0"]
75+
cerebras = ["openai>=1.107.2"]
7576
groq = ["groq>=0.25.0"]
7677
mistral = ["mistralai>=1.9.10"]
7778
bedrock = ["boto3>=1.40.14"]

tests/models/cassettes/test_model_names/test_known_model_names.yaml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ interactions:
124124
alt-svc:
125125
- h3=":443"; ma=86400
126126
content-length:
127-
- '570'
127+
- '479'
128128
content-type:
129129
- application/json
130130
referrer-policy:
@@ -133,32 +133,28 @@ interactions:
133133
- max-age=3600; includeSubDomains
134134
parsed_body:
135135
data:
136-
- created: 0
137-
id: qwen-3-235b-a22b-thinking-2507
138-
object: model
139-
owned_by: Cerebras
140136
- created: 0
141137
id: llama-3.3-70b
142138
object: model
143139
owned_by: Cerebras
144140
- created: 0
145-
id: qwen-3-235b-a22b-instruct-2507
141+
id: llama3.1-8b
146142
object: model
147143
owned_by: Cerebras
148144
- created: 0
149-
id: qwen-3-32b
145+
id: zai-glm-4.6
150146
object: model
151147
owned_by: Cerebras
152148
- created: 0
153-
id: zai-glm-4.6
149+
id: qwen-3-32b
154150
object: model
155151
owned_by: Cerebras
156152
- created: 0
157153
id: gpt-oss-120b
158154
object: model
159155
owned_by: Cerebras
160156
- created: 0
161-
id: llama3.1-8b
157+
id: qwen-3-235b-a22b-instruct-2507
162158
object: model
163159
owned_by: Cerebras
164160
object: list

tests/models/test_cerebras.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
4+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
5+
6+
from ..conftest import try_import
7+
8+
with try_import() as imports_successful:
9+
from pydantic_ai.models.cerebras import CerebrasModel
10+
from pydantic_ai.providers.cerebras import CerebrasProvider
11+
12+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')
13+
14+
15+
def test_cerebras_model_init():
16+
model = CerebrasModel('llama-3.3-70b', provider=CerebrasProvider(api_key='test_key'))
17+
assert model.model_name == 'llama-3.3-70b'
18+
assert isinstance(model._provider, CerebrasProvider) # type: ignore[reportPrivateUsage]
19+
assert model._provider.client.api_key == 'test_key' # type: ignore[reportPrivateUsage]
20+
21+
22+
def test_cerebras_model_profile():
23+
provider = CerebrasProvider(api_key='test_key')
24+
25+
# Test Llama model
26+
model = CerebrasModel('llama-3.3-70b', provider=provider)
27+
profile = model.profile
28+
assert isinstance(profile, OpenAIModelProfile)
29+
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
30+
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
31+
32+
# Test Qwen model
33+
model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', provider=provider)
34+
profile = model.profile
35+
assert isinstance(profile, OpenAIModelProfile)
36+
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
37+
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
38+
39+
# Test GPT-OSS model
40+
model = CerebrasModel('gpt-oss-120b', provider=provider)
41+
profile = model.profile
42+
assert isinstance(profile, OpenAIModelProfile)
43+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
44+
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
45+
46+
# Test unknown model - use zai-glm which is valid but won't match any prefix
47+
model = CerebrasModel('zai-glm-4.6', provider=provider)
48+
profile = model.profile
49+
assert isinstance(profile, OpenAIModelProfile)
50+
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False

tests/models/test_openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,7 +3022,7 @@ async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_re
30223022

30233023

30243024
async def test_openai_model_cerebras_provider(allow_model_requests: None, cerebras_api_key: str):
3025-
m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key))
3025+
m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30263026
agent = Agent(m)
30273027

30283028
result = await agent.run('What is the capital of France?')
@@ -3034,15 +3034,15 @@ class Location(TypedDict):
30343034
city: str
30353035
country: str
30363036

3037-
m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key))
3037+
m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30383038
agent = Agent(m, output_type=Location)
30393039

30403040
result = await agent.run('What is the capital of France?')
30413041
assert result.output == snapshot({'city': 'Paris', 'country': 'France'})
30423042

30433043

30443044
async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None, cerebras_api_key: str):
3045-
m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key))
3045+
m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30463046
agent = Agent(m)
30473047

30483048
result = await agent.run('What is the capital of France?')

0 commit comments

Comments
 (0)