Skip to content

Commit 0ffc14c

Browse files
committed
refactor: deduplicate litellm streaming and non streaming
1 parent 483b3d5 commit 0ffc14c

File tree

3 files changed

+253
-132
lines changed

3 files changed

+253
-132
lines changed

src/strands/models/litellm.py

Lines changed: 185 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -280,121 +280,16 @@ async def stream(
280280

281281
logger.debug("invoking model with stream=%s", litellm_request.get("stream"))
282282

283-
if not is_streaming:
284-
response = await litellm.acompletion(**self.client_args, **litellm_request)
285-
286-
logger.debug("got non-streaming response from model")
287-
yield self.format_chunk({"chunk_type": "message_start"})
288-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
289-
290-
tool_calls: dict[int, list[Any]] = {}
291-
finish_reason = None
292-
293-
if hasattr(response, "choices") and response.choices and len(response.choices) > 0:
294-
choice = response.choices[0]
295-
296-
if hasattr(choice, "message") and choice.message:
297-
if hasattr(choice.message, "content") and choice.message.content:
298-
yield self.format_chunk(
299-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.message.content}
300-
)
301-
302-
if hasattr(choice.message, "reasoning_content") and choice.message.reasoning_content:
303-
yield self.format_chunk(
304-
{
305-
"chunk_type": "content_delta",
306-
"data_type": "reasoning_content",
307-
"data": choice.message.reasoning_content,
308-
}
309-
)
310-
311-
if hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
312-
for i, tool_call in enumerate(choice.message.tool_calls):
313-
tool_calls.setdefault(i, []).append(tool_call)
314-
315-
if hasattr(choice, "finish_reason"):
316-
finish_reason = choice.finish_reason
317-
318-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
319-
320-
for tool_deltas in tool_calls.values():
321-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
322-
323-
for tool_delta in tool_deltas:
324-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
325-
326-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
327-
328-
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
329-
330-
# Add usage information if available
331-
if hasattr(response, "usage"):
332-
yield self.format_chunk({"chunk_type": "metadata", "data": response.usage})
333-
else:
334-
# For streaming, use the streaming API
335-
response = await litellm.acompletion(**self.client_args, **litellm_request)
336-
337-
logger.debug("got streaming response from model")
338-
yield self.format_chunk({"chunk_type": "message_start"})
339-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
340-
341-
streaming_tool_calls: dict[int, list[Any]] = {}
342-
finish_reason = None
343-
344-
try:
345-
async for event in response:
346-
# Defensive: skip events with empty or missing choices
347-
if not getattr(event, "choices", None):
348-
continue
349-
choice = event.choices[0]
350-
351-
if choice.delta.content:
352-
yield self.format_chunk(
353-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
354-
)
355-
356-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
357-
yield self.format_chunk(
358-
{
359-
"chunk_type": "content_delta",
360-
"data_type": "reasoning_content",
361-
"data": choice.delta.reasoning_content,
362-
}
363-
)
364-
365-
for tool_call in choice.delta.tool_calls or []:
366-
streaming_tool_calls.setdefault(tool_call.index, []).append(tool_call)
367-
368-
if choice.finish_reason:
369-
finish_reason = choice.finish_reason
370-
break
371-
except Exception as e:
372-
logger.warning("Error processing streaming response: %s", e)
373-
374-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
375-
376-
# Process tool calls
377-
for tool_deltas in streaming_tool_calls.values():
378-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
379-
380-
for tool_delta in tool_deltas:
381-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
382-
383-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
384-
385-
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
386-
387-
try:
388-
last_event = None
389-
async for event in response:
390-
last_event = event
391-
392-
# Use the last event for usage information
393-
if last_event and hasattr(last_event, "usage"):
394-
yield self.format_chunk({"chunk_type": "metadata", "data": last_event.usage})
395-
except Exception:
396-
# If there's an error collecting remaining events, just continue
397-
pass
283+
try:
284+
if is_streaming:
285+
async for chunk in self._handle_streaming_response(litellm_request):
286+
yield chunk
287+
else:
288+
async for chunk in self._handle_non_streaming_response(litellm_request):
289+
yield chunk
290+
except ContextWindowExceededError as e:
291+
logger.warning("litellm client raised context window overflow")
292+
raise ContextWindowOverflowException(e) from e
398293

399294
logger.debug("finished processing response from model")
400295

