Skip to content

Commit 6ccc8e7

Browse files
authored
models - openai - use client context (#856)
1 parent 69d3910 commit 6ccc8e7

File tree

2 files changed

+58
-62
lines changed

2 files changed

+58
-62
lines changed

src/strands/models/openai.py

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,10 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
6464
"""
6565
validate_config_keys(model_config, self.OpenAIConfig)
6666
self.config = dict(model_config)
67+
self.client_args = client_args or {}
6768

6869
logger.debug("config=<%s> | initializing", self.config)
6970

70-
client_args = client_args or {}
71-
self.client = openai.AsyncOpenAI(**client_args)
72-
7371
@override
7472
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
7573
"""Update the OpenAI model configuration with the provided arguments.
@@ -379,58 +377,60 @@ async def stream(
379377
logger.debug("formatted request=<%s>", request)
380378

381379
logger.debug("invoking model")
382-
response = await self.client.chat.completions.create(**request)
383-
384-
logger.debug("got response from model")
385-
yield self.format_chunk({"chunk_type": "message_start"})
386-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
387-
388-
tool_calls: dict[int, list[Any]] = {}
389-
390-
async for event in response:
391-
# Defensive: skip events with empty or missing choices
392-
if not getattr(event, "choices", None):
393-
continue
394-
choice = event.choices[0]
395-
396-
if choice.delta.content:
397-
yield self.format_chunk(
398-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
399-
)
400-
401-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
402-
yield self.format_chunk(
403-
{
404-
"chunk_type": "content_delta",
405-
"data_type": "reasoning_content",
406-
"data": choice.delta.reasoning_content,
407-
}
408-
)
409380

410-
for tool_call in choice.delta.tool_calls or []:
411-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
381+
async with openai.AsyncOpenAI(**self.client_args) as client:
382+
response = await client.chat.completions.create(**request)
412383

413-
if choice.finish_reason:
414-
break
384+
logger.debug("got response from model")
385+
yield self.format_chunk({"chunk_type": "message_start"})
386+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
415387

416-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
388+
tool_calls: dict[int, list[Any]] = {}
417389

418-
for tool_deltas in tool_calls.values():
419-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
390+
async for event in response:
391+
# Defensive: skip events with empty or missing choices
392+
if not getattr(event, "choices", None):
393+
continue
394+
choice = event.choices[0]
395+
396+
if choice.delta.content:
397+
yield self.format_chunk(
398+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
399+
)
400+
401+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
402+
yield self.format_chunk(
403+
{
404+
"chunk_type": "content_delta",
405+
"data_type": "reasoning_content",
406+
"data": choice.delta.reasoning_content,
407+
}
408+
)
420409

421-
for tool_delta in tool_deltas:
422-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
410+
for tool_call in choice.delta.tool_calls or []:
411+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
423412

424-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
413+
if choice.finish_reason:
414+
break
425415

426-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
416+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
427417

428-
# Skip remaining events as we don't have use for anything except the final usage payload
429-
async for event in response:
430-
_ = event
418+
for tool_deltas in tool_calls.values():
419+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
431420

432-
if event.usage:
433-
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
421+
for tool_delta in tool_deltas:
422+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
423+
424+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
425+
426+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
427+
428+
# Skip remaining events as we don't have use for anything except the final usage payload
429+
async for event in response:
430+
_ = event
431+
432+
if event.usage:
433+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
434434

435435
logger.debug("finished streaming response from model")
436436

@@ -449,11 +449,12 @@ async def structured_output(
449449
Yields:
450450
Model events with the last being the structured output.
451451
"""
452-
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
453-
model=self.get_config()["model_id"],
454-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
455-
response_format=output_model,
456-
)
452+
async with openai.AsyncOpenAI(**self.client_args) as client:
453+
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
454+
model=self.get_config()["model_id"],
455+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
456+
response_format=output_model,
457+
)
457458

458459
parsed: T | None = None
459460
# Find the first choice with tool_calls

tests/strands/models/test_openai.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88

99

1010
@pytest.fixture
11-
def openai_client_cls():
11+
def openai_client():
1212
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
13-
yield mock_client_cls
14-
15-
16-
@pytest.fixture
17-
def openai_client(openai_client_cls):
18-
return openai_client_cls.return_value
13+
mock_client = unittest.mock.AsyncMock()
14+
mock_client_cls.return_value.__aenter__.return_value = mock_client
15+
yield mock_client
1916

2017

2118
@pytest.fixture
@@ -68,16 +65,14 @@ class TestOutputModel(pydantic.BaseModel):
6865
return TestOutputModel
6966

7067

71-
def test__init__(openai_client_cls, model_id):
72-
model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
68+
def test__init__(model_id):
69+
model = OpenAIModel(model_id=model_id, params={"max_tokens": 1})
7370

7471
tru_config = model.get_config()
7572
exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
7673

7774
assert tru_config == exp_config
7875

79-
openai_client_cls.assert_called_once_with(api_key="k1")
80-
8176

8277
def test_update_config(model, model_id):
8378
model.update_config(model_id=model_id)

0 commit comments

Comments
 (0)