Skip to content

Commit f5226e3

Browse files
Angazennangazennzengyanjia
authored
[0.9.1] Add LMhead TP communication groups. (#1956)
### What this PR does / why we need it? In pure dp scenarios (such as DP32), LMHead comptuation takes 1~2ms. In this PR we customize the parallelism of LMHead,enabling the separate TP of LMHead. The computation flow is listed as follows: ``` get_lmhead_group().all_gather # [num_tokens, hid_dim] --> [num_tokens * lmhead_tp, hid_dim] --> lmhead matmul # [num_tokens * lmhead_tp, hid_dim] --> [num_tokens * lmhead_tp, vocab_size // lmhead_tp] --> get_lmhead_group().all_to_all # [num_tokens * lmhead_tp, vocab_size // lmhead_tp] --> [num_tokens, vocab_size] ``` this can decrease 0.5~1ms for deepseek with 28BS on a single die、MTP. In addition, this PR also fixes a bug that introduced by LMHead quantization. The OP `npu_quant_matmul` only accepts dim < 65536, while `vocab_size` is > 65536 if using TP 1. We can set lmhead tp size > 1 to avoid this bug. Main version of this PR: #2309 . ### Does this PR introduce _any_ user-facing change? Yes. We introduced another configurable options `lmhead_tp_size` in ascend_config. For example: ``` additional_config={ "lmhead_tp_size": 16, } ``` The default value is -1, and `lmhead_tp_size` is automatically set to `tensor_parallel_size` in this case. Besides, it is suggested to use it when running full DP to avoid additional communication introduced by TP. Therefore, the parallel size of `lmhead` group will also be changed to `tensor_parallel_size` if TP > 1 so as to fall back to normally TP+DP case. ### How was this patch tested? --------- Signed-off-by: angazenn <[email protected]> Signed-off-by: zengyanjia <[email protected]> Co-authored-by: angazenn <[email protected]> Co-authored-by: zengyanjia <[email protected]>
1 parent 3f65494 commit f5226e3

File tree

11 files changed

+373
-18
lines changed

11 files changed

+373
-18
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, vllm_config):
6262
False) # Whether to enable DeepSeek models' prefill optimizations
6363
self.enable_cpu_binding = additional_config.get( # Whether to enable the cpu binding
6464
"enable_cpu_binding", False)
65+
self.lmhead_tp_size = additional_config.get("lmhead_tp_size", -1)
6566

6667

6768
class TorchairGraphConfig:

vllm_ascend/distributed/parallel_state.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,26 @@
66

77
# Currently, mc2 op need their own group coordinator.
88
_MC2: Optional[GroupCoordinator] = None
9+
_LMHEAD: Optional[GroupCoordinator] = None
910

1011

1112
def get_mc2_group() -> GroupCoordinator:
1213
assert _MC2 is not None, ("mc2 group is not initialized")
1314
return _MC2
1415

1516

17+
def get_lmhead_group() -> GroupCoordinator:
18+
assert _LMHEAD is not None, ("lmhead group is not initialized")
19+
return _LMHEAD
20+
21+
1622
def model_parallel_initialized():
1723
return (_MC2 is not None)
1824

1925

2026
def init_ascend_model_parallel(
2127
expert_parallel_size: int = 1,
28+
lm_head_tp_size: int = -1,
2229
backend: Optional[str] = None,
2330
):
2431
if model_parallel_initialized():
@@ -41,9 +48,24 @@ def init_ascend_model_parallel(
4148
backend,
4249
group_name="mc2")
4350

51+
if lm_head_tp_size > 0:
52+
all_ranks = torch.arange(world_size).reshape(-1, lm_head_tp_size)
53+
global _LMHEAD
54+
group_ranks = all_ranks.unbind(0)
55+
group_ranks = [x.tolist() for x in group_ranks]
56+
57+
_LMHEAD = init_model_parallel_group(group_ranks,
58+
get_world_group().local_rank,
59+
backend,
60+
group_name="lmhead")
61+
4462

4563
def destroy_ascend_model_parallel():
4664
global _MC2
4765
if _MC2:
4866
_MC2.destroy()
4967
_MC2 = None
68+
global _LMHEAD
69+
if _LMHEAD:
70+
_LMHEAD.destroy()
71+
_LMHEAD = None

vllm_ascend/distributed/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
18+
from vllm.distributed.parallel_state import (
19+
get_dp_group, get_tensor_model_parallel_world_size)
20+
21+
from vllm_ascend.distributed.parallel_state import get_lmhead_group
22+
23+
24+
def is_lmhead_tp():
25+
# We only activate optimization of lmhead communication
26+
# when tp_size == 1, dp_size > 1 and lmhead_tp_size > 1.
27+
28+
try:
29+
get_lmhead_group()
30+
except AssertionError:
31+
return False
32+
33+
tp_size = get_tensor_model_parallel_world_size()
34+
dp_size = get_dp_group().world_size
35+
lmhead_tp_size = get_lmhead_group().world_size
36+
37+
return tp_size == 1 and dp_size > 1 and lmhead_tp_size > 1

