1010
1111from __future__ import annotations
1212
13- import json
14- from typing import Any , Protocol , cast
13+ from typing import Any , Protocol
1514
1615from guidellm .schemas import GenerationRequest , GenerationResponse , UsageMetrics
17- from guidellm .utils import RegistryMixin
18-
19- try :
20- import orjson
21- except ImportError :
22- orjson = None # type: ignore[assignment]
16+ from guidellm .utils import RegistryMixin , json
2317
2418__all__ = [
2519 "AudioResponseHandler" ,
@@ -115,8 +109,7 @@ def compile_non_streaming(
115109 :param response: Complete API response containing choices and usage data
116110 :return: Standardized GenerationResponse with extracted text and metrics
117111 """
118- choices = cast ("list[dict]" , response .get ("choices" , []))
119- usage = cast ("dict[str, int | dict[str, int]]" , response .get ("usage" , {}))
112+ choices , usage = self .extract_choices_and_usage (response )
120113 input_metrics , output_metrics = self .extract_metrics (usage )
121114
122115 return GenerationResponse (
@@ -139,26 +132,17 @@ def add_streaming_line(self, line: str) -> int | None:
139132 :param line: Raw SSE line from the streaming response
140133 :return: 1 if text content was extracted, 0 if line ignored, None if done
141134 """
142- if line == " data: [DONE]" :
143- return None
135+ if not ( data := self . extract_line_data ( line )) :
136+ return None if data is None else 0
144137
145- if not line or not (line := line .strip ()) or not line .startswith ("data:" ):
146- return 0
147-
148- line = line [len ("data:" ) :].strip ()
149- data = cast (
150- "dict[str, Any]" ,
151- json .loads (line ) if orjson is None else orjson .loads (line ),
152- )
153138 updated = False
139+ choices , usage = self .extract_choices_and_usage (data )
154140
155- if (choices := cast ("list[dict]" , data .get ("choices" ))) and (
156- text := choices [0 ].get ("text" )
157- ):
141+ if text := choices [0 ].get ("text" ):
158142 self .streaming_texts .append (text )
159143 updated = True
160144
161- if usage := cast ( "dict[str, int | dict[str, int]]" , data . get ( "usage" )) :
145+ if usage :
162146 self .streaming_usage = usage
163147
164148 return 1 if updated else 0
@@ -182,6 +166,34 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse:
182166 output_metrics = output_metrics ,
183167 )
184168
169+ def extract_line_data (self , line : str ) -> dict [str , Any ] | None :
170+ """
171+ Extract JSON data from a streaming response line.
172+
173+ :param line: Raw line from the streaming response
174+ :return: Parsed JSON data as a dictionary, or None if line is invalid
175+ """
176+ if line == "data: [DONE]" :
177+ return None
178+
179+ if not line or not (line := line .strip ()) or not line .startswith ("data:" ):
180+ return {}
181+
182+ line = line [len ("data:" ) :].strip ()
183+
184+ return json .loads (line )
185+
186+ def extract_choices_and_usage (
187+ self , response : dict
188+ ) -> tuple [list [dict ], dict [str , int | dict [str , int ]]]:
189+ """
190+ Extract choices and usage data from the API response.
191+
192+ :param response: Complete API response containing choices and usage data
193+ :return: Tuple of (choices list, usage dictionary)
194+ """
195+ return response .get ("choices" , []), response .get ("usage" , {})
196+
185197 def extract_metrics (
186198 self , usage : dict [str , int | dict [str , int ]] | None
187199 ) -> tuple [UsageMetrics , UsageMetrics ]:
@@ -194,15 +206,14 @@ def extract_metrics(
194206 if not usage :
195207 return UsageMetrics (), UsageMetrics ()
196208
197- input_details = cast ( " dict[str, int]" , usage .get ("prompt_tokens_details" , {}))
198- output_details = cast (
199- "dict[str, int]" , usage .get ("completion_tokens_details" , {})
209+ input_details : dict [str , int ] = usage .get ("prompt_tokens_details" , {}) or {}
210+ output_details : dict [ str , int ] = (
211+ usage .get ("completion_tokens_details" , {}) or {}
200212 )
201213
202214 return UsageMetrics (
203215 text_tokens = (
204- input_details .get ("prompt_tokens" )
205- or cast ("int" , usage .get ("prompt_tokens" ))
216+ input_details .get ("prompt_tokens" ) or usage .get ("prompt_tokens" )
206217 ),
207218 image_tokens = input_details .get ("image_tokens" ),
208219 video_tokens = input_details .get ("video_tokens" ),
@@ -211,7 +222,7 @@ def extract_metrics(
211222 ), UsageMetrics (
212223 text_tokens = (
213224 output_details .get ("completion_tokens" )
214- or cast ( "int" , usage .get ("completion_tokens" ) )
225+ or usage .get ("completion_tokens" )
215226 ),
216227 image_tokens = output_details .get ("image_tokens" ),
217228 video_tokens = output_details .get ("video_tokens" ),
@@ -243,18 +254,15 @@ def compile_non_streaming(
243254 :param response: Complete API response containing choices and usage data
244255 :return: Standardized GenerationResponse with extracted content and metrics
245256 """
246- choices = cast ("list[dict]" , response .get ("choices" , []))
247- usage = cast ("dict[str, int | dict[str, int]]" , response .get ("usage" , {}))
257+ choices , usage = self .extract_choices_and_usage (response )
248258 input_metrics , output_metrics = self .extract_metrics (usage )
249259
250260 return GenerationResponse (
251261 request_id = request .request_id ,
252262 request_args = str (
253263 request .arguments .model_dump () if request .arguments else None
254264 ),
255- text = cast ("dict" , choices [0 ].get ("message" , {})).get ("content" , "" )
256- if choices
257- else "" ,
265+ text = (choices [0 ].get ("message" , {}).get ("content" , "" ) if choices else "" ),
258266 input_metrics = input_metrics ,
259267 output_metrics = output_metrics ,
260268 )
@@ -269,27 +277,17 @@ def add_streaming_line(self, line: str) -> int | None:
269277 :param line: Raw SSE line from the streaming response
270278 :return: 1 if content was extracted, 0 if line ignored, None if done
271279 """
272- if line == " data: [DONE]" :
273- return None
280+ if not ( data := self . extract_line_data ( line )) :
281+ return None if data is None else 0
274282
275- if not line or not (line := line .strip ()) or not line .startswith ("data:" ):
276- return 0
277-
278- line = line [len ("data:" ) :].strip ()
279- data = cast (
280- "dict[str, Any]" ,
281- json .loads (line ) if orjson is None else orjson .loads (line ),
282- )
283283 updated = False
284+ choices , usage = self .extract_choices_and_usage (data )
284285
285- # Extract delta content for chat completion chunks
286- if choices := cast ("list[dict]" , data .get ("choices" )):
287- delta = choices [0 ].get ("delta" , {})
288- if content := delta .get ("content" ):
289- self .streaming_texts .append (content )
286+ if choices and (content := choices [0 ].get ("delta" , {}).get ("content" )):
287+ self .streaming_texts .append (content )
290288 updated = True
291289
292- if usage := cast ( "dict[str, int | dict[str, int]]" , data . get ( "usage" )) :
290+ if usage :
293291 self .streaming_usage = usage
294292
295293 return 1 if updated else 0
@@ -355,10 +353,10 @@ def compile_non_streaming(
355353 :param response: Complete API response containing text and usage data
356354 :return: Standardized GenerationResponse with extracted text and metrics
357355 """
358- usage = cast ( " dict[str, int]" , response .get ("usage" , {}) )
359- input_details = cast ( " dict[str, int]" , usage .get ("input_token_details" , {}))
360- output_details = cast ( " dict[str, int]" , usage .get ("output_token_details" , {}))
361- text = response .get ("text" , "" )
356+ usage : dict [ str , int | dict [str , int ]] = response .get ("usage" , {})
357+ input_details : dict [str , int ] = usage .get ("input_token_details" , {}) or {}
358+ output_details : dict [str , int ] = usage .get ("output_token_details" , {}) or {}
359+ text : str = response .get ("text" , "" )
362360
363361 return GenerationResponse (
364362 request_id = request .request_id ,
@@ -396,17 +394,16 @@ def add_streaming_line(self, line: str) -> int | None:
396394 if not line or not (line := line .strip ()) or not line .startswith ("{" ):
397395 return 0
398396
399- data = cast (
400- "dict[str, Any]" ,
401- json .loads (line ) if orjson is None else orjson .loads (line ),
402- )
397+ data : dict [str , Any ] = json .loads (line )
398+ text : str
399+ usage : dict [str , int | dict [str , int ]]
403400 updated = False
404401
405402 if text := data .get ("text" ):
406403 self .streaming_texts .append (text )
407404 updated = True
408405
409- if usage := cast ( "dict[str, int | dict[str, int]]" , data .get ("usage" ) ):
406+ if usage := data .get ("usage" ):
410407 self .streaming_usage = usage
411408
412409 return 1 if updated else 0
@@ -445,22 +442,15 @@ def extract_metrics(
445442 if not usage :
446443 return UsageMetrics (), UsageMetrics ()
447444
448- input_details = cast ( " dict[str, int]" , usage .get ("input_token_details" , {}))
449- output_details = cast ( " dict[str, int]" , usage .get ("output_token_details" , {}))
445+ input_details : dict [str , int ] = usage .get ("input_token_details" , {}) or {}
446+ output_details : dict [str , int ] = usage .get ("output_token_details" , {}) or {}
450447
451448 return UsageMetrics (
452- text_tokens = (
453- input_details .get ("text_tokens" )
454- or cast ("int" , usage .get ("input_tokens" ))
455- ),
449+ text_tokens = (input_details .get ("text_tokens" ) or usage .get ("input_tokens" )),
456450 audio_tokens = (
457- input_details .get ("audio_tokens" )
458- or cast ("int" , usage .get ("audio_tokens" ))
459- ),
460- audio_seconds = (
461- input_details .get ("seconds" ) or cast ("int" , usage .get ("seconds" ))
451+ input_details .get ("audio_tokens" ) or usage .get ("audio_tokens" )
462452 ),
453+ audio_seconds = (input_details .get ("seconds" ) or usage .get ("seconds" )),
463454 ), UsageMetrics (
464- text_tokens = output_details .get ("text_tokens" )
465- or cast ("int" , usage .get ("output_tokens" )),
455+ text_tokens = output_details .get ("text_tokens" ) or usage .get ("output_tokens" ),
466456 )
0 commit comments