Skip to content

Commit 600b08f

Browse files
lidenghui1110zzhx1
andauthored
[Feat]: Add custom lmhead tensor model parallel (#2309)
### What this PR does / why we need it? This PR introduces LMhead tensor model parallel to achieve decreasing of memory consumption, and TPOT performance improvement. It support both eager mode and graph mode. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved 1.48 GB NPU memory per RANK. performance data: <img width="1444" height="438" alt="image" src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | lmhead_tensor_parallel_size | Split the lm_head matrix along the column dimension (vocab_size) into lmhead_tensor_parallel_size pieces | No | int | default value is None, once this value is set, the feature will be enabled, vocab_size must be divisible by this value. | example `--additional_config={"lmhead_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@de533ab --------- Signed-off-by: zzhx1 <[email protected]> Co-authored-by: zhangzihang <[email protected]>
1 parent e7ad4a6 commit 600b08f

File tree

14 files changed

+458
-22
lines changed

14 files changed

+458
-22
lines changed

docs/source/user_guide/configuration/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ The following table lists the additional configuration options available in vLLM
3434
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
3535
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
3636
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
37+
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
3738

3839
The details of each config option are as follows:
3940

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
from vllm.config import ParallelConfig
5+
6+
from vllm_ascend.distributed.parallel_state import (
7+
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
8+
get_mc2_group, init_ascend_model_parallel)
9+
10+
11+
@pytest.fixture
12+
def parallel_config():
13+
return ParallelConfig(data_parallel_size=2,
14+
tensor_parallel_size=2,
15+
pipeline_parallel_size=2)
16+
17+
18+
@pytest.fixture
19+
def mock_distributed():
20+
with patch('torch.distributed.is_initialized', return_value=True), \
21+
patch('torch.distributed.get_world_size', return_value=8), \
22+
patch('torch.distributed.get_backend', return_value='nccl'), \
23+
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
24+
mock_group.return_value.local_rank = 0
25+
mock_group.return_value.device_group = MagicMock()
26+
yield
27+
28+
29+
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
30+
mock_ascend_config = MagicMock()
31+
mock_ascend_config.lmhead_tensor_parallel_size = 2
32+
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
33+
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
34+
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
35+
init_ascend_model_parallel(parallel_config)
36+
37+
mc2_group = get_mc2_group()
38+
assert mc2_group is not None
39+
lmheadtp_group = get_lmhead_tp_group()
40+
assert lmheadtp_group is not None
41+
42+
destroy_ascend_model_parallel()
43+
assert _MC2 is None
44+
assert _LMTP is None

tests/ut/models/test_deepseek_mtp.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture):
3131
mocker_deepseek_v2_decode_layer = mocker.patch(
3232
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
3333
return_value=None)
34+
mocker.patch(
35+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
36+
return_value=None)
37+
mocker.patch("vllm_ascend.utils.get_ascend_config",
38+
return_value=mocker.Mock())
3439

3540
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
3641
mocker_deepseek_v2_decode_layer.assert_called_once()
@@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture):
8388
mocker.patch(
8489
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
8590
return_value=None)
91+
mocker.patch(
92+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
93+
return_value=None)
94+
mocker.patch("vllm_ascend.utils.get_ascend_config",
95+
return_value=mocker.Mock())
8696

8797
predictor = CustomDeepSeekMultiTokenPredictor(
8898
vllm_config=mock_vllm_config)
@@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture):
157167
return_value=None)
158168
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
159169
return_value=None)
170+
mocker.patch(
171+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
172+
return_value=None)
173+
mocker.patch("vllm_ascend.utils.get_ascend_config",
174+
return_value=mocker.Mock())
160175

161176
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
162177
return mtp
@@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp):
177192
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
178193
previous_hidden_states, inputs_embeds,
179194
spec_step_idx)
180-
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
195+
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

tests/ut/models/test_deepseek_v2.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
2727
CustomDeepseekV2RowParallelLinear,
2828
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
29-
CustomDeepseekV2SiluAndMul)
29+
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
3030

3131

3232
@pytest.fixture
@@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
266266
kv_lora_rank=16,
267267
prefix="layers.1.self_attn")
268268
assert hasattr(attn, "q_proj")
269+
270+
271+
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
272+
# 创建一个简单的配置对象
273+
class SimpleConfig:
274+
275+
def __init__(self):
276+
self.vocab_size = 10000
277+
self.hidden_size = 128
278+
279+
config = SimpleConfig()
280+
281+
# 直接创建lmhead和logits_processor
282+
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
283+
logits_processor = LogitsProcessor(config.vocab_size)
284+
285+
# 创建模拟输出
286+
mock_output = torch.randn(2, 4, config.hidden_size)
287+
mock_logits = torch.randn(2, 4, config.vocab_size)
288+
289+
# 直接测试logits_processor
290+
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
291+
with patch.object(logits_processor,
292+
"_gather_logits",
293+
return_value=mock_logits):
294+
logits = logits_processor(lmhead, mock_output)
295+
assert logits.shape == (2, 4, config.vocab_size)

