|
1 | 1 | import math
|
| 2 | +from unittest import mock |
2 | 3 | from unittest.mock import MagicMock, patch
|
3 | 4 |
|
4 | 5 | import pytest
|
5 | 6 | import torch
|
| 7 | +import torch_npu |
6 | 8 |
|
7 | 9 | from tests.ut.base import TestBase
|
| 10 | +from vllm_ascend.ops.rotary_embedding import __set_cos_sin_cache # noqa E402 |
8 | 11 | from vllm_ascend.ops.rotary_embedding import \
|
9 | 12 | __set_cos_sin_cache as raw__set_cos_sin_cache
|
10 | 13 | from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
|
11 | 14 | native_rope_deepseek_forward,
|
12 |
| - rope_forward_oot, rotate_half, |
| 15 | + rope_forward, rope_forward_oot, |
| 16 | + rotate_half, |
13 | 17 | yarn_find_correction_dim,
|
14 | 18 | yarn_get_mscale)
|
15 | 19 |
|
@@ -363,3 +367,65 @@ def fake_register_buffer(name, tensor, persistent=True):
|
363 | 367 | assert buf.shape == expected_shape
|
364 | 368 | assert buf.device == device
|
365 | 369 | assert buf.dtype == torch.float32
|
| 370 | + |
| 371 | + |
| 372 | +class DummyConfig: |
| 373 | + |
| 374 | + class TorchairGraphConfig: |
| 375 | + enabled = True |
| 376 | + |
| 377 | + torchair_graph_config = TorchairGraphConfig() |
| 378 | + |
| 379 | + |
| 380 | +class DummyModel: |
| 381 | + |
| 382 | + def __init__(self, head_size, max_pos): |
| 383 | + self.head_size = head_size |
| 384 | + self.max_position_embeddings = max_pos |
| 385 | + self.cos = torch.randn(max_pos, head_size) |
| 386 | + self.sin = torch.randn(max_pos, head_size) |
| 387 | + |
| 388 | + def embed(self, positions, weight): |
| 389 | + B, S = positions.shape |
| 390 | + return torch.ones(B, S, self.head_size) * 0.5 |
| 391 | + |
| 392 | + |
| 393 | +@mock.patch("vllm_ascend.ops.rotary_embedding.get_ascend_config", |
| 394 | + return_value=DummyConfig()) |
| 395 | +@mock.patch.object(torch_npu, "npu_apply_rotary_pos_emb") |
| 396 | +@mock.patch("vllm_ascend.ops.rotary_embedding.__set_cos_sin_cache") |
| 397 | +def test_rope_forward_output_shape(mock_set_cache, mock_npu_apply, |
| 398 | + mock_get_ascend_config): |
| 399 | + batch_size = 2 |
| 400 | + seq_len = 4 |
| 401 | + num_heads = 3 |
| 402 | + head_size = 5 |
| 403 | + |
| 404 | + q = torch.randn(batch_size, seq_len, num_heads * head_size) |
| 405 | + k = torch.randn_like(q) |
| 406 | + |
| 407 | + positions = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) |
| 408 | + |
| 409 | + model = DummyModel(head_size=head_size, max_pos=100) |
| 410 | + |
| 411 | + def fake_apply_rotary(q_in, k_in, cos, sin): |
| 412 | + return q_in, k_in |
| 413 | + |
| 414 | + mock_npu_apply.side_effect = fake_apply_rotary |
| 415 | + |
| 416 | + q_out, k_out = rope_forward( |
| 417 | + model, |
| 418 | + positions=positions, |
| 419 | + query=q, |
| 420 | + key=k, |
| 421 | + offsets=None, |
| 422 | + is_neox_style_override=None, |
| 423 | + max_seq_len=None, |
| 424 | + is_prefill=False, # no rope_forward_oot |
| 425 | + is_qwen_torchair=True, # go rotary |
| 426 | + ) |
| 427 | + |
| 428 | + assert q_out.shape == (batch_size, 1, seq_len, num_heads * head_size) |
| 429 | + assert k_out.shape == (batch_size, 1, seq_len, num_heads * head_size) |
| 430 | + |
| 431 | + mock_set_cache.assert_not_called() |
0 commit comments