|
4 | 4 | import random
|
5 | 5 |
|
6 | 6 | import pytest
|
| 7 | +import torch |
7 | 8 |
|
8 | 9 | from vllm.attention import Attention
|
9 | 10 | from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
|
277 | 278 | assert _is_req_state_block_table_match(model_runner, req_id)
|
278 | 279 |
|
279 | 280 |
|
| 281 | +def test_get_nans_in_logits(model_runner): |
| 282 | + req_ids = ("req_0", "req_1") |
| 283 | + |
| 284 | + scheduler_output = _schedule_new_request(*req_ids) |
| 285 | + model_runner._update_states(scheduler_output) |
| 286 | + |
| 287 | + logits = torch.tensor([ |
| 288 | + [1.0, 2.0, 3.0], |
| 289 | + [3.0, 2.0, 1.0], |
| 290 | + ], device=DEVICE) |
| 291 | + result = model_runner._get_nans_in_logits(logits) |
| 292 | + assert result == {"req_0": 0, "req_1": 0} |
| 293 | + |
| 294 | + logits = torch.tensor([ |
| 295 | + [1.0, float('nan'), 3.0], |
| 296 | + [4.0, float('nan'), float('nan')], |
| 297 | + ], |
| 298 | + device=DEVICE) |
| 299 | + result = model_runner._get_nans_in_logits(logits) |
| 300 | + assert result == {"req_0": 1, "req_1": 2} |
| 301 | + |
| 302 | + logits = torch.tensor([ |
| 303 | + [1.0, 2.0, 3.0], |
| 304 | + [4.0, float('nan'), float('nan')], |
| 305 | + ], |
| 306 | + device=DEVICE) |
| 307 | + result = model_runner._get_nans_in_logits(logits) |
| 308 | + assert result == {"req_0": 0, "req_1": 2} |
| 309 | + |
| 310 | + result = model_runner._get_nans_in_logits(logits=None) |
| 311 | + assert result == {"req_0": 0, "req_1": 0} |
| 312 | + |
| 313 | + logits = torch.tensor([ |
| 314 | + [1.0, float('nan'), 3.0], |
| 315 | + ], device=DEVICE) |
| 316 | + result = model_runner._get_nans_in_logits(logits) |
| 317 | + assert result == {'req_0': 1, 'req_1': 0} |
| 318 | + |
| 319 | + logits = torch.tensor([ |
| 320 | + [float('nan'), float('nan'), 2.0], |
| 321 | + [1.0, 2.0, 3.0], |
| 322 | + [float('nan'), 2.0, 3.0], |
| 323 | + ], |
| 324 | + device=DEVICE) |
| 325 | + result = model_runner._get_nans_in_logits(logits) |
| 326 | + assert result == {'req_0': 2, 'req_1': 0} |
| 327 | + |
| 328 | + |
280 | 329 | def test_update_states_no_changes(model_runner):
|
281 | 330 | req_id = "req_0"
|
282 | 331 |
|
|
0 commit comments