Skip to content

Commit 5b6c013

Browse files
author
Levi-JQ
committed
Independent Flashcomm2 == [TODO1]
1 parent 32559d0 commit 5b6c013

File tree

4 files changed

+100
-16
lines changed

4 files changed

+100
-16
lines changed

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ def __init__(self, vllm_config):
101101
)
102102
if self.oproj_tensor_parallel_size is not None:
103103
raise AssertionError(
104-
"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
104+
f"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
105105
)
106106
if global_tp_size <= self.flashcomm2_oproj_tensor_parallel_size:
107107
raise AssertionError(
108-
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
108+
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
109109
)
110110
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
111111
raise AssertionError(
112-
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
112+
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
113113
)
114114

115115

vllm_ascend/models/layers/mla.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch import nn
2727
from vllm.attention import Attention, AttentionMetadata
2828
from vllm.config import CacheConfig, get_current_vllm_config
29+
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
2930
from vllm.forward_context import ForwardContext, get_forward_context
3031
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
3132
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -133,6 +134,18 @@ def forward(
133134
if num_tokens % self.tp_size:
134135
rows += 1
135136
output_shape = (rows, hidden_states.shape[1])
137+
138+
forward_context = get_forward_context()
139+
is_prefill = forward_context.with_prefill
140+
if forward_context.flashcomm_v2_enabled and forward_context.flashcomm1_ds_prefill:
141+
num_padding_tokens = forward_context.pad_size
142+
if is_prefill and self.debug_layer_idx > 0 and self.debug_layer_idx < self.layers:
143+
output_shape = hidden_states.shape
144+
else:
145+
B = (hidden_states.shape[0] + num_padding_tokens) // get_tensor_model_parallel_world_size()
146+
H = hidden_states.shape[1]
147+
output_shape = (B, H)
148+
136149
# FIXME: This does not seem right, should make sure the buffer is fixed
137150
output = torch.empty(output_shape,
138151
dtype=hidden_states.dtype,

vllm_ascend/ops/linear_op.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from typing import Optional, Tuple, Union
4040

41+
from torch import nn
4142
import torch
4243
import torch.distributed as dist
4344
import torch_npu
@@ -182,6 +183,69 @@ def apply_impl(
182183
return output, output_bias
183184

184185

186+
class Flashcomm2MergedColumnParallelOp(CustomColumnParallelOp):
187+
188+
def apply_impl(
189+
self, input_: torch.Tensor
190+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
191+
"""Linear layer with column parallelism.
192+
193+
Implemented multiple optimization projects for dense models, such as FlashComm and
194+
communication-computation fusion.
195+
"""
196+
197+
bias = self.bias if not self.skip_bias_add else None
198+
199+
# Matrix multiply.
200+
assert self.quant_method is not None
201+
202+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
203+
output_parallel = self.quant_method.apply(self.layer, input_, bias)
204+
205+
if self.gather_output:
206+
# All-gather across the partitions.
207+
output = self.comm_group.all_gather(output_parallel)
208+
else:
209+
output = output_parallel
210+
output_bias = self.bias if self.skip_bias_add else None
211+
return output, output_bias
212+
213+
214+
class Flashcomm2QKVParallelOp(CustomColumnParallelOp):
215+
216+
def __init__(self, layer, prefix):
217+
super().__init__(layer)
218+
self.prefix = prefix
219+
220+
def apply_impl(
221+
self, input_: torch.Tensor
222+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
223+
"""Linear layer with column parallelism.
224+
225+
Implemented multiple optimization projects for dense models, such as FlashComm and
226+
communication-computation fusion.
227+
"""
228+
229+
bias = self.bias if not self.skip_bias_add else None
230+
231+
# Matrix multiply.
232+
assert self.quant_method is not None
233+
234+
layer_num = self.prefix.split('.')[2]
235+
236+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
237+
input_, layer_num != '0')
238+
output_parallel = self.quant_method.apply(self.layer, input_, bias)
239+
240+
if self.gather_output:
241+
# All-gather across the partitions.
242+
output = self.comm_group.all_gather(output_parallel)
243+
else:
244+
output = output_parallel
245+
output_bias = self.bias if self.skip_bias_add else None
246+
return output, output_bias
247+
248+
185249
class SequenceQKVParallelOp(CustomColumnParallelOp):
186250

187251
def __init__(self, layer, prefix):
@@ -316,20 +380,24 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
316380

317381
def __init__(self, layer):
318382
super().__init__(layer)
319-
self.forward_type = "flashcomm2_oproj_tp"
320383
self.odp_group = get_flashcomm2_odp_group()
321384
self.odp_size = self.odp_group.world_size
322385
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
323386
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
324387

325388
@property
326389
def comm_group(self):
390+
# TODO:otpsize==1时get_flashcomm2_otp_group=None;需要单独考虑
327391
return get_flashcomm2_otp_group()
328392

329393
def apply_impl(
330394
self,
331395
input_: torch.Tensor,
332396
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
397+
"""Linear layer for Flashcomm2.
398+
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
399+
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
400+
"""
333401
# Handle input parallelism - split or use as-is
334402
if self.input_is_parallel:
335403
input_parallel = input_
@@ -383,7 +451,7 @@ def apply_impl(
383451
# Only fuse bias add into GEMM for rank 0 (this ensures that
384452
# bias will not get added more than once in TP>1 case)
385453
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
386-
output_parallel = self.quant_method.apply(self,
454+
output_parallel = self.quant_method.apply(self.layer,
387455
input_parallel,
388456
bias=bias_)
389457
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
@@ -392,11 +460,6 @@ def apply_impl(
392460
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
393461
else:
394462
output = output_parallel
395-
if not forward_context.flashcomm1_ds_prefill:
396-
# flashcomm1 not enabled
397-
output = get_tp_group().all_gather(output, 0)
398-
if num_padding_tokens > 0:
399-
output = output[:-num_padding_tokens]
400463

401464
output_bias = self.bias if self.skip_bias_add else None
402465

@@ -510,22 +573,27 @@ def update_attrs(self):
510573
def get_column_parallel_op(
511574
disable_tp, prefix, layer
512575
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
513-
SequenceQKVParallelOp]], int, int]:
576+
SequenceQKVParallelOp, Flashcomm2MergedColumnParallelOp, Flashcomm2QKVParallelOp]], int, int]:
514577
if disable_tp:
515578
return None, 0, 1
516579

517580
custom_op: Optional[Union[
518581
MLPColumnParallelOp,
519582
SequenceMergedColumnParallelOp,
520583
SequenceQKVParallelOp,
584+
Flashcomm2MergedColumnParallelOp,
585+
Flashcomm2QKVParallelOp
521586
]] = None
522587
if "gate_up_proj" in prefix and mlp_tp_enable():
523588
custom_op = MLPColumnParallelOp(layer)
524589
elif "gate_up_proj" in prefix and enable_sp():
525590
custom_op = SequenceMergedColumnParallelOp(layer)
591+
elif "gate_up_proj" in prefix and flashcomm2_enable():
592+
custom_op = Flashcomm2MergedColumnParallelOp(layer)
526593
elif enable_sp():
527594
custom_op = SequenceQKVParallelOp(layer, prefix)
528-
595+
elif flashcomm2_enable():
596+
custom_op = Flashcomm2QKVParallelOp(layer, prefix)
529597
if custom_op is not None:
530598
return custom_op, custom_op.tp_rank, custom_op.tp_size
531599

vllm_ascend/ops/register_custom_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
2121

2222
if x.size(0) != residual.size(0):
2323
sp_enabled = forward_context.sp_enabled
24-
assert sp_enabled is True, ("Currently, this situation only occurs "
25-
"when sp is enabled")
24+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
25+
assert sp_enabled or flashcomm_v2_enabled is True, ("Currently, this situation only occurs "
26+
"when sp or flashcomm_v2 is enabled")
2627
pad_size = forward_context.pad_size
2728
if pad_size > 0:
2829
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -41,7 +42,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
4142
return x
4243

4344
sp_enabled = forward_context.sp_enabled
44-
if sp_enabled and label:
45+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
46+
if (sp_enabled or flashcomm_v2_enabled) and label:
4547
x = tensor_model_parallel_all_gather(x, 0)
4648
pad_size = forward_context.pad_size
4749
if pad_size > 0:
@@ -56,7 +58,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
5658
return tensor_model_parallel_all_reduce(x)
5759

5860
sp_enabled = forward_context.sp_enabled
59-
if sp_enabled:
61+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
62+
if sp_enabled or flashcomm_v2_enabled:
6063
pad_size = forward_context.pad_size
6164
if pad_size > 0:
6265
x = F.pad(x, (0, 0, 0, pad_size))

0 commit comments

Comments
 (0)