55import queue
66import ssl
77import threading
8- from typing import Any , Dict , List , Optional , Union
8+ from typing import Any , Dict , List , Literal , Optional , Union
99
1010import websockets
1111from asgiref .sync import sync_to_async
1212from pydantic import BaseModel , ConfigDict
1313
14+ from llmstack .apps .types .agent import AgentConfigSchema
15+ from llmstack .apps .types .voice_agent import VoiceAgentConfigSchema
1416from llmstack .common .blocks .base .schema import StrEnum
1517from llmstack .common .utils .liquid import render_template
1618from llmstack .common .utils .provider_config import get_matched_provider_config
1719from llmstack .common .utils .sslr .types .chat .chat_completion import ChatCompletion
1820from llmstack .common .utils .sslr .types .chat .chat_completion_chunk import (
1921 ChatCompletionChunk ,
2022)
21- from llmstack .processors .providers .config import ProviderConfig
2223from llmstack .processors .providers .promptly import get_llm_client_from_provider_config
2324
2425logger = logging .getLogger (__name__ )
2526
2627
2728class AgentControllerConfig (BaseModel ):
2829 provider_configs : Dict [str , Any ]
29- provider_config : ProviderConfig
30- provider_slug : str
31- model_slug : str
32- system_message : str
30+ agent_config : Union [AgentConfigSchema , VoiceAgentConfigSchema ]
31+ is_voice_agent : bool = False
3332 tools : List [Dict ]
34- stream : bool = False
35- realtime : bool = False
36- max_steps : int = 30
3733 metadata : Dict [str , Any ]
38- model_config = ConfigDict (protected_namespaces = ())
34+
35+ model_config = ConfigDict (arbitrary_types_allowed = True )
36+
37+ def __init__ (self , ** data ):
38+ # Convert agent_config to correct type if needed
39+ if "agent_config" in data :
40+ config = data ["agent_config" ]
41+ if isinstance (config , dict ):
42+ if data .get ("is_voice_agent" , False ):
43+ data ["agent_config" ] = VoiceAgentConfigSchema (** config )
44+ else :
45+ data ["agent_config" ] = AgentConfigSchema (** config )
46+
47+ super ().__init__ (** data )
48+
49+ if self .is_voice_agent and not isinstance (self .agent_config , VoiceAgentConfigSchema ):
50+ raise ValueError ("agent_config must be VoiceAgentConfigSchema when is_voice_agent is True" )
51+ elif not self .is_voice_agent and not isinstance (self .agent_config , AgentConfigSchema ):
52+ raise ValueError ("agent_config must be AgentConfigSchema when is_voice_agent is False" )
3953
4054
4155class AgentControllerDataType (StrEnum ):
@@ -54,6 +68,8 @@ class AgentUsageData(BaseModel):
5468 prompt_tokens : int = 0
5569 completion_tokens : int = 0
5670 total_tokens : int = 0
71+ provider : str = ""
72+ source : str = ""
5773
5874
5975class AgentMessageRole (StrEnum ):
@@ -116,19 +132,10 @@ class AgentController:
116132 def __init__ (self , output_queue : asyncio .Queue , config : AgentControllerConfig ):
117133 self ._output_queue = output_queue
118134 self ._config = config
119- self ._messages : List [AgentMessage ] = [
120- AgentSystemMessage (
121- role = AgentMessageRole .SYSTEM ,
122- content = [
123- AgentMessageContent (
124- type = AgentMessageContentType .TEXT ,
125- data = render_template (self ._config .system_message , {}),
126- )
127- ],
128- )
129- ]
135+ self ._messages : List [AgentMessage ] = []
130136 self ._llm_client = None
131137 self ._websocket = None
138+ self ._provider_config = None
132139
133140 self ._input_text_stream = None
134141 self ._input_audio_stream = None
@@ -154,11 +161,16 @@ async def _handle_websocket_messages(self):
154161 if event ["type" ] == "session.created" :
155162 logger .info (f"Session created: { event ['session' ]['id' ]} " )
156163 session = {}
157- session ["instructions" ] = self ._config .system_message
164+ session ["instructions" ] = self ._config .agent_config . system_message
158165 session ["tools" ] = [
159166 {"type" : "function" , ** t ["function" ]} for t in self ._config .tools if t ["type" ] == "function"
160167 ]
161168
169+ if self ._config .agent_config .input_audio_format :
170+ session ["input_audio_format" ] = self ._config .agent_config .input_audio_format
171+ if self ._config .agent_config .output_audio_format :
172+ session ["output_audio_format" ] = self ._config .agent_config .output_audio_format
173+
162174 updated_session = {
163175 "type" : "session.update" ,
164176 "session" : session ,
@@ -173,6 +185,12 @@ async def _init_websocket_connection(self):
173185 from llmstack .apps .models import AppSessionFiles
174186 from llmstack .assets .stream import AssetStream
175187
188+ self ._provider_config = get_matched_provider_config (
189+ provider_configs = self ._config .provider_configs ,
190+ provider_slug = self ._config .agent_config .backend .provider ,
191+ model_slug = self ._config .agent_config .backend .model ,
192+ )
193+
176194 # Create the output streams
177195 self ._output_audio_stream = AssetStream (
178196 await sync_to_async (AppSessionFiles .create_streaming_asset )(
@@ -191,9 +209,9 @@ async def _init_websocket_connection(self):
191209 ssl_context .check_hostname = False
192210 ssl_context .verify_mode = ssl .CERT_NONE
193211
194- websocket_url = f"wss://api.openai.com/v1/realtime?model={ self ._config .model_slug } "
212+ websocket_url = f"wss://api.openai.com/v1/realtime?model={ self ._config .agent_config . backend . model } "
195213 headers = {
196- "Authorization" : f"Bearer { self ._config . provider_config .api_key } " ,
214+ "Authorization" : f"Bearer { self ._provider_config .api_key } " ,
197215 "OpenAI-Beta" : "realtime=v1" ,
198216 }
199217
@@ -208,16 +226,34 @@ async def _init_websocket_connection(self):
208226 self ._loop .create_task (self ._handle_websocket_messages ())
209227
210228 def _init_llm_client (self ):
229+ self ._provider_config = get_matched_provider_config (
230+ provider_configs = self ._config .provider_configs ,
231+ provider_slug = self ._config .agent_config .provider ,
232+ model_slug = self ._config .agent_config .model ,
233+ )
234+
211235 self ._llm_client = get_llm_client_from_provider_config (
212- self ._config .provider_slug ,
213- self ._config .model_slug ,
236+ self ._config .agent_config . provider ,
237+ self ._config .agent_config . model ,
214238 lambda provider_slug , model_slug : get_matched_provider_config (
215239 provider_configs = self ._config .provider_configs ,
216240 provider_slug = provider_slug ,
217241 model_slug = model_slug ,
218242 ),
219243 )
220244
245+ self ._messages .append (
246+ AgentSystemMessage (
247+ role = AgentMessageRole .SYSTEM ,
248+ content = [
249+ AgentMessageContent (
250+ type = AgentMessageContentType .TEXT ,
251+ data = render_template (self ._config .agent_config .system_message , {}),
252+ )
253+ ],
254+ )
255+ )
256+
221257 async def _process_input_audio_stream (self ):
222258 if self ._input_audio_stream :
223259 async for chunk in self ._input_audio_stream .read_async ():
@@ -317,8 +353,8 @@ def process(self, data: AgentControllerData):
317353 self ._messages .append (data .data )
318354
319355 try :
320- if len (self ._messages ) > self ._config .max_steps :
321- raise Exception (f"Max steps ({ self ._config .max_steps } ) exceeded: { len (self ._messages )} " )
356+ if len (self ._messages ) > self ._config .agent_config . max_steps :
357+ raise Exception (f"Max steps ({ self ._config .agent_config . max_steps } ) exceeded: { len (self ._messages )} " )
322358
323359 if data .type != AgentControllerDataType .AGENT_OUTPUT :
324360 self ._input_messages_queue .put (data )
@@ -334,7 +370,7 @@ def process(self, data: AgentControllerData):
334370 )
335371
336372 async def process_messages (self , data : AgentControllerData ):
337- if self ._config .realtime :
373+ if self ._config .is_voice_agent and self . _config . agent_config . backend . backend_type == Literal [ "multi_modal" ] :
338374 if not self ._websocket :
339375 await self ._init_websocket_connection ()
340376
@@ -391,14 +427,15 @@ async def process_messages(self, data: AgentControllerData):
391427 self ._init_llm_client ()
392428
393429 client_messages = self ._convert_messages_to_llm_client_format ()
430+ stream = True if self ._config .agent_config .stream is None else self ._config .agent_config .stream
394431 response = self ._llm_client .chat .completions .create (
395- model = self ._config .model_slug ,
432+ model = self ._config .agent_config . model ,
396433 messages = client_messages ,
397- stream = self . _config . stream ,
434+ stream = stream ,
398435 tools = self ._config .tools ,
399436 )
400437
401- if self . _config . stream :
438+ if stream :
402439 for chunk in response :
403440 self .add_llm_client_response_to_output_queue (chunk )
404441 else :
@@ -419,6 +456,8 @@ def add_llm_client_response_to_output_queue(self, response: Any):
419456 prompt_tokens = response .usage .input_tokens ,
420457 completion_tokens = response .usage .output_tokens ,
421458 total_tokens = response .usage .total_tokens ,
459+ source = self ._provider_config .provider_config_source ,
460+ provider = str (self ._provider_config ),
422461 ),
423462 )
424463 )
@@ -621,6 +660,10 @@ async def add_ws_event_to_output_queue(self, event: Any):
621660 type = AgentControllerDataType .INPUT_STREAM ,
622661 )
623662 )
663+ elif event_type == "input_audio_buffer.speech_stopped" :
664+ pass
665+ elif event_type == "conversation.item.input_audio_transcription.completed" :
666+ pass
624667 elif event_type == "error" :
625668 logger .error (f"WebSocket error: { event } " )
626669
0 commit comments