@@ -139,6 +139,8 @@ def __init__(
139139 self .last_first_text_delay = None
140140 self .last_first_audio_delay = None
141141 self .metrics = []
142+ # 添加用于同步等待连接关闭的事件
143+ self .disconnect_event = None
142144
143145 def _generate_event_id (self ):
144146 '''
@@ -296,6 +298,49 @@ def update_session(self,
296298 'session' : self .config
297299 }))
298300
301+ def end_session (self , timeout : int = 20 ) -> None :
302+ """
303+ end session
304+
305+ Parameters:
306+ -----------
307+ timeout: int
308+ Timeout in seconds to wait for the session to end. Default is 20 seconds.
309+ """
310+ if self .disconnect_event is not None :
311+ # if the event is already set, do nothing
312+ return
313+
314+ # create the event
315+ self .disconnect_event = threading .Event ()
316+
317+ self .__send_str (
318+ json .dumps ({
319+ 'event_id' : self ._generate_event_id (),
320+ 'type' : 'session.finish'
321+ }))
322+
323+ # wait for the event to be set
324+ finish_success = self .disconnect_event .wait (timeout )
325+ # clear the event
326+ self .disconnect_event = None
327+
328+ # if the event is not set, close the connection
329+ if not finish_success :
330+ self .close ()
331+ raise TimeoutError ("Session end timeout after {} seconds" .format (timeout ))
332+
333+ def end_session_async (self , ) -> None :
334+ """
335+ end session asynchronously. you need close the connection manually
336+ """
337+ # 发送结束会话消息
338+ self .__send_str (
339+ json .dumps ({
340+ 'event_id' : self ._generate_event_id (),
341+ 'type' : 'session.finish'
342+ }))
343+
299344 def append_audio (self , audio_b64 : str ) -> None :
300345 '''
301346 send audio in base64 format
@@ -414,7 +459,13 @@ def on_message(self, ws, message):
414459 self .callback .on_event (json_data )
415460 if 'type' in message :
416461 if 'session.created' == json_data ['type' ]:
462+ logger .info ('[omni realtime] session created' )
417463 self .session_id = json_data ['session' ]['id' ]
464+ elif 'session.finished' == json_data ['type' ]:
465+ # wait for the event to be set
466+ logger .info ('[omni realtime] session finished' )
467+ if self .disconnect_event is not None :
468+ self .disconnect_event .set ()
418469 if 'response.created' == json_data ['type' ]:
419470 self .last_response_id = json_data ['response' ]['id' ]
420471 self .last_response_create_time = time .time () * 1000
0 commit comments