tests/ut/ops/test_vocab_parallel_embedding.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import torch
2020

21-
from vllm_ascend.ops.vocab_parallel_embedding import \
22-
AscendVocabParallelEmbedding
21+
from vllm_ascend.ops.vocab_parallel_embedding import (
22+
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
2323

2424
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
2525

@@ -34,7 +34,11 @@ def setUp(self):
3434

3535
def _create_layer(self):
3636
# Patch methods and dependencies for VocabParallelEmbedding
37-
with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
37+
mock_group = MagicMock()
38+
mock_group.world_size = 2
39+
mock_group.rank_in_group = 0
40+
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
41+
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
3842
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
3943
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
4044
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
@@ -174,3 +178,55 @@ def test_output_shape(self):
174178
# Call the forward method
175179
output = layer.forward(input_)
176180
self.assertEqual(output.shape, expected_shape)
181+
182+
183+
class TestAscendLogitsProcessor(unittest.TestCase):
184+
185+
def setUp(self):
186+
self.vocab_size = 50
187+
self.num_embeddings = 50
188+
self.embedding_dim = 10
189+
self.org_num_embeddings = 40
190+
self.padding_size = 8
191+
192+
self.mock_group = MagicMock()
193+
self.mock_group.world_size = 2
194+
self.mock_group.rank_in_group = 0
195+
self.mock_ascend_config = MagicMock()
196+
self.mock_quant_method = MagicMock()
197+
self.mock_quant_method.apply = MagicMock(
198+
return_value=torch.randn(1, self.vocab_size))
199+
self.patches = [
200+
patch("vllm_ascend.ascend_config.get_ascend_config",
201+
return_value=self.mock_ascend_config),
202+
patch(
203+
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
204+
return_value=self.mock_group),
205+
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
206+
return_value=True),
207+
patch(
208+
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
209+
return_value=torch.randn(1, self.vocab_size))
210+
]
211+
212+
for p in self.patches:
213+
p.start()
214+
215+
def tearDown(self):
216+
for p in self.patches:
217+
p.stop()
218+
219+
def test_create_processor(self):
220+
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
221+
self.assertEqual(processor.vocab_size, self.vocab_size)
222+
223+
def test_get_logits(self):
224+
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
225+
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
226+
embedding_dim=self.embedding_dim,
227+
prefix="lm_head")
228+
lmhead.quant_method = self.mock_quant_method
229+
lmhead.quant_method.apply = self.mock_quant_method.apply
230+
hidden_state = torch.randn(1, self.org_num_embeddings)
231+
processor._get_logits(hidden_state, lmhead)
232+
self.mock_quant_method.apply.assert_called_once()

tests/ut/test_ascend_config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717

1818
from transformers import PretrainedConfig
19-
from vllm.config import ModelConfig, VllmConfig
19+
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
2020

2121
from tests.ut.base import TestBase
2222
from vllm_ascend.ascend_config import (_check_torchair_supported,
@@ -75,7 +75,7 @@ def test_init_ascend_config_with_additional_config(self):
7575
"enabled": True
7676
},
7777
"expert_map_path": "test_expert_map_path",
78-
"refresh": True
78+
"refresh": True,
7979
}
8080
ascend_config = init_ascend_config(test_vllm_config)
8181
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
@@ -304,3 +304,12 @@ def test_ascend_config_load_error(self):
304304
"refresh": True
305305
}
306306
init_ascend_config(test_vllm_config)
307+
308+
with self.assertRaises(AssertionError):
309+
test_vllm_config.additional_config = {
310+
"lmhead_tensor_parallel_size": 2,
311+
"refresh": True
312+
}
313+
test_vllm_config.parallel_config = ParallelConfig(
314+
data_parallel_size=4, tensor_parallel_size=2)
315+
init_ascend_config(test_vllm_config)

tests/ut/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
289289
# ascend custom op is not registered
290290
utils.register_ascend_customop()
291291
# should call register_oot three
292-
self.assertEqual(mock_customop.register_oot.call_count, 10)
292+
self.assertEqual(mock_customop.register_oot.call_count, 12)
293293
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
294294