@@ -481,6 +376,181 @@ async def _structured_output_using_tool(
481376
except (json.JSONDecodeError, TypeError, ValueError) as e:
482377
raise ValueError(f"Failed to parse or load content into model: {e}") from e
483378

379+
async def _process_choice_content(
380+
self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True
381+
) -> AsyncGenerator[tuple[str | None, StreamEvent], None]:
382+
"""Process content from a choice object (streaming or non-streaming).
383+
384+
Args:
385+
choice: The choice object from the response.
386+
data_type: Current data type being processed.
387+
tool_calls: Dictionary to collect tool calls.
388+
is_streaming: Whether this is from a streaming response.
389+
390+
Yields:
391+
Tuples of (updated_data_type, stream_event).
392+
"""
393+
# Get the content source - this is the only difference between streaming/non-streaming
394+
# We use duck typing here: both choice.delta and choice.message have the same interface
395+
# (reasoning_content, content, tool_calls attributes) but different object structures
396+
content_source = choice.delta if is_streaming else choice.message
397+
398+
# Process reasoning content
399+
if hasattr(content_source, "reasoning_content") and content_source.reasoning_content:
400+
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
401+
for chunk in chunks:
402+
yield data_type, chunk
403+
chunk = self.format_chunk(
404+
{
405+
"chunk_type": "content_delta",
406+
"data_type": "reasoning_content",
407+
"data": content_source.reasoning_content,
408+
}
409+
)
410+
yield data_type, chunk
411+
412+
# Process text content
413+
if hasattr(content_source, "content") and content_source.content:
414+
chunks, data_type = self._stream_switch_content("text", data_type)
415+
for chunk in chunks:
416+
yield data_type, chunk
417+
chunk = self.format_chunk(
418+
{
419+
"chunk_type": "content_delta",
420+
"data_type": "text",
421+
"data": content_source.content,
422+
}
423+
)
424+
yield data_type, chunk
425+
426+
# Process tool calls
427+
if hasattr(content_source, "tool_calls") and content_source.tool_calls:
428+
if is_streaming:
429+
# Streaming: tool calls have index attribute for out-of-order delivery
430+
for tool_call in content_source.tool_calls:
431+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
432+
else:
433+
# Non-streaming: tool calls arrive in order, use enumerated index
434+
for i, tool_call in enumerate(content_source.tool_calls):
435+
tool_calls.setdefault(i, []).append(tool_call)
436+
437+
async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]:
438+
"""Process and yield tool call events.
439+
440+
Args:
441+
tool_calls: Dictionary of tool calls indexed by their position.
442+
443+
Yields:
444+
Formatted tool call chunks.
445+
"""
446+
for tool_deltas in tool_calls.values():
447+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
448+
449+
for tool_delta in tool_deltas:
450+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
451+
452+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
453+
454+
async def _handle_non_streaming_response(
455+
self, litellm_request: dict[str, Any]
456+
) -> AsyncGenerator[StreamEvent, None]:
457+
"""Handle non-streaming response from LiteLLM.
458+
459+
Args:
460+
litellm_request: The formatted request for LiteLLM.
461+
462+
Yields:
463+
Formatted message chunks from the model.
464+
"""
465+
response = await litellm.acompletion(**self.client_args, **litellm_request)
466+
467+
logger.debug("got non-streaming response from model")
468+
yield self.format_chunk({"chunk_type": "message_start"})
469+
470+
tool_calls: dict[int, list[Any]] = {}
471+
data_type: str | None = None
472+
finish_reason: str | None = None
473+
474+
if hasattr(response, "choices") and response.choices and len(response.choices) > 0:
475+
choice = response.choices[0]
476+
477+
if hasattr(choice, "message") and choice.message:
478+
# Process content using shared logic
479+
async for updated_data_type, chunk in self._process_choice_content(
480+
choice, data_type, tool_calls, is_streaming=False
481+
):
482+
data_type = updated_data_type
483+
yield chunk
484+
485+
if hasattr(choice, "finish_reason"):
486+
finish_reason = choice.finish_reason
487+
488+
# Stop the current content block if we have one
489+
if data_type:
490+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
491+
492+
# Process tool calls
493+
async for chunk in self._process_tool_calls(tool_calls):
494+
yield chunk
495+
496+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
497+
498+
# Add usage information if available
499+
if hasattr(response, "usage"):
500+
yield self.format_chunk({"chunk_type": "metadata", "data": response.usage})
501+
502+
async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
503+
"""Handle streaming response from LiteLLM.
504+
505+
Args:
506+
litellm_request: The formatted request for LiteLLM.
507+
508+
Yields:
509+
Formatted message chunks from the model.
510+
"""
511+
# For streaming, use the streaming API
512+
response = await litellm.acompletion(**self.client_args, **litellm_request)
513+
514+
logger.debug("got response from model")
515+
yield self.format_chunk({"chunk_type": "message_start"})
516+
517+
tool_calls: dict[int, list[Any]] = {}
518+
data_type: str | None = None
519+
finish_reason: str | None = None
520+
521+
async for event in response:
522+
# Defensive: skip events with empty or missing choices
523+
if not getattr(event, "choices", None):
524+
continue
525+
choice = event.choices[0]
526+
527+
# Process content using shared logic
528+
async for updated_data_type, chunk in self._process_choice_content(
529+
choice, data_type, tool_calls, is_streaming=True
530+
):
531+
data_type = updated_data_type
532+
yield chunk
533+
534+
if choice.finish_reason:
535+
finish_reason = choice.finish_reason
536+
if data_type:
537+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
538+
break
539+
540+
# Process tool calls
541+
async for chunk in self._process_tool_calls(tool_calls):
542+
yield chunk
543+
544+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
545+
546+
# Skip remaining events as we don't have use for anything except the final usage payload
547+
async for event in response:
548+
_ = event
549+
if event.usage:
550+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
551+
552+
logger.debug("finished streaming response from model")
553+
484554
def _apply_proxy_prefix(self) -> None:
485555
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.
486556

0 commit comments

Comments
 (0)