@@ -64,12 +64,10 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
64
64
"""
65
65
validate_config_keys (model_config , self .OpenAIConfig )
66
66
self .config = dict (model_config )
67
+ self .client_args = client_args or {}
67
68
68
69
logger .debug ("config=<%s> | initializing" , self .config )
69
70
70
- client_args = client_args or {}
71
- self .client = openai .AsyncOpenAI (** client_args )
72
-
73
71
@override
74
72
def update_config (self , ** model_config : Unpack [OpenAIConfig ]) -> None : # type: ignore[override]
75
73
"""Update the OpenAI model configuration with the provided arguments.
@@ -379,58 +377,60 @@ async def stream(
379
377
logger .debug ("formatted request=<%s>" , request )
380
378
381
379
logger .debug ("invoking model" )
382
- response = await self .client .chat .completions .create (** request )
383
-
384
- logger .debug ("got response from model" )
385
- yield self .format_chunk ({"chunk_type" : "message_start" })
386
- yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
387
-
388
- tool_calls : dict [int , list [Any ]] = {}
389
-
390
- async for event in response :
391
- # Defensive: skip events with empty or missing choices
392
- if not getattr (event , "choices" , None ):
393
- continue
394
- choice = event .choices [0 ]
395
-
396
- if choice .delta .content :
397
- yield self .format_chunk (
398
- {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : choice .delta .content }
399
- )
400
-
401
- if hasattr (choice .delta , "reasoning_content" ) and choice .delta .reasoning_content :
402
- yield self .format_chunk (
403
- {
404
- "chunk_type" : "content_delta" ,
405
- "data_type" : "reasoning_content" ,
406
- "data" : choice .delta .reasoning_content ,
407
- }
408
- )
409
380
410
- for tool_call in choice . delta . tool_calls or [] :
411
- tool_calls . setdefault ( tool_call . index , []). append ( tool_call )
381
+ async with openai . AsyncOpenAI ( ** self . client_args ) as client :
382
+ response = await client . chat . completions . create ( ** request )
412
383
413
- if choice .finish_reason :
414
- break
384
+ logger .debug ("got response from model" )
385
+ yield self .format_chunk ({"chunk_type" : "message_start" })
386
+ yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
415
387
416
- yield self . format_chunk ({ "chunk_type" : "content_stop" , "data_type" : "text" })
388
+ tool_calls : dict [ int , list [ Any ]] = {}
417
389
418
- for tool_deltas in tool_calls .values ():
419
- yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]})
390
+ async for event in response :
391
+ # Defensive: skip events with empty or missing choices
392
+ if not getattr (event , "choices" , None ):
393
+ continue
394
+ choice = event .choices [0 ]
395
+
396
+ if choice .delta .content :
397
+ yield self .format_chunk (
398
+ {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : choice .delta .content }
399
+ )
400
+
401
+ if hasattr (choice .delta , "reasoning_content" ) and choice .delta .reasoning_content :
402
+ yield self .format_chunk (
403
+ {
404
+ "chunk_type" : "content_delta" ,
405
+ "data_type" : "reasoning_content" ,
406
+ "data" : choice .delta .reasoning_content ,
407
+ }
408
+ )
420
409
421
- for tool_delta in tool_deltas :
422
- yield self . format_chunk ({ "chunk_type" : "content_delta" , "data_type" : "tool" , "data" : tool_delta } )
410
+ for tool_call in choice . delta . tool_calls or [] :
411
+ tool_calls . setdefault ( tool_call . index , []). append ( tool_call )
423
412
424
- yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
413
+ if choice .finish_reason :
414
+ break
425
415
426
- yield self .format_chunk ({"chunk_type" : "message_stop " , "data " : choice . finish_reason })
416
+ yield self .format_chunk ({"chunk_type" : "content_stop " , "data_type " : "text" })
427
417
428
- # Skip remaining events as we don't have use for anything except the final usage payload
429
- async for event in response :
430
- _ = event
418
+ for tool_deltas in tool_calls .values ():
419
+ yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]})
431
420
432
- if event .usage :
433
- yield self .format_chunk ({"chunk_type" : "metadata" , "data" : event .usage })
421
+ for tool_delta in tool_deltas :
422
+ yield self .format_chunk ({"chunk_type" : "content_delta" , "data_type" : "tool" , "data" : tool_delta })
423
+
424
+ yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
425
+
426
+ yield self .format_chunk ({"chunk_type" : "message_stop" , "data" : choice .finish_reason })
427
+
428
+ # Skip remaining events as we don't have use for anything except the final usage payload
429
+ async for event in response :
430
+ _ = event
431
+
432
+ if event .usage :
433
+ yield self .format_chunk ({"chunk_type" : "metadata" , "data" : event .usage })
434
434
435
435
logger .debug ("finished streaming response from model" )
436
436
@@ -449,11 +449,12 @@ async def structured_output(
449
449
Yields:
450
450
Model events with the last being the structured output.
451
451
"""
452
- response : ParsedChatCompletion = await self .client .beta .chat .completions .parse ( # type: ignore
453
- model = self .get_config ()["model_id" ],
454
- messages = self .format_request (prompt , system_prompt = system_prompt )["messages" ],
455
- response_format = output_model ,
456
- )
452
+ async with openai .AsyncOpenAI (** self .client_args ) as client :
453
+ response : ParsedChatCompletion = await client .beta .chat .completions .parse (
454
+ model = self .get_config ()["model_id" ],
455
+ messages = self .format_request (prompt , system_prompt = system_prompt )["messages" ],
456
+ response_format = output_model ,
457
+ )
457
458
458
459
parsed : T | None = None
459
460
# Find the first choice with tool_calls
0 commit comments