1212
1313import asyncio
1414import time
15- from collections .abc import AsyncIterator
15+ from collections .abc import AsyncIterator , Callable
1616from typing import Any
1717
1818import httpx
1919
2020from guidellm .backends .backend import Backend
2121from guidellm .backends .response_handlers import GenerationResponseHandlerFactory
22- from guidellm .schemas import GenerationRequest , GenerationResponse , RequestInfo
22+ from guidellm .schemas import (
23+ GenerationRequest ,
24+ GenerationRequestArguments ,
25+ GenerationResponse ,
26+ RequestInfo ,
27+ )
2328
2429__all__ = ["OpenAIHTTPBackend" ]
2530
@@ -59,6 +64,10 @@ def __init__(
5964 follow_redirects : bool = True ,
6065 verify : bool = False ,
6166 validate_backend : bool | str | dict [str , Any ] = True ,
67+ stream : bool = True ,
68+ extras : dict [str , Any ] | GenerationRequestArguments | None = None ,
69+ max_tokens : int | None = None ,
70+ max_completion_tokens : int | None = None ,
6271 ):
6372 """
6473 Initialize OpenAI HTTP backend with server configuration.
@@ -96,11 +105,28 @@ def __init__(
96105 self .validate_backend : dict [str , Any ] | None = self ._resolve_validate_kwargs (
97106 validate_backend
98107 )
108+ self .stream : bool = stream
109+ self .extras = (
110+ GenerationRequestArguments (** extras )
111+ if extras and isinstance (extras , dict )
112+ else extras
113+ )
114+ self .max_tokens : int | None = max_tokens or max_completion_tokens
99115
100116 # Runtime state
101117 self ._in_process = False
102118 self ._async_client : httpx .AsyncClient | None = None
103119
120+ # TODO: Find a better way to register formatters
121+ self .request_formatters : dict [
122+ str , Callable [[GenerationRequest ], GenerationRequestArguments ]
123+ ] = {
124+ "text_completions" : self .formatter_text_completions ,
125+ "chat_completions" : self .formatter_chat_completions ,
126+ "audio_transcriptions" : self .formatter_audio_transcriptions ,
127+ "audio_translations" : self .formatter_audio_transcriptions ,
128+ }
129+
104130 @property
105131 def info (self ) -> dict [str , Any ]:
106132 """
@@ -227,31 +253,35 @@ async def resolve( # type: ignore[override]
227253 if history is not None :
228254 raise NotImplementedError ("Multi-turn requests not yet supported" )
229255
256+ arguments : GenerationRequestArguments = self .request_formatters [
257+ request .request_type
258+ ](request )
259+
230260 if (request_path := self .api_routes .get (request .request_type )) is None :
231261 raise ValueError (f"Unsupported request type '{ request .request_type } '" )
232262
233263 request_url = f"{ self .target } /{ request_path } "
234264 request_files = (
235265 {
236266 key : tuple (value ) if isinstance (value , list ) else value
237- for key , value in request . arguments .files .items ()
267+ for key , value in arguments .files .items ()
238268 }
239- if request . arguments .files
269+ if arguments .files
240270 else None
241271 )
242- request_json = request . arguments .body if not request_files else None
243- request_data = request . arguments .body if request_files else None
272+ request_json = arguments .body if not request_files else None
273+ request_data = arguments .body if request_files else None
244274 response_handler = GenerationResponseHandlerFactory .create (
245275 request .request_type , handler_overrides = self .response_handlers
246276 )
247277
248- if not request . arguments .stream :
278+ if not arguments .stream :
249279 request_info .timings .request_start = time .time ()
250280 response = await self ._async_client .request (
251- request . arguments .method or "POST" ,
281+ arguments .method or "POST" ,
252282 request_url ,
253- params = request . arguments .params ,
254- headers = request . arguments .headers ,
283+ params = arguments .params ,
284+ headers = arguments .headers ,
255285 json = request_json ,
256286 data = request_data ,
257287 files = request_files ,
@@ -266,10 +296,10 @@ async def resolve( # type: ignore[override]
266296 request_info .timings .request_start = time .time ()
267297
268298 async with self ._async_client .stream (
269- request . arguments .method or "POST" ,
299+ arguments .method or "POST" ,
270300 request_url ,
271- params = request . arguments .params ,
272- headers = request . arguments .headers ,
301+ params = arguments .params ,
302+ headers = arguments .headers ,
273303 json = request_json ,
274304 data = request_data ,
275305 files = request_files ,
@@ -332,3 +362,177 @@ def _resolve_validate_kwargs(
332362 validate_kwargs ["method" ] = "GET"
333363
334364 return validate_kwargs
365+
366+ def formatter_text_completions (
367+ self , data : GenerationRequest
368+ ) -> GenerationRequestArguments :
369+ arguments : GenerationRequestArguments = GenerationRequestArguments ()
370+ arguments .body = {} # The type checker works better setting this field here
371+
372+ # Add model
373+ if self .model is not None :
374+ arguments .body ["model" ] = self .model
375+
376+ # Configure streaming
377+ if self .stream :
378+ arguments .stream = True
379+ arguments .body ["stream" ] = True
380+ arguments .body ["stream_options" ] = {"include_usage" : True }
381+
382+ # Handle output tokens
383+ if data .output_metrics .text_tokens :
384+ arguments .body ["max_tokens" ] = data .output_metrics .text_tokens
385+ arguments .body ["stop" ] = None
386+ arguments .body ["ignore_eos" ] = True
387+ elif self .max_tokens is not None :
388+ arguments .body ["max_tokens" ] = self .max_tokens
389+
390+ # Apply extra arguments
391+ if self .extras :
392+ arguments .model_combine (self .extras )
393+
394+ # Build prompt
395+ prefix = "" .join (pre for pre in data .columns .get ("prefix_column" , []) if pre )
396+ text = "" .join (txt for txt in data .columns .get ("text_column" , []) if txt )
397+ if prefix or text :
398+ prompt = prefix + text
399+ arguments .body ["prompt" ] = prompt
400+
401+ return arguments
402+
403+ def formatter_chat_completions ( # noqa: C901, PLR0912, PLR0915
404+ self , data : GenerationRequest
405+ ) -> GenerationRequestArguments :
406+ arguments = GenerationRequestArguments ()
407+ arguments .body = {} # The type checker works best with body assigned here
408+
409+ # Add model
410+ if self .model is not None :
411+ arguments .body ["model" ] = self .model
412+
413+ # Configure streaming
414+ if self .stream :
415+ arguments .stream = True
416+ arguments .body ["stream" ] = True
417+ arguments .body ["stream_options" ] = {"include_usage" : True }
418+
419+ # Handle output tokens
420+ if data .output_metrics .text_tokens :
421+ arguments .body .update (
422+ {
423+ "max_completion_tokens" : data .output_metrics .text_tokens ,
424+ "stop" : None ,
425+ "ignore_eos" : True ,
426+ }
427+ )
428+ elif self .max_tokens is not None :
429+ arguments .body ["max_completion_tokens" ] = self .max_tokens
430+
431+ # Apply extra arguments
432+ if self .extras :
433+ arguments .model_combine (self .extras )
434+
435+ # Build messages
436+ arguments .body ["messages" ] = []
437+
438+ for prefix in data .columns .get ("prefix_column" , []):
439+ if not prefix :
440+ continue
441+
442+ arguments .body ["messages" ].append ({"role" : "system" , "content" : prefix })
443+
444+ for text in data .columns .get ("text_column" , []):
445+ if not text :
446+ continue
447+
448+ arguments .body ["messages" ].append (
449+ {"role" : "user" , "content" : [{"type" : "text" , "text" : text }]}
450+ )
451+
452+ for image in data .columns .get ("image_column" , []):
453+ if not image :
454+ continue
455+
456+ arguments .body ["messages" ].append (
457+ {
458+ "role" : "user" ,
459+ "content" : [{"type" : "image_url" , "image_url" : image .get ("image" )}],
460+ }
461+ )
462+
463+ for video in data .columns .get ("video_column" , []):
464+ if not video :
465+ continue
466+
467+ arguments .body ["messages" ].append (
468+ {
469+ "role" : "user" ,
470+ "content" : [{"type" : "video_url" , "video_url" : video .get ("video" )}],
471+ }
472+ )
473+
474+ for audio in data .columns .get ("audio_column" , []):
475+ if not audio :
476+ continue
477+
478+ arguments .body ["messages" ].append (
479+ {
480+ "role" : "user" ,
481+ "content" : [
482+ {
483+ "type" : "input_audio" ,
484+ "input_audio" : {
485+ "data" : audio .get ("audio" ),
486+ "format" : audio .get ("format" ),
487+ },
488+ }
489+ ],
490+ }
491+ )
492+
493+ return arguments
494+
495+ def formatter_audio_transcriptions ( # noqa: C901
496+ self , data : GenerationRequest
497+ ) -> GenerationRequestArguments :
498+ arguments = GenerationRequestArguments (files = {})
499+ arguments .body = {}
500+
501+ # Add model
502+ if self .model is not None :
503+ arguments .body ["model" ] = self .model
504+
505+ # Configure streaming
506+ if self .stream :
507+ arguments .stream = True
508+ arguments .body ["stream" ] = True
509+ arguments .body ["stream_options" ] = {"include_usage" : True }
510+
511+ # Apply extra arguments
512+ if self .extras :
513+ arguments .model_combine (self .extras )
514+
515+ # Build audio input
516+ audio_columns = data .columns .get ("audio_column" , [])
517+ if len (audio_columns ) != 1 :
518+ raise ValueError (
519+ f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
520+ f"one audio column, but got { len (audio_columns )} ."
521+ )
522+
523+ arguments .files = {
524+ "file" : (
525+ audio_columns [0 ].get ("file_name" , "audio_input" ),
526+ audio_columns [0 ].get ("audio" ),
527+ audio_columns [0 ].get ("mimetype" ),
528+ )
529+ }
530+
531+ # Build prompt
532+ prefix = "" .join (pre for pre in data .columns .get ("prefix_column" , []) if pre )
533+ text = "" .join (txt for txt in data .columns .get ("text_column" , []) if txt )
534+ if prefix or text :
535+ prompt = prefix + text
536+ arguments .body ["prompt" ] = prompt
537+
538+ return arguments
0 commit comments