vllm_ascend/models/deepseek_mtp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@
2626
from vllm.config import CacheConfig, ModelConfig, VllmConfig
2727
from vllm.forward_context import get_forward_context
2828
from vllm.model_executor.layers.layernorm import RMSNorm
29-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3029
from vllm.model_executor.layers.quantization import QuantizationConfig
3130
from vllm.model_executor.layers.sampler import get_sampler
32-
from vllm.model_executor.layers.vocab_parallel_embedding import (
33-
ParallelLMHead, VocabParallelEmbedding)
31+
from vllm.model_executor.layers.vocab_parallel_embedding import \
32+
VocabParallelEmbedding
3433
from vllm.model_executor.models.deepseek_mtp import (
3534
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
3635
SharedHead)
3736
from vllm.model_executor.models.utils import maybe_prefix
3837
from vllm.model_executor.sampling_metadata import SamplingMetadata
3938
from vllm.sequence import IntermediateTensors
4039

40+
from vllm_ascend.ops.lmhead import CustomParallelLMHead
41+
from vllm_ascend.ops.logits_processor import CustomLogitsProcessor
42+
4143
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
4244

4345

@@ -49,10 +51,10 @@ def __init__(self,
4951
prefix: str = "") -> None:
5052
nn.Module.__init__(self)
5153
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
52-
self.head = ParallelLMHead(config.vocab_size,
53-
config.hidden_size,
54-
quant_config=quant_config,
55-
prefix=maybe_prefix(prefix, "head"))
54+
self.head = CustomParallelLMHead(config.vocab_size,
55+
config.hidden_size,
56+
quant_config=quant_config,
57+
prefix=maybe_prefix(prefix, "head"))
5658

5759

5860
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
@@ -145,7 +147,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
145147
for idx in range(self.mtp_start_layer_idx,
146148
self.mtp_start_layer_idx + self.num_mtp_layers)
147149
]
148-
self.logits_processor = LogitsProcessor(config.vocab_size)
150+
self.logits_processor = CustomLogitsProcessor(config.vocab_size)
149151

150152
def forward(
151153
self,

vllm_ascend/models/deepseek_v2.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@
4848
ReplicatedLinear,
4949
RowParallelLinear,
5050
UnquantizedLinearMethod)
51-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5251
from vllm.model_executor.layers.quantization import QuantizationConfig
5352
from vllm.model_executor.layers.rotary_embedding import get_rope
5453
from vllm.model_executor.layers.sampler import get_sampler
55-
from vllm.model_executor.layers.vocab_parallel_embedding import (
56-
ParallelLMHead, VocabParallelEmbedding)
54+
from vllm.model_executor.layers.vocab_parallel_embedding import \
55+
VocabParallelEmbedding
5756
from vllm.model_executor.models.deepseek_v2 import \
5857
DeepseekV2ForCausalLM # noqa: E501
5958
from vllm.model_executor.models.deepseek_v2 import \
@@ -68,6 +67,8 @@
6867

6968
from vllm_ascend.ascend_config import get_ascend_config
7069
from vllm_ascend.ops.fused_moe import AscendFusedMoE
70+
from vllm_ascend.ops.lmhead import CustomParallelLMHead
71+
from vllm_ascend.ops.logits_processor import CustomLogitsProcessor
7172
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7273
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7374
from vllm_ascend.utils import dispose_tensor, npu_prefetch
@@ -835,14 +836,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
835836
prefix=maybe_prefix(
836837
prefix, "model"))
837838
if get_pp_group().is_last_rank:
838-
self.lm_head = ParallelLMHead(config.vocab_size,
839-
config.hidden_size,
840-
quant_config=quant_config,
841-
prefix=maybe_prefix(
842-
prefix, "lm_head"))
839+
self.lm_head = CustomParallelLMHead(config.vocab_size,
840+
config.hidden_size,
841+
quant_config=quant_config,
842+
prefix=maybe_prefix(
843+
prefix, "lm_head"))
843844
else:
844845
self.lm_head = PPMissingLayer()
845-
self.logits_processor = LogitsProcessor(config.vocab_size)
846+
self.logits_processor = CustomLogitsProcessor(config.vocab_size)
846847
self.sampler = get_sampler()
847848
self.make_empty_intermediate_tensors = (
848849
self.model.make_empty_intermediate_tensors)

