Skip to content

Commit ea6372a

Browse files
Add Cerebras model support with official SDK
- Implement CerebrasModel subclassing OpenAIChatModel - Disable web search (Cerebras-specific feature) - Use cerebras-cloud-sdk for official API support - All tests passing (6/6 Cerebras-specific tests)
1 parent 3e552fb commit ea6372a

File tree

9 files changed

+258
-210
lines changed

9 files changed

+258
-210
lines changed

docs/models/cerebras.md

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,77 @@
11
# Cerebras
22

3-
Cerebras provides ultra-fast inference using their Wafer-Scale Engine (WSE), delivering predictable performance for any workload.
3+
## Install
44

5-
## Installation
6-
7-
To use Cerebras, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group:
5+
To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group:
86

97
```bash
10-
# pip
11-
pip install "pydantic-ai-slim[cerebras]"
12-
13-
# uv
14-
uv add "pydantic-ai-slim[cerebras]"
8+
pip/uv-add "pydantic-ai-slim[cerebras]"
159
```
1610

1711
## Configuration
1812

19-
To use Cerebras, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) to get an API key.
13+
To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key.
14+
15+
`CerebrasModelName` contains a list of available Cerebras models.
2016

21-
### Environment Variable
17+
## Environment variable
2218

23-
Set your API key as an environment variable:
19+
Once you have the API key, you can set it as an environment variable:
2420

2521
```bash
2622
export CEREBRAS_API_KEY='your-api-key'
2723
```
2824

29-
### Available Models
30-
31-
Cerebras supports the following models:
32-
33-
- `llama-3.3-70b` (recommended) - Latest Llama 3.3 model
34-
- `llama-3.1-8b` - Llama 3.1 8B (faster, smaller)
35-
- `qwen-3-235b-a22b-instruct-2507` - Qwen 3 235B
36-
- `qwen-3-32b` - Qwen 3 32B
37-
- `gpt-oss-120b` - GPT-OSS 120B
38-
- `zai-glm-4.6` - GLM 4.6 model
25+
You can then use `CerebrasModel` by name:
3926

27+
```python
28+
from pydantic_ai import Agent
4029

41-
See the [Cerebras documentation](https://inference-docs.cerebras.ai/introduction?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) for the latest models.
30+
agent = Agent('cerebras:llama-3.3-70b')
31+
...
32+
```
4233

43-
## Usage
34+
Or initialise the model directly with just the model name:
4435

4536
```python
4637
from pydantic_ai import Agent
38+
from pydantic_ai.models.cerebras import CerebrasModel
4739

48-
agent = Agent('cerebras:llama-3.3-70b')
49-
result = agent.run_sync('What is the capital of France?')
50-
print(result.output)
51-
#> The capital of France is Paris.
40+
model = CerebrasModel('llama-3.3-70b')
41+
agent = Agent(model)
42+
...
5243
```
5344

54-
## Why Cerebras?
45+
## `provider` argument
46+
47+
You can provide a custom `Provider` via the `provider` argument:
5548

56-
- **Ultra-fast inference** - Powered by the world's largest AI chip (WSE)
57-
- **Predictable performance** - Consistent latency for any workload
58-
- **OpenAI-compatible** - Drop-in replacement for OpenAI API
59-
- **Cost-effective** - Competitive pricing with superior performance
49+
```python
50+
from pydantic_ai import Agent
51+
from pydantic_ai.models.cerebras import CerebrasModel
52+
from pydantic_ai.providers.cerebras import CerebrasProvider
53+
54+
model = CerebrasModel(
55+
'llama-3.3-70b', provider=CerebrasProvider(api_key='your-api-key')
56+
)
57+
agent = Agent(model)
58+
...
59+
```
6060