295295
# ascend custom op is already registered
296296
utils.register_ascend_customop()
297297
# should not register_oot again, thus only called three in this ut
298-
self.assertEqual(mock_customop.register_oot.call_count, 10)
298+
self.assertEqual(mock_customop.register_oot.call_count, 12)
299299

300300

301301
class TestProfileExecuteDuration(TestBase):

tests/ut/torchair/models/test_torchair_deepseek_mtp.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture):
3131
mocker_deepseek_v2_decode_layer = mocker.patch(
3232
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
3333
return_value=None)
34+
mocker.patch(
35+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
36+
return_value=None)
37+
mocker.patch("vllm_ascend.utils.get_ascend_config",
38+
return_value=mocker.Mock())
3439

3540
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
3641
mocker_deepseek_v2_decode_layer.assert_called_once()
@@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture):
8388
mocker.patch(
8489
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
8590
return_value=None)
91+
mocker.patch(
92+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
93+
return_value=None)
94+
mocker.patch("vllm_ascend.utils.get_ascend_config",
95+
return_value=mocker.Mock())
8696

8797
predictor = TorchairDeepSeekMultiTokenPredictor(
8898
vllm_config=mock_vllm_config)
@@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture):
157167
return_value=None)
158168
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
159169
return_value=None)
170+
mocker.patch(
171+
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
172+
return_value=None)
173+
mocker.patch("vllm_ascend.utils.get_ascend_config",
174+
return_value=mocker.Mock())
160175

161176
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
162177
return mtp
@@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp):
177192
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
178193
previous_hidden_states, inputs_embeds,
179194
spec_step_idx)
180-
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
195+
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

vllm_ascend/ascend_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def __init__(self, vllm_config):
5151
"enable_shared_expert_dp", False
5252
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
5353
self.enable_prefetch = additional_config.get("enable_prefetch", False)
54+
self.lmhead_tensor_parallel_size = additional_config.get(
55+
"lmhead_tensor_parallel_size", None)
56+
if self.lmhead_tensor_parallel_size is not None:
57+
logger.info(
58+
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
59+
)
60+
if vllm_config.parallel_config.tensor_parallel_size != 1:
61+
raise AssertionError(
62+
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
63+
)
5464

5565

5666
class TorchairGraphConfig:

vllm_ascend/distributed/parallel_state.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,26 @@
66
init_model_parallel_group)
77

88
import vllm_ascend.envs as envs_ascend
9+
from vllm_ascend.ascend_config import get_ascend_config
910

1011
# Currently, mc2 op need their own group coordinator.
1112
_MC2: Optional[GroupCoordinator] = None
1213
_MLP_TP: Optional[GroupCoordinator] = None
1314

15+
_LMTP: Optional[GroupCoordinator] = None
16+
1417

1518
def get_mc2_group() -> GroupCoordinator:
1619
assert _MC2 is not None, ("mc2 group is not initialized")
1720
return _MC2
1821

1922

23+
def get_lmhead_tp_group() -> GroupCoordinator:
24+
assert _LMTP is not None, (
25+
"lm head tensor parallel group is not initialized")
26+
return _LMTP
27+
28+
2029
def get_mlp_tp_group() -> GroupCoordinator:
2130
assert _MLP_TP is not None, ("mlp group is not initialized")
2231
return _MLP_TP
@@ -65,6 +74,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
6574
backend,
6675
group_name="mlp_tp")
6776

77+
lmhead_tensor_parallel_size = get_ascend_config(
78+
).lmhead_tensor_parallel_size
79+
if lmhead_tensor_parallel_size is not None:
80+
group_ranks = []
81+
global _LMTP
82+
num_lmhead_tensor_parallel_groups: int = (world_size //
83+
lmhead_tensor_parallel_size)
84+
for i in range(num_lmhead_tensor_parallel_groups):
85+
ranks = list(
86+
range(i * lmhead_tensor_parallel_size,
87+
(i + 1) * lmhead_tensor_parallel_size))
88+
group_ranks.append(ranks)
89+
_LMTP = init_model_parallel_group(group_ranks,
90+
get_world_group().local_rank,
91+
backend,
92+
group_name="lmheadtp")
93+
6894

6995
def get_mlp_tensor_model_parallel_world_size():
7096
"""Return world size for the tensor model parallel group."""
@@ -86,3 +112,8 @@ def destroy_ascend_model_parallel():
86112
if _MLP_TP:
87113
_MLP_TP.destroy()
88114
_MLP_TP = None
115+
116+
global _LMTP
117+
if _LMTP:
118+
_LMTP.destroy()
119+
_LMTP = None

0 commit comments

Comments
 (0)