-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvad_mode.py
More file actions
148 lines (120 loc) · 3.53 KB
/
vad_mode.py
File metadata and controls
148 lines (120 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import base64
import json
import ssl
import time
import uuid
from threading import Lock, Thread
import pyaudio
from loguru import logger
from websockets.sync.client import ClientConnection, connect
from client_events import (
InputAudioBufferAppend,
Session,
SessionUpdate,
TurnDetection,
TurnDetectionType,
)
from server_events import ServerEventType
FORMAT = pyaudio.paInt16
CHANNELS = 1
SAMPLE_RATE = 24000
CHUNK_FRAMES = 480
WS_ENDPOINT = "wss://api.stepfun.com/v1/realtime"
MODEL_NAME = "step-1o-audio"
TRACE_ID = str(uuid.uuid4())
API_KEY = "replace with your api key"
# init local input/output aduio
p = pyaudio.PyAudio()
input = p.open(
format=FORMAT,
channels=CHANNELS,
rate=SAMPLE_RATE,
input=True,
frames_per_buffer=CHUNK_FRAMES,
)
output = p.open(
format=FORMAT,
channels=CHANNELS,
rate=SAMPLE_RATE,
output=True,
)
# init audio buffer
# interrupt audio by replacing buffer
# if you want to more precisely, implement AudioBuffer in frame level
class AudioBuffer:
def __init__(self):
self._buffer = []
self._lock = Lock()
def read(self) -> bytes:
with self._lock:
if len(self._buffer):
return self._buffer.pop(0)
return b""
def append(self, data: bytes):
with self._lock:
self._buffer.append(data)
buffer = AudioBuffer()
# init websocket
ws: ClientConnection = connect(
f"{WS_ENDPOINT}?model={MODEL_NAME}",
ssl=ssl.create_default_context(),
additional_headers={
"X-Trace-Id": TRACE_ID,
"Authorization": f"Bearer {API_KEY}",
},
)
logger.info(f"connected! trace id: {TRACE_ID}")
def run_sender():
# init session with vad mode
ws.send(
SessionUpdate(
session=Session(
turn_detection=TurnDetection(type=TurnDetectionType.ServerVAD)
)
).model_dump_json(exclude_none=True)
)
# Keep sending audio
while True:
# send small size audio each time to make stream smooth
data = input.read(CHUNK_FRAMES, exception_on_overflow=False)
ws.send(
InputAudioBufferAppend(
audio=base64.b64encode(data).decode("utf-8")
).model_dump_json(exclude_none=True)
)
sender_thread = Thread(target=run_sender)
sender_thread.start()
audio_playing = False
def play_audio():
global audio_playing
while True:
data = buffer.read()
if data:
audio_playing = True
output.write(data)
else:
audio_playing = False
time.sleep(0.02)
play_thread = Thread(target=play_audio)
play_thread.start()
first_audio_received = False
for msg in ws:
server_event = json.loads(msg)
match server_event["type"]:
case ServerEventType.ResponseAudioDelta:
if not first_audio_received:
first_audio_received = True
logger.info("audio started!")
logger.info(
f"received audio delta {len(server_event['delta'])} bytes, event id: {server_event['event_id']}"
)
buffer.append(base64.b64decode(server_event["delta"]))
case ServerEventType.ResponseDone:
first_audio_received = False
logger.info("response done!")
case ServerEventType.InputAudioBufferSpeechStarted:
if audio_playing:
logger.info("interrupt detected! resetting audio buffer")
buffer = AudioBuffer()
case _:
logger.info(f"received event {server_event}")