61-
## Resources
61+
You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`:
6262

63-
- [Cerebras Inference Documentation](https://inference-docs.cerebras.ai?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc)
64-
- [Get API Key](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc)
65-
- [Model Pricing](https://cerebras.ai/pricing?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc)
63+
```python
64+
from httpx import AsyncClient
65+
66+
from pydantic_ai import Agent
67+
from pydantic_ai.models.cerebras import CerebrasModel
68+
from pydantic_ai.providers.cerebras import CerebrasProvider
69+
70+
custom_http_client = AsyncClient(timeout=30)
71+
model = CerebrasModel(
72+
'llama-3.3-70b',
73+
provider=CerebrasProvider(api_key='your-api-key', http_client=custom_http_client),
74+
)
75+
agent = Agent(model)
76+
...
77+
```

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,6 @@ def infer_model( # noqa: C901
800800
'openai',
801801
'azure',
802802
'deepseek',
803-
'cerebras',
804803
'fireworks',
805804
'github',
806805
'grok',
@@ -818,7 +817,11 @@ def infer_model( # noqa: C901
818817
elif model_kind in ('google-gla', 'google-vertex'):
819818
model_kind = 'google'
820819

821-
if model_kind == 'openai-chat':
820+
if model_kind == 'cerebras':
821+
from .cerebras import CerebrasModel
822+
823+
return CerebrasModel(model_name, provider=provider)
824+
elif model_kind == 'openai-chat':
822825
from .openai import OpenAIChatModel
823826

824827
return OpenAIChatModel(model_name, provider=provider)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Cerebras model implementation using OpenAI-compatible API."""
2+
3+
from __future__ import annotations as _annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Literal
7+
8+
try:
9+
from cerebras.cloud.sdk import AsyncCerebras # noqa: F401
10+
except ImportError as _import_error: # pragma: no cover
11+
raise ImportError(
12+
'Please install the `cerebras-cloud-sdk` package to use the Cerebras model, '
13+
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
14+
) from _import_error
15+
16+
from ..profiles import ModelProfile, ModelProfileSpec
17+
from ..profiles.harmony import harmony_model_profile
18+
from ..profiles.meta import meta_model_profile
19+
from ..profiles.qwen import qwen_model_profile
20+
from ..providers import Provider
21+
from ..settings import ModelSettings
22+
from .openai import OpenAIChatModel, OpenAIModelProfile # type: ignore[attr-defined]
23+
24+
__all__ = ('CerebrasModel', 'CerebrasModelName')
25+
26+
CerebrasModelName = Literal[
27+
'llama-3.3-70b',
28+
'llama-4-scout-17b-16e-instruct',
29+
'qwen-3-235b-a22b-instruct-2507',
30+
'qwen-3-32b',
31+
'gpt-oss-120b',
32+
'zai-glm-4.6',
33+
]
34+
35+
36+
@dataclass(init=False)
37+
class CerebrasModel(OpenAIChatModel):
38+
"""A model that uses Cerebras's OpenAI-compatible API.
39+
40+
Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE).
41+
42+
Apart from `__init__`, all methods are private or match those of the base class.
43+
"""
44+
45+
def __init__(
46+
self,
47+
model_name: CerebrasModelName,
48+
*,
49+
provider: Literal['cerebras'] | Provider[Any] = 'cerebras',
50+
profile: ModelProfileSpec | None = None,
51+
settings: ModelSettings | None = None,
52+
):
53+
"""Initialize a Cerebras model.
54+
55+
Args:
56+
model_name: The name of the Cerebras model to use.
57+
provider: The provider to use. Can be 'cerebras' or a Provider instance.
58+
profile: The model profile to use. Defaults to a profile based on the model name.
59+
settings: Model-specific settings that will be used as defaults for this model.
60+
"""
61+
if provider == 'cerebras':
62+
from ..providers.cerebras import CerebrasProvider
63+
64+
# Extract api_key from settings if provided
65+
api_key = settings.get('api_key') if settings else None
66+
provider = CerebrasProvider(api_key=api_key) if api_key else CerebrasProvider() # type: ignore[call-overload]
67+
68+
# Use our custom model_profile method if no profile is provided
69+
if profile is None:
70+
profile = self._cerebras_model_profile
71+
72+
super().__init__(model_name, provider=provider, profile=profile, settings=settings) # type: ignore[arg-type]
73+
74+
def _cerebras_model_profile(self, model_name: str) -> ModelProfile:
75+
"""Get the model profile for this Cerebras model.
76+
77+
Returns a profile with web search disabled since Cerebras doesn't support it.
78+
"""
79+
model_name_lower = model_name.lower()
80+
81+
# Get base profile based on model family
82+
if model_name_lower.startswith('llama'):
83+
base_profile = meta_model_profile(model_name)
84+
elif model_name_lower.startswith('qwen'):
85+
base_profile = qwen_model_profile(model_name)
86+
elif model_name_lower.startswith('gpt-oss'):
87+
base_profile = harmony_model_profile(model_name)
88+
else:
89+
# Default profile for unknown models
90+
base_profile = ModelProfile()
91+
92+
# Wrap in OpenAIModelProfile with web search disabled
93+
return OpenAIModelProfile(
94+
openai_chat_supports_web_search=False,
95+
).update(base_profile)

pydantic_ai_slim/pydantic_ai/providers/cerebras.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
from pydantic_ai.providers import Provider
1515

1616
try:
17-
from openai import AsyncOpenAI
17+
from cerebras.cloud.sdk import AsyncCerebras
1818
except ImportError as _import_error: # pragma: no cover
1919
raise ImportError(
20-
'Please install the `openai` package to use the Cerebras provider, '
21-
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
20+
'Please install the `cerebras-cloud-sdk` package to use the Cerebras provider, '
21+
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
2222
) from _import_error
2323

2424

25-
class CerebrasProvider(Provider[AsyncOpenAI]):
25+
class CerebrasProvider(Provider[AsyncCerebras]):
2626
"""Provider for Cerebras API."""
2727

2828
@property
@@ -34,7 +34,7 @@ def base_url(self) -> str:
3434
return 'https://api.cerebras.ai/v1'
3535

3636
@property
37-
def client(self) -> AsyncOpenAI:
37+
def client(self) -> AsyncCerebras:
3838
return self._client
3939

4040
def model_profile(self, model_name: str) -> ModelProfile | None:
@@ -52,35 +52,46 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
5252
return None
5353

5454
@overload
55-
def __init__(self) -> None: ...
55+
def __init__(self, *, cerebras_client: AsyncCerebras | None = None) -> None: ...
5656

5757
@overload
58-
def __init__(self, *, api_key: str) -> None: ...
59-
60-
@overload
61-
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
62-
63-
@overload
64-
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
58+
def __init__(
59+
self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None
60+
) -> None: ...
6561

6662
def __init__(
6763
self,
6864
*,
6965
api_key: str | None = None,
70-
openai_client: AsyncOpenAI | None = None,
66+
base_url: str | None = None,
67+
cerebras_client: AsyncCerebras | None = None,
7168
http_client: httpx.AsyncClient | None = None,
7269
) -> None:
73-
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
74-
if not api_key and openai_client is None:
75-
raise UserError(
76-
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
77-
'to use the Cerebras provider.'
78-
)
79-
80-
if openai_client is not None:
81-
self._client = openai_client
82-
elif http_client is not None:
83-
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
70+
"""Create a new Cerebras provider.
71+
72+
Args:
73+
api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable
74+
will be used if available.
75+
base_url: The base url for the Cerebras requests. If not provided, defaults to Cerebras's base url.
76+
cerebras_client: An existing `AsyncCerebras` client to use. If provided, `api_key` and `http_client` must be `None`.
77+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
78+
"""
79+
if cerebras_client is not None:
80+
assert http_client is None, 'Cannot provide both `cerebras_client` and `http_client`'
81+
assert api_key is None, 'Cannot provide both `cerebras_client` and `api_key`'
82+
assert base_url is None, 'Cannot provide both `cerebras_client` and `base_url`'
83+
self._client = cerebras_client
8484
else:
85-
http_client = cached_async_http_client(provider='cerebras')
86-
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
85+
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
86+
base_url = base_url or 'https://api.cerebras.ai/v1'
87+
88+
if not api_key:
89+
raise UserError(
90+
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
91+
'to use the Cerebras provider.'
92+
)
93+
elif http_client is not None:
94+
self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client)
95+
else:
96+
http_client = cached_async_http_client(provider='cerebras')
97+
self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client)

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
7272
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
7373
google = ["google-genai>=1.51.0"]
7474
anthropic = ["anthropic>=0.70.0"]
75-
cerebras = ["openai>=1.107.2"]
75+
cerebras = ["cerebras-cloud-sdk>=1.0.0"]
7676
groq = ["groq>=0.25.0"]
7777
mistral = ["mistralai>=1.9.10"]
7878
bedrock = ["boto3>=1.40.14"]

0 commit comments

Comments
 (0)