Skip to content

Commit 25be631

Browse files
wangyueluckywangyue.demon
andauthored
feat(realtime): Implement Doubao realtime voice live session with ASR… (#370)
* feat(realtime): Implement Doubao realtime voice live session with ASR/TTS/dialog * fix ut: TestRealtimeModelConfig Clear any cached properties before each test --------- Co-authored-by: wangyue.demon <[email protected]>
1 parent a752bf7 commit 25be631

16 files changed

+1691
-1
lines changed

.gitleaks.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ description = "Empty environment variables with KEY pattern"
7373
regex = '''os\.environ\[".*?KEY"\]\s*=\s*".+"'''
7474

7575
[allowlist]
76-
paths = ["requirements.txt", "tests"]
76+
paths = ["requirements.txt", "tests", "veadk/realtime/client.py", "veadk/realtime/live.py"]

tests/config/test_model_config.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from unittest import TestCase, mock
17+
from veadk.configs.model_configs import RealtimeModelConfig
18+
19+
20+
class TestRealtimeModelConfig(TestCase):
21+
def test_default_values(self):
22+
"""Test that default values are set correctly"""
23+
config = RealtimeModelConfig()
24+
self.assertEqual(config.name, "doubao_realtime_voice_model")
25+
self.assertEqual(
26+
config.api_base, "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
27+
)
28+
29+
@mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_api_key"})
30+
def test_api_key_from_env(self):
31+
"""Test api_key is retrieved from environment variable"""
32+
config = RealtimeModelConfig()
33+
self.assertEqual(config.api_key, "test_api_key")
34+
35+
@mock.patch.dict(os.environ, {}, clear=True)
36+
@mock.patch(
37+
"veadk.configs.model_configs.get_speech_token", return_value="mocked_token"
38+
)
39+
def test_api_key_from_get_speech_token(self, mock_get_token):
40+
"""Test api_key falls back to get_speech_token when env var is not set"""
41+
config = RealtimeModelConfig()
42+
self.assertEqual(config.api_key, "mocked_token")
43+
mock_get_token.assert_called_once()
44+
45+
@mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": ""})
46+
@mock.patch(
47+
"veadk.configs.model_configs.get_speech_token", return_value="mocked_token"
48+
)
49+
def test_api_key_empty_env_var(self, mock_get_token):
50+
"""Test api_key falls back when env var is empty string"""
51+
config = RealtimeModelConfig()
52+
self.assertEqual(config.api_key, "mocked_token")
53+
mock_get_token.assert_called_once()
54+
55+
def test_api_key_caching(self):
56+
"""Test that api_key is properly cached"""
57+
with mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_key"}):
58+
config = RealtimeModelConfig()
59+
first_call = config.api_key
60+
second_call = config.api_key
61+
self.assertEqual(first_call, second_call)
62+
self.assertEqual(first_call, "test_key")
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import unittest
17+
from unittest.mock import patch, MagicMock
18+
from google.genai._api_client import BaseApiClient
19+
from veadk.realtime.client import DoubaoClient, DoubaoAsyncClient
20+
from veadk.utils.logger import get_logger
21+
22+
logger = get_logger(__name__)
23+
24+
25+
class TestDoubaoAsyncClient(unittest.TestCase):
26+
def setUp(self):
27+
self.mock_api_client = MagicMock(spec=BaseApiClient)
28+
self.async_client = DoubaoAsyncClient(self.mock_api_client)
29+
30+
def test_initialization(self):
31+
self.assertIsInstance(self.async_client, DoubaoAsyncClient)
32+
self.assertEqual(self.async_client._api_client, self.mock_api_client)
33+
34+
def test_live_property(self):
35+
from veadk.realtime.live import DoubaoAsyncLive
36+
37+
live_instance = self.async_client.live
38+
self.assertIsInstance(live_instance, DoubaoAsyncLive)
39+
self.assertEqual(live_instance._api_client, self.mock_api_client)
40+
41+
42+
class TestDoubaoClient(unittest.TestCase):
43+
def setUp(self):
44+
self.patcher = patch.dict("os.environ", {}, clear=True)
45+
self.patcher.start()
46+
47+
def tearDown(self):
48+
self.patcher.stop()
49+
50+
def test_initialization_without_google_key(self):
51+
# Test when GOOGLE_API_KEY is not set
52+
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
53+
client = DoubaoClient()
54+
self.assertEqual(os.environ["GOOGLE_API_KEY"], "hack_google_api_key")
55+
self.assertIsNotNone(client._aio)
56+
57+
def test_initialization_with_google_key(self):
58+
# Test when GOOGLE_API_KEY is already set
59+
os.environ["GOOGLE_API_KEY"] = "existing_key"
60+
os.environ["REALTIME_API_KEY"] = "existing_key"
61+
client = DoubaoClient()
62+
self.assertEqual(os.environ["GOOGLE_API_KEY"], "existing_key")
63+
self.assertIsNotNone(client._aio)
64+
65+
@patch(
66+
"veadk.realtime.client.DoubaoAsyncClient", side_effect=Exception("Test error")
67+
)
68+
def test_initialization_failure(self, mock_async_client):
69+
# Test when DoubaoAsyncClient initialization fails
70+
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
71+
client = DoubaoClient()
72+
self.assertIsNone(client._aio)
73+
74+
def test_aio_property(self):
75+
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
76+
client = DoubaoClient()
77+
aio_client = client.aio
78+
self.assertIsInstance(aio_client, DoubaoAsyncClient)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest.mock import AsyncMock, MagicMock, patch
17+
from google.genai import types
18+
from veadk.realtime.doubao_realtime_voice_llm import DoubaoRealtimeVoice
19+
from google.adk.models.llm_request import LlmRequest
20+
from google.adk.models.base_llm_connection import BaseLlmConnection
21+
from google.genai.types import GenerateContentConfig
22+
import os
23+
from veadk.realtime.client import DoubaoClient
24+
from veadk.realtime.doubao_realtime_voice_llm import (
25+
_AGENT_ENGINE_TELEMETRY_TAG,
26+
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME,
27+
)
28+
29+
30+
class TestDoubaoRealtimeVoice:
31+
@pytest.fixture
32+
def mock_llm_request(self):
33+
request = MagicMock(spec=LlmRequest)
34+
request.model = "doubao_realtime_voice"
35+
request.config = GenerateContentConfig()
36+
request.config.system_instruction = "Test instruction"
37+
request.config.tools = []
38+
request.live_connect_config = types.LiveConnectConfig(
39+
http_options=types.HttpOptions()
40+
)
41+
return request
42+
43+
def test_supported_models(self):
44+
"""Test supported_models returns correct model patterns"""
45+
models = DoubaoRealtimeVoice.supported_models()
46+
assert isinstance(models, list)
47+
assert len(models) == 2
48+
assert r"doubao_realtime_voice.*" in models
49+
assert r"Doubao_scene_SLM_Doubao_realtime_voice_model.*" in models
50+
51+
def test_api_client_property(self):
52+
"""Test api_client property returns DoubaoClient with correct options"""
53+
model = DoubaoRealtimeVoice()
54+
client = model.api_client
55+
assert isinstance(client, DoubaoClient)
56+
assert client._api_client._http_options.retry_options == model.retry_options
57+
58+
def test_live_api_client_property(self):
59+
"""Test _live_api_client property returns DoubaoClient with correct version"""
60+
model = DoubaoRealtimeVoice()
61+
client = model._live_api_client
62+
assert isinstance(client, DoubaoClient)
63+
assert client._api_client._http_options.api_version == model._live_api_version
64+
65+
def test_tracking_headers_without_env(self):
66+
"""Test _tracking_headers without environment variable"""
67+
model = DoubaoRealtimeVoice()
68+
headers = model._tracking_headers
69+
assert "x-volcengine-api-client" in headers
70+
assert "user-agent" in headers
71+
assert _AGENT_ENGINE_TELEMETRY_TAG not in headers["x-volcengine-api-client"]
72+
73+
@patch.dict(os.environ, {_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME: "test_id"})
74+
def test_tracking_headers_with_env(self):
75+
"""Test _tracking_headers with environment variable set"""
76+
model = DoubaoRealtimeVoice()
77+
headers = model._tracking_headers
78+
assert _AGENT_ENGINE_TELEMETRY_TAG in headers["x-volcengine-api-client"]
79+
80+
@pytest.mark.asyncio
81+
async def test_connect_with_speech_config(self, mock_llm_request):
82+
"""Test connect method with speech config"""
83+
speech_config = types.SpeechConfig()
84+
model = DoubaoRealtimeVoice(speech_config=speech_config)
85+
86+
# 修正异步上下文管理器的 mock 设置
87+
with patch.object(model._live_api_client.aio.live, "connect") as mock_connect:
88+
# 创建模拟的异步上下文管理器
89+
mock_session = AsyncMock()
90+
mock_connect.return_value.__aenter__.return_value = mock_session
91+
92+
async with model.connect(mock_llm_request) as connection:
93+
assert isinstance(connection, BaseLlmConnection)
94+
assert (
95+
mock_llm_request.live_connect_config.speech_config == speech_config
96+
)
97+
mock_connect.assert_called_once_with(
98+
model=mock_llm_request.model,
99+
config=mock_llm_request.live_connect_config,
100+
)
101+
102+
@pytest.mark.asyncio
103+
async def test_connect_without_speech_config(self, mock_llm_request):
104+
"""Test connect method without speech config"""
105+
model = DoubaoRealtimeVoice()
106+
107+
with patch.object(model._live_api_client.aio.live, "connect") as mock_connect:
108+
# 使用AsyncMock模拟会话对象,更贴近真实场景
109+
mock_session = AsyncMock()
110+
mock_connect.return_value.__aenter__.return_value = mock_session
111+
112+
async with model.connect(mock_llm_request) as connection:
113+
assert isinstance(connection, BaseLlmConnection)
114+
# 验证speech_config为None而非检查属性是否存在
115+
assert mock_llm_request.live_connect_config.speech_config is None
116+
mock_connect.assert_called_once_with(
117+
model=mock_llm_request.model,
118+
config=mock_llm_request.live_connect_config,
119+
)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest.mock import AsyncMock
17+
from veadk.realtime.doubao_realtime_voice_llm_connection import (
18+
DoubaoRealtimeVoiceLlmConnection,
19+
)
20+
from google.genai import types
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_send_realtime_with_blob():
25+
"""Test sending Blob input."""
26+
# Setup
27+
mock_session = AsyncMock()
28+
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
29+
connection._gemini_session = mock_session
30+
31+
blob_input = types.Blob()
32+
33+
# Execute
34+
await connection.send_realtime(blob_input)
35+
36+
# Verify
37+
mock_session.send_realtime_input.assert_called_once_with(media=blob_input)
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_send_realtime_with_activity_start():
42+
"""Test sending ActivityStart input."""
43+
# Setup
44+
mock_session = AsyncMock()
45+
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
46+
connection._gemini_session = mock_session
47+
48+
activity_start = types.ActivityStart()
49+
50+
# Execute
51+
await connection.send_realtime(activity_start)
52+
53+
# Verify
54+
mock_session.send_realtime_input.assert_called_once_with(
55+
activity_start=activity_start
56+
)
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_send_realtime_with_activity_end():
61+
"""Test sending ActivityEnd input."""
62+
# Setup
63+
mock_session = AsyncMock()
64+
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
65+
connection._gemini_session = mock_session
66+
67+
activity_end = types.ActivityEnd()
68+
69+
# Execute
70+
await connection.send_realtime(activity_end)
71+
72+
# Verify
73+
mock_session.send_realtime_input.assert_called_once_with(activity_end=activity_end)
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_send_realtime_with_unsupported_type():
78+
"""Test sending unsupported input type."""
79+
# Setup
80+
mock_session = AsyncMock()
81+
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
82+
connection._gemini_session = mock_session
83+
84+
unsupported_input = "unsupported_type"
85+
86+
# Execute & Verify
87+
with pytest.raises(ValueError) as excinfo:
88+
await connection.send_realtime(unsupported_input)
89+
90+
assert "Unsupported input type" in str(excinfo.value)

0 commit comments

Comments
 (0)