5
5
from dataclasses import dataclass , field
6
6
from http import HTTPStatus
7
7
from typing import Optional
8
- from unittest .mock import MagicMock
8
+ from unittest .mock import AsyncMock , MagicMock
9
9
10
10
import pytest
11
11
@@ -83,20 +83,31 @@ def register_mock_resolver():
83
83
def mock_serving_setup ():
84
84
"""Provides a mocked engine and serving completion instance."""
85
85
mock_engine = MagicMock (spec = AsyncLLM )
86
- mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
87
86
mock_engine .errored = False
88
87
89
- def mock_add_lora_side_effect (lora_request : LoRARequest ):
88
+ tokenizer = get_tokenizer (MODEL_NAME )
89
+ mock_engine .get_tokenizer = AsyncMock (return_value = tokenizer )
90
+
91
+ async def mock_add_lora_side_effect (lora_request : LoRARequest ):
90
92
"""Simulate engine behavior when adding LoRAs."""
91
93
if lora_request .lora_name == "test-lora" :
92
94
# Simulate successful addition
93
- return
94
- elif lora_request .lora_name == "invalid-lora" :
95
+ return True
96
+ if lora_request .lora_name == "invalid-lora" :
95
97
# Simulate failure during addition (e.g. invalid format)
96
98
raise ValueError (f"Simulated failure adding LoRA: "
97
99
f"{ lora_request .lora_name } " )
100
+ return True
101
+
102
+ mock_engine .add_lora = AsyncMock (side_effect = mock_add_lora_side_effect )
103
+
104
+ async def mock_generate (* args , ** kwargs ):
105
+ for _ in []:
106
+ yield _
107
+
108
+ mock_engine .generate = MagicMock (spec = AsyncLLM .generate ,
109
+ side_effect = mock_generate )
98
110
99
- mock_engine .add_lora .side_effect = mock_add_lora_side_effect
100
111
mock_engine .generate .reset_mock ()
101
112
mock_engine .add_lora .reset_mock ()
102
113
@@ -131,7 +142,7 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup,
131
142
with suppress (Exception ):
132
143
await serving_completion .create_completion (req_found )
133
144
134
- mock_engine .add_lora .assert_called_once ()
145
+ mock_engine .add_lora .assert_awaited_once ()
135
146
called_lora_request = mock_engine .add_lora .call_args [0 ][0 ]
136
147
assert isinstance (called_lora_request , LoRARequest )
137
148
assert called_lora_request .lora_name == lora_model_name
@@ -157,7 +168,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
157
168
158
169
response = await serving_completion .create_completion (req )
159
170
160
- mock_engine .add_lora .assert_not_called ()
171
+ mock_engine .add_lora .assert_not_awaited ()
161
172
mock_engine .generate .assert_not_called ()
162
173
163
174
assert isinstance (response , ErrorResponse )
@@ -181,7 +192,7 @@ async def test_serving_completion_resolver_add_lora_fails(
181
192
response = await serving_completion .create_completion (req )
182
193
183
194
# Assert add_lora was called before the failure
184
- mock_engine .add_lora .assert_called_once ()
195
+ mock_engine .add_lora .assert_awaited_once ()
185
196
called_lora_request = mock_engine .add_lora .call_args [0 ][0 ]
186
197
assert isinstance (called_lora_request , LoRARequest )
187
198
assert called_lora_request .lora_name == invalid_model
0 commit comments