Skip to content

Commit 2f04bc0

Browse files
schleidlDaniel Schleicherdbschmigelski
authored
feat(litellm): handle litellm non streaming responses (#512)
--------- Co-authored-by: Daniel Schleicher <dschlei@amazon.de> Co-authored-by: Dean Schmigelski <dbschmigelski+github@gmail.com>
1 parent 1e27d79 commit 2f04bc0

File tree

3 files changed

+461
-77
lines changed

3 files changed

+461
-77
lines changed

src/strands/models/litellm.py

Lines changed: 191 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -269,75 +269,29 @@ async def stream(
269269
)
270270
logger.debug("request=<%s>", request)
271271

272-
logger.debug("invoking model")
273-
try:
274-
if kwargs.get("stream") is False:
275-
raise ValueError("stream parameter cannot be explicitly set to False")
276-
response = await litellm.acompletion(**self.client_args, **request)
277-
except ContextWindowExceededError as e:
278-
logger.warning("litellm client raised context window overflow")
279-
raise ContextWindowOverflowException(e) from e
272+
# Check if streaming is disabled in the params
273+
config = self.get_config()
274+
params = config.get("params") or {}
275+
is_streaming = params.get("stream", True)
280276

281-
logger.debug("got response from model")
282-
yield self.format_chunk({"chunk_type": "message_start"})
277+
litellm_request = {**request}
283278

284-
tool_calls: dict[int, list[Any]] = {}
285-
data_type: str | None = None
279+
litellm_request["stream"] = is_streaming
286280

287-
async for event in response:
288-
# Defensive: skip events with empty or missing choices
289-
if not getattr(event, "choices", None):
290-
continue
291-
choice = event.choices[0]
281+
logger.debug("invoking model with stream=%s", litellm_request.get("stream"))
292282

293-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
294-
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
295-
for chunk in chunks:
283+
try:
284+
if is_streaming:
285+
async for chunk in self._handle_streaming_response(litellm_request):
296286
yield chunk
297-
298-
yield self.format_chunk(
299-
{
300-
"chunk_type": "content_delta",
301-
"data_type": data_type,
302-
"data": choice.delta.reasoning_content,
303-
}
304-
)
305-
306-
if choice.delta.content:
307-
chunks, data_type = self._stream_switch_content("text", data_type)
308-
for chunk in chunks:
287+
else:
288+
async for chunk in self._handle_non_streaming_response(litellm_request):
309289
yield chunk
290+
except ContextWindowExceededError as e:
291+
logger.warning("litellm client raised context window overflow")
292+
raise ContextWindowOverflowException(e) from e
310293

311-
yield self.format_chunk(
312-
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
313-
)
314-
315-
for tool_call in choice.delta.tool_calls or []:
316-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
317-
318-
if choice.finish_reason:
319-
if data_type:
320-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
321-
break
322-
323-
for tool_deltas in tool_calls.values():
324-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
325-
326-
for tool_delta in tool_deltas:
327-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
328-
329-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
330-
331-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
332-
333-
# Skip remaining events as we don't have use for anything except the final usage payload
334-
async for event in response:
335-
_ = event
336-
337-
if event.usage:
338-
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
339-
340-
logger.debug("finished streaming response from model")
294+
logger.debug("finished processing response from model")
341295

342296
@override
343297
async def structured_output(
@@ -422,6 +376,181 @@ async def _structured_output_using_tool(
422376
except (json.JSONDecodeError, TypeError, ValueError) as e:
423377
raise ValueError(f"Failed to parse or load content into model: {e}") from e
424378

379+
async def _process_choice_content(
380+
self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True
381+
) -> AsyncGenerator[tuple[str | None, StreamEvent], None]:
382+
"""Process content from a choice object (streaming or non-streaming).
383+
384+
Args:
385+
choice: The choice object from the response.
386+
data_type: Current data type being processed.
387+
tool_calls: Dictionary to collect tool calls.
388+
is_streaming: Whether this is from a streaming response.
389+
390+
Yields:
391+
Tuples of (updated_data_type, stream_event).
392+
"""
393+
# Get the content source - this is the only difference between streaming/non-streaming
394+
# We use duck typing here: both choice.delta and choice.message have the same interface
395+
# (reasoning_content, content, tool_calls attributes) but different object structures
396+
content_source = choice.delta if is_streaming else choice.message
397+
398+
# Process reasoning content
399+
if hasattr(content_source, "reasoning_content") and content_source.reasoning_content:
400+
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
401+
for chunk in chunks:
402+
yield data_type, chunk
403+
chunk = self.format_chunk(
404+
{
405+
"chunk_type": "content_delta",
406+
"data_type": "reasoning_content",
407+
"data": content_source.reasoning_content,
408+
}
409+
)
410+
yield data_type, chunk
411+
412+
# Process text content
413+
if hasattr(content_source, "content") and content_source.content:
414+
chunks, data_type = self._stream_switch_content("text", data_type)
415+
for chunk in chunks:
416+
yield data_type, chunk
417+
chunk = self.format_chunk(
418+
{
419+
"chunk_type": "content_delta",
420+
"data_type": "text",
421+
"data": content_source.content,
422+
}
423+
)
424+
yield data_type, chunk
425+
426+
# Process tool calls
427+
if hasattr(content_source, "tool_calls") and content_source.tool_calls:
428+
if is_streaming:
429+
# Streaming: tool calls have index attribute for out-of-order delivery
430+
for tool_call in content_source.tool_calls:
431+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
432+
else:
433+
# Non-streaming: tool calls arrive in order, use enumerated index
434+
for i, tool_call in enumerate(content_source.tool_calls):
435+
tool_calls.setdefault(i, []).append(tool_call)
436+
437+
async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]:
438+
"""Process and yield tool call events.
439+
440+
Args:
441+
tool_calls: Dictionary of tool calls indexed by their position.
442+
443+
Yields:
444+
Formatted tool call chunks.
445+
"""
446+
for tool_deltas in tool_calls.values():
447+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
448+
449+
for tool_delta in tool_deltas:
450+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
451+
452+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
453+
454+
async def _handle_non_streaming_response(
455+
self, litellm_request: dict[str, Any]
456+
) -> AsyncGenerator[StreamEvent, None]:
457+
"""Handle non-streaming response from LiteLLM.
458+
459+
Args:
460+
litellm_request: The formatted request for LiteLLM.
461+
462+
Yields:
463+
Formatted message chunks from the model.
464+
"""
465+
response = await litellm.acompletion(**self.client_args, **litellm_request)
466+
467+
logger.debug("got non-streaming response from model")
468+
yield self.format_chunk({"chunk_type": "message_start"})
469+
470+
tool_calls: dict[int, list[Any]] = {}
471+
data_type: str | None = None
472+
finish_reason: str | None = None
473+
474+
if hasattr(response, "choices") and response.choices and len(response.choices) > 0:
475+
choice = response.choices[0]
476+
477+
if hasattr(choice, "message") and choice.message:
478+
# Process content using shared logic
479+
async for updated_data_type, chunk in self._process_choice_content(
480+
choice, data_type, tool_calls, is_streaming=False
481+
):
482+
data_type = updated_data_type
483+
yield chunk
484+
485+
if hasattr(choice, "finish_reason"):
486+
finish_reason = choice.finish_reason
487+
488+
# Stop the current content block if we have one
489+
if data_type:
490+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
491+
492+
# Process tool calls
493+
async for chunk in self._process_tool_calls(tool_calls):
494+
yield chunk
495+
496+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
497+
498+
# Add usage information if available
499+
if hasattr(response, "usage"):
500+
yield self.format_chunk({"chunk_type": "metadata", "data": response.usage})
501+
502+
async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
503+
"""Handle streaming response from LiteLLM.
504+
505+
Args:
506+
litellm_request: The formatted request for LiteLLM.
507+
508+
Yields:
509+
Formatted message chunks from the model.
510+
"""
511+
# For streaming, use the streaming API
512+
response = await litellm.acompletion(**self.client_args, **litellm_request)
513+
514+
logger.debug("got response from model")
515+
yield self.format_chunk({"chunk_type": "message_start"})
516+
517+
tool_calls: dict[int, list[Any]] = {}
518+
data_type: str | None = None
519+
finish_reason: str | None = None
520+
521+
async for event in response:
522+
# Defensive: skip events with empty or missing choices
523+
if not getattr(event, "choices", None):
524+
continue
525+
choice = event.choices[0]
526+
527+
# Process content using shared logic
528+
async for updated_data_type, chunk in self._process_choice_content(
529+
choice, data_type, tool_calls, is_streaming=True
530+
):
531+
data_type = updated_data_type
532+
yield chunk
533+
534+
if choice.finish_reason:
535+
finish_reason = choice.finish_reason
536+
if data_type:
537+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
538+
break
539+
540+
# Process tool calls
541+
async for chunk in self._process_tool_calls(tool_calls):
542+
yield chunk
543+
544+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
545+
546+
# Skip remaining events as we don't have use for anything except the final usage payload
547+
async for event in response:
548+
_ = event
549+
if event.usage:
550+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
551+
552+
logger.debug("finished streaming response from model")
553+
425554
def _apply_proxy_prefix(self) -> None:
426555
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.
427556

0 commit comments

Comments
 (0)