vllm_ascend/ops/lmhead.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# Adapted from vllm/model_executor/layers/lmhead.py
18+
# This file is a part of the vllm-ascend project.
19+
20+
from typing import Optional
21+
22+
import torch
23+
from torch.nn.parameter import Parameter
24+
from vllm.distributed import divide
25+
from vllm.model_executor.layers.quantization.base_config import (
26+
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
27+
from vllm.model_executor.layers.vocab_parallel_embedding import (
28+
UnquantizedEmbeddingMethod, VocabParallelEmbedding, pad_vocab_size)
29+
from vllm.model_executor.utils import set_weight_attrs
30+
31+
from vllm_ascend.distributed.parallel_state import get_lmhead_group
32+
33+
DEFAULT_VOCAB_PADDING_SIZE = 64
34+
35+
36+
class CustomParallelLMHead(VocabParallelEmbedding):
37+
"""Parallelized LM head.
38+
39+
Output logits weight matrices used in the Sampler. The weight and bias
40+
tensors are padded to make sure they are divisible by the number of
41+
model parallel GPUs.
42+
43+
Args:
44+
num_embeddings: vocabulary size.
45+
embedding_dim: size of hidden state.
46+
bias: whether to use bias.
47+
params_dtype: type of the parameters.
48+
org_num_embeddings: original vocabulary size (without LoRA).
49+
padding_size: padding size for the vocabulary.
50+
"""
51+
52+
def __init__(self,
53+
num_embeddings: int,
54+
embedding_dim: int,
55+
bias: bool = False,
56+
params_dtype: Optional[torch.dtype] = None,
57+
org_num_embeddings: Optional[int] = None,
58+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
59+
quant_config: Optional[QuantizationConfig] = None,
60+
prefix: str = ""):
61+
super().__init__(num_embeddings, embedding_dim, params_dtype,
62+
org_num_embeddings, padding_size, quant_config,
63+
prefix)
64+
# Keep the input dimensions.
65+
tp_rank = get_lmhead_group().rank_in_group
66+
self.tp_size = get_lmhead_group().world_size
67+
self.num_embeddings = num_embeddings
68+
self.padding_size = padding_size
69+
self.org_vocab_size = org_num_embeddings or num_embeddings
70+
num_added_embeddings = num_embeddings - self.org_vocab_size
71+
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
72+
self.padding_size)
73+
self.num_embeddings_padded = pad_vocab_size(
74+
self.org_vocab_size_padded + num_added_embeddings,
75+
self.padding_size)
76+
assert self.org_vocab_size_padded <= self.num_embeddings_padded
77+
78+
self.shard_indices = self._get_indices(self.num_embeddings_padded,
79+
self.org_vocab_size_padded,
80+
self.num_embeddings,
81+
self.org_vocab_size, tp_rank,
82+
self.tp_size)
83+
self.embedding_dim = embedding_dim
84+
85+
quant_method = None
86+
if quant_config is not None:
87+
quant_method = quant_config.get_quant_method(self, prefix=prefix)
88+
if quant_method is None:
89+
quant_method = UnquantizedEmbeddingMethod()
90+
91+
# If we are making an embedding layer, then our quantization linear
92+
# method must implement the embedding operation. If we are another
93+
# layer type like ParallelLMHead, this is not important.
94+
is_embedding_layer = type(self) is VocabParallelEmbedding
95+
quant_method_implements_embedding = method_has_implemented_embedding(
96+
type(quant_method))
97+
if is_embedding_layer and not quant_method_implements_embedding:
98+
raise NotImplementedError(
99+
f"The class {type(quant_method).__name__} must implement "
100+
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
101+
102+
self.quant_method: QuantizeMethodBase = quant_method
103+
104+
if params_dtype is None:
105+
params_dtype = torch.get_default_dtype()
106+
# Divide the weight matrix along the vocaburaly dimension.
107+
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
108+
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
109+
self.tp_size)
110+
assert (self.shard_indices.num_elements_padded ==
111+
self.num_embeddings_per_partition)
112+
self.num_org_embeddings_per_partition = (
113+
self.shard_indices.org_vocab_end_index -
114+
self.shard_indices.org_vocab_start_index)
115+
self.num_added_embeddings_per_partition = (
116+
self.shard_indices.added_vocab_end_index -
117+
self.shard_indices.added_vocab_start_index)
118+
119+
self.quant_method.create_weights(self,
120+
self.embedding_dim,
121+
[self.num_embeddings_per_partition],
122+
self.embedding_dim,
123+
self.num_embeddings_padded,
124+
params_dtype=params_dtype,
125+
weight_loader=self.weight_loader)
126+
127+
self.quant_config = quant_config
128+
if bias:
129+
self.bias = Parameter(
130+
torch.empty(self.num_embeddings_per_partition,
131+
dtype=params_dtype))
132+
set_weight_attrs(self.bias, {
133+
"output_dim": 0,
134+
"weight_loader": self.weight_loader,
135+
})
136+
else:
137+
self.register_parameter("bias", None)
138+
139+
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
140+
"""Tie the weights with word embeddings."""
141+
# GGUF quantized embed_tokens.
142+
if self.quant_config and self.quant_config.get_name() == "gguf":
143+
return embed_tokens
144+
else:
145+
self.weight = embed_tokens.weight
146+
return self
147+
148+
def forward(self, input_):
149+
del input_
150+
raise RuntimeError("LMHead's weights should be used in the sampler.")

0 commit comments

Comments
 (0)