Skip to content

Commit 12a46ae

Browse files
refactor: Cerebras to use OpenAI client instead of SDK
- Changed CerebrasProvider to use AsyncOpenAI instead of AsyncCerebras SDK - Simplified CerebrasModel by removing custom _completions_create override - Updated dependency from cerebras-cloud-sdk to openai package - Follows OpenRouter pattern for consistency - Reduced codebase by ~200 lines while maintaining all functionality - All Cerebras tests passing (5/5) This aligns with Pydantic team's request to use OpenAI-compatible approach and removes the need for a separate SDK dependency.
1 parent 376573e commit 12a46ae

File tree

8 files changed

+69
-167
lines changed

8 files changed

+69
-167
lines changed

docs/models/cerebras.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Install
44

5-
To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs the `cerebras-cloud-sdk`):
5+
To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs `openai`):
66

77
```bash
88
pip install "pydantic-ai-slim[cerebras]"

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def infer_model( # noqa: C901
819819
if model_kind == 'cerebras':
820820
from .cerebras import CerebrasModel
821821

822-
return CerebrasModel(model_name, provider=provider)
822+
return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type]
823823
elif model_kind == 'openai-chat':
824824
from .openai import OpenAIChatModel
825825

pydantic_ai_slim/pydantic_ai/models/cerebras.py

Lines changed: 12 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,21 @@
33
from __future__ import annotations as _annotations
44

55
from dataclasses import dataclass
6-
from typing import Any, Literal
6+
from typing import Literal
7+
8+
from ..profiles import ModelProfileSpec
9+
from ..providers import Provider
10+
from ..settings import ModelSettings
11+
from .openai import OpenAIChatModel
712

813
try:
9-
from cerebras.cloud.sdk import AsyncCerebras # noqa: F401
14+
from openai import AsyncOpenAI
1015
except ImportError as _import_error: # pragma: no cover
1116
raise ImportError(
12-
'Please install the `cerebras-cloud-sdk` package to use the Cerebras model, '
13-
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
17+
'Please install the `openai` package to use the Cerebras model, '
18+
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"'
1419
) from _import_error
1520

16-
from ..profiles import ModelProfile, ModelProfileSpec
17-
from ..profiles.harmony import harmony_model_profile
18-
from ..profiles.meta import meta_model_profile
19-
from ..profiles.qwen import qwen_model_profile
20-
from ..providers import Provider
21-
from ..settings import ModelSettings
22-
from .openai import OpenAIChatModel, OpenAIModelProfile # type: ignore[attr-defined]
23-
2421
__all__ = ('CerebrasModel', 'CerebrasModelName')
2522

2623
CerebrasModelName = Literal[
@@ -46,101 +43,16 @@ def __init__(
4643
self,
4744
model_name: CerebrasModelName,
4845
*,
49-
provider: Literal['cerebras'] | Provider[Any] = 'cerebras',
46+
provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras',
5047
profile: ModelProfileSpec | None = None,
5148
settings: ModelSettings | None = None,
5249
):
5350
"""Initialize a Cerebras model.
5451
5552
Args:
5653
model_name: The name of the Cerebras model to use.
57-
provider: The provider to use. Can be 'cerebras' or a Provider instance.
54+
provider: The provider to use. Defaults to 'cerebras'.
5855
profile: The model profile to use. Defaults to a profile based on the model name.
5956
settings: Model-specific settings that will be used as defaults for this model.
6057
"""
61-
if provider == 'cerebras':
62-
from ..providers.cerebras import CerebrasProvider
63-
64-
# Extract api_key from settings if provided
65-
api_key = settings.get('api_key') if settings else None
66-
provider = CerebrasProvider(api_key=api_key) if api_key else CerebrasProvider() # type: ignore[call-overload]
67-
68-
# Use our custom model_profile method if no profile is provided
69-
if profile is None:
70-
profile = self._cerebras_model_profile
71-
72-
super().__init__(model_name, provider=provider, profile=profile, settings=settings) # type: ignore[arg-type]
73-
74-
def _cerebras_model_profile(self, model_name: str) -> ModelProfile:
75-
"""Get the model profile for this Cerebras model.
76-
77-
Returns a profile with web search disabled since Cerebras doesn't support it.
78-
"""
79-
model_name_lower = model_name.lower()
80-
81-
# Get base profile based on model family
82-
if model_name_lower.startswith('llama'):
83-
base_profile = meta_model_profile(model_name)
84-
elif model_name_lower.startswith('qwen'):
85-
base_profile = qwen_model_profile(model_name)
86-
elif model_name_lower.startswith('gpt-oss'):
87-
base_profile = harmony_model_profile(model_name)
88-
else:
89-
# Default profile for unknown models
90-
base_profile = ModelProfile()
91-
92-
# Wrap in OpenAIModelProfile with web search disabled
93-
return OpenAIModelProfile(
94-
openai_chat_supports_web_search=False,
95-
).update(base_profile)
96-
97-
async def _completions_create(
98-
self,
99-
messages: list[Any],
100-
stream: bool,
101-
model_settings: dict[str, Any],
102-
model_request_parameters: Any,
103-
) -> Any:
104-
"""Override to remove web_search_options parameter and convert Cerebras response to OpenAI format."""
105-
from openai._types import NOT_GIVEN
106-
from openai.types.chat import ChatCompletion
107-
108-
# Get the original client method
109-
original_create = self.client.chat.completions.create
110-
111-
# Create a wrapper that removes web_search_options and filters OMIT values
112-
async def create_without_web_search(**kwargs):
113-
# Remove web_search_options if present
114-
kwargs.pop('web_search_options', None)
115-
116-
# Remove all keys with OMIT or NOT_GIVEN values
117-
keys_to_remove = []
118-
for key, value in kwargs.items():
119-
# Check if it's OMIT by checking the type name
120-
if hasattr(value, '__class__') and value.__class__.__name__ == 'Omit':
121-
keys_to_remove.append(key)
122-
elif value is NOT_GIVEN:
123-
keys_to_remove.append(key)
124-
125-
for key in keys_to_remove:
126-
del kwargs[key]
127-
128-
# Call Cerebras SDK
129-
cerebras_response = await original_create(**kwargs)
130-
131-
# Convert Cerebras response to OpenAI ChatCompletion
132-
# The Cerebras SDK returns a compatible structure, we just need to convert the type
133-
response_dict = (
134-
cerebras_response.model_dump() if hasattr(cerebras_response, 'model_dump') else cerebras_response
135-
)
136-
return ChatCompletion.model_validate(response_dict)
137-
138-
# Temporarily replace the method
139-
self.client.chat.completions.create = create_without_web_search # type: ignore
140-
141-
try:
142-
# Call the parent implementation
143-
return await super()._completions_create(messages, stream, model_settings, model_request_parameters) # type: ignore
144-
finally:
145-
# Restore the original method
146-
self.client.chat.completions.create = original_create # type: ignore
58+
super().__init__(model_name, provider=provider, profile=profile, settings=settings)

pydantic_ai_slim/pydantic_ai/providers/cerebras.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
from pydantic_ai.models import cached_async_http_client
1111
from pydantic_ai.profiles.harmony import harmony_model_profile
1212
from pydantic_ai.profiles.meta import meta_model_profile
13+
from pydantic_ai.profiles.openai import OpenAIModelProfile
1314
from pydantic_ai.profiles.qwen import qwen_model_profile
1415
from pydantic_ai.providers import Provider
1516

1617
try:
17-
from cerebras.cloud.sdk import AsyncCerebras
18+
from openai import AsyncOpenAI
1819
except ImportError as _import_error: # pragma: no cover
1920
raise ImportError(
20-
'Please install the `cerebras-cloud-sdk` package to use the Cerebras provider, '
21+
'Please install the `openai` package to use the Cerebras provider, '
2122
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
2223
) from _import_error
2324

2425

25-
class CerebrasProvider(Provider[AsyncCerebras]):
26+
class CerebrasProvider(Provider[AsyncOpenAI]):
2627
"""Provider for Cerebras API."""
2728

2829
@property
@@ -31,10 +32,10 @@ def name(self) -> str:
3132

3233
@property
3334
def base_url(self) -> str:
34-
return 'https://api.cerebras.ai'
35+
return 'https://api.cerebras.ai/v1'
3536

3637
@property
37-
def client(self) -> AsyncCerebras:
38+
def client(self) -> AsyncOpenAI:
3839
return self._client
3940

4041
def model_profile(self, model_name: str) -> ModelProfile | None:
@@ -44,54 +45,58 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4445
'gpt-oss': harmony_model_profile,
4546
}
4647

48+
profile = None
49+
model_name_lower = model_name.lower()
4750
for prefix, profile_func in prefix_to_profile.items():
48-
model_name = model_name.lower()
49-
if model_name.startswith(prefix):
50-
return profile_func(model_name)
51+
if model_name_lower.startswith(prefix):
52+
profile = profile_func(model_name)
53+
break
5154

52-
return None
55+
# Wrap in OpenAIModelProfile with web search disabled
56+
# Cerebras doesn't support web search
57+
return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile)
5358

5459
@overload
55-
def __init__(self, *, cerebras_client: AsyncCerebras | None = None) -> None: ...
60+
def __init__(self) -> None: ...
5661

5762
@overload
58-
def __init__(
59-
self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None
60-
) -> None: ...
63+
def __init__(self, *, api_key: str) -> None: ...
64+
65+
@overload
66+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
67+
68+
@overload
69+
def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...
70+
71+
@overload
72+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
6173

6274
def __init__(
6375
self,
6476
*,
6577
api_key: str | None = None,
66-
base_url: str | None = None,
67-
cerebras_client: AsyncCerebras | None = None,
78+
openai_client: AsyncOpenAI | None = None,
6879
http_client: httpx.AsyncClient | None = None,
6980
) -> None:
7081
"""Create a new Cerebras provider.
7182
7283
Args:
7384
api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable
7485
will be used if available.
75-
base_url: The base url for the Cerebras requests. If not provided, defaults to Cerebras's base url.
76-
cerebras_client: An existing `AsyncCerebras` client to use. If provided, `api_key` and `http_client` must be `None`.
86+
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`.
7787
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
7888
"""
79-
if cerebras_client is not None:
80-
assert http_client is None, 'Cannot provide both `cerebras_client` and `http_client`'
81-
assert api_key is None, 'Cannot provide both `cerebras_client` and `api_key`'
82-
assert base_url is None, 'Cannot provide both `cerebras_client` and `base_url`'
83-
self._client = cerebras_client
89+
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
90+
if not api_key and openai_client is None:
91+
raise UserError(
92+
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
93+
'to use the Cerebras provider.'
94+
)
95+
96+
if openai_client is not None:
97+
self._client = openai_client
98+
elif http_client is not None:
99+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
84100
else:
85-
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
86-
base_url = base_url or 'https://api.cerebras.ai'
87-
88-
if not api_key:
89-
raise UserError(
90-
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
91-
'to use the Cerebras provider.'
92-
)
93-
elif http_client is not None:
94-
self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client)
95-
else:
96-
http_client = cached_async_http_client(provider='cerebras')
97-
self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client)
101+
http_client = cached_async_http_client(provider='cerebras')
102+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
7272
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
7373
google = ["google-genai>=1.51.0"]
7474
anthropic = ["anthropic>=0.70.0"]
75-
cerebras = ["cerebras-cloud-sdk>=1.0.0"]
75+
cerebras = ["openai>=1.107.2"]
7676
groq = ["groq>=0.25.0"]
7777
mistral = ["mistralai>=1.9.10"]
7878
bedrock = ["boto3>=1.40.14"]

tests/models/test_cerebras.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,38 @@
1313

1414

1515
def test_cerebras_model_init():
16-
model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'})
16+
model = CerebrasModel('llama-3.3-70b', provider=CerebrasProvider(api_key='test_key'))
1717
assert model.model_name == 'llama-3.3-70b'
18-
assert isinstance(model._provider, CerebrasProvider)
19-
assert model._provider.client.api_key == 'test_key'
18+
assert isinstance(model._provider, CerebrasProvider) # type: ignore[reportPrivateUsage]
19+
assert model._provider.client.api_key == 'test_key' # type: ignore[reportPrivateUsage]
2020

2121

2222
def test_cerebras_model_profile():
23+
provider = CerebrasProvider(api_key='test_key')
24+
2325
# Test Llama model
24-
model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'})
26+
model = CerebrasModel('llama-3.3-70b', provider=provider)
2527
profile = model.profile
2628
assert isinstance(profile, OpenAIModelProfile)
2729
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
2830
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
2931

3032
# Test Qwen model
31-
model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', settings={'api_key': 'test_key'})
33+
model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', provider=provider)
3234
profile = model.profile
3335
assert isinstance(profile, OpenAIModelProfile)
3436
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
3537
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
3638

3739
# Test GPT-OSS model
38-
model = CerebrasModel('gpt-oss-120b', settings={'api_key': 'test_key'})
40+
model = CerebrasModel('gpt-oss-120b', provider=provider)
3941
profile = model.profile
4042
assert isinstance(profile, OpenAIModelProfile)
4143
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
4244
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False
4345

4446
# Test unknown model - use zai-glm which is valid but won't match any prefix
45-
model = CerebrasModel('zai-glm-4.6', settings={'api_key': 'test_key'})
47+
model = CerebrasModel('zai-glm-4.6', provider=provider)
4648
profile = model.profile
4749
assert isinstance(profile, OpenAIModelProfile)
4850
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False

tests/models/test_openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,7 +3022,7 @@ async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_re
30223022

30233023

30243024
async def test_openai_model_cerebras_provider(allow_model_requests: None, cerebras_api_key: str):
3025-
m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key))
3025+
m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30263026
agent = Agent(m)
30273027

30283028
result = await agent.run('What is the capital of France?')
@@ -3034,15 +3034,15 @@ class Location(TypedDict):
30343034
city: str
30353035
country: str
30363036

3037-
m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key))
3037+
m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30383038
agent = Agent(m, output_type=Location)
30393039

30403040
result = await agent.run('What is the capital of France?')
30413041
assert result.output == snapshot({'city': 'Paris', 'country': 'France'})
30423042

30433043

30443044
async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None, cerebras_api_key: str):
3045-
m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key))
3045+
m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type]
30463046
agent = Agent(m)
30473047

30483048
result = await agent.run('What is the capital of France?')

0 commit comments

Comments
 (0)