Skip to content

Commit 693c0ec

Browse files
author
Levi-JQ
committed
Independent Flashcomm2
1 parent b216949 commit 693c0ec

File tree

4 files changed

+112
-16
lines changed

4 files changed

+112
-16
lines changed

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,15 @@ def __init__(self, vllm_config):
103103
)
104104
if self.oproj_tensor_parallel_size is not None:
105105
raise AssertionError(
106-
"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
106+
f"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
107107
)
108108
if global_tp_size <= self.flashcomm2_oproj_tensor_parallel_size:
109109
raise AssertionError(
110-
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
110+
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
111111
)
112112
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
113113
raise AssertionError(
114-
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
114+
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
115115
)
116116

117117

vllm_ascend/models/layers/mla.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch import nn
2727
from vllm.attention import Attention, AttentionMetadata
2828
from vllm.config import CacheConfig, get_current_vllm_config
29+
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
2930
from vllm.forward_context import ForwardContext, get_forward_context
3031
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
3132
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -133,6 +134,18 @@ def forward(
133134
if num_tokens % self.tp_size:
134135
rows += 1
135136
output_shape = (rows, hidden_states.shape[1])
137+
138+
forward_context = get_forward_context()
139+
is_prefill = forward_context.with_prefill
140+
if forward_context.flashcomm_v2_enabled and forward_context.flashcomm1_ds_prefill:
141+
num_padding_tokens = forward_context.pad_size
142+
if is_prefill and self.debug_layer_idx > 0 and self.debug_layer_idx < self.layers:
143+
output_shape = hidden_states.shape
144+
else:
145+
B = (hidden_states.shape[0] + num_padding_tokens) // get_tensor_model_parallel_world_size()
146+
H = hidden_states.shape[1]
147+
output_shape = (B, H)
148+
136149
# FIXME: This does not seem right, should make sure the buffer is fixed
137150
output = torch.empty(output_shape,
138151
dtype=hidden_states.dtype,

vllm_ascend/ops/linear_op.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from typing import Optional, Tuple, Union
4040

41+
from torch import nn
4142
import torch
4243
import torch.distributed as dist
4344
import torch_npu
@@ -52,6 +53,7 @@
5253
from vllm_ascend.utils import (dense_optim_enable, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids,
5354
matmul_allreduce_enable, mlp_tp_enable,
5455
oproj_tp_enable)
56+
from vllm_ascend.ascend_config import get_ascend_config
5557

5658

5759
class CustomTensorParallelOp:
@@ -182,6 +184,69 @@ def apply_impl(
182184
return output, output_bias
183185

184186

187+
class Flashcomm2MergedColumnParallelOp(CustomColumnParallelOp):
188+
189+
def apply_impl(
190+
self, input_: torch.Tensor
191+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
192+
"""Linear layer with column parallelism.
193+
194+
Implemented multiple optimization projects for dense models, such as FlashComm and
195+
communication-computation fusion.
196+
"""
197+
198+
bias = self.bias if not self.skip_bias_add else None
199+
200+
# Matrix multiply.
201+
assert self.quant_method is not None
202+
203+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
204+
output_parallel = self.quant_method.apply(self.layer, input_, bias)
205+
206+
if self.gather_output:
207+
# All-gather across the partitions.
208+
output = self.comm_group.all_gather(output_parallel)
209+
else:
210+
output = output_parallel
211+
output_bias = self.bias if self.skip_bias_add else None
212+
return output, output_bias
213+
214+
215+
class Flashcomm2QKVParallelOp(CustomColumnParallelOp):
216+
217+
def __init__(self, layer, prefix):
218+
super().__init__(layer)
219+
self.prefix = prefix
220+
221+
def apply_impl(
222+
self, input_: torch.Tensor
223+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
224+
"""Linear layer with column parallelism.
225+
226+
Implemented multiple optimization projects for dense models, such as FlashComm and
227+
communication-computation fusion.
228+
"""
229+
230+
bias = self.bias if not self.skip_bias_add else None
231+
232+
# Matrix multiply.
233+
assert self.quant_method is not None
234+
235+
layer_num = self.prefix.split('.')[2]
236+
237+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
238+
input_, layer_num != '0')
239+
output_parallel = self.quant_method.apply(self.layer, input_, bias)
240+
241+
if self.gather_output:
242+
# All-gather across the partitions.
243+
output = self.comm_group.all_gather(output_parallel)
244+
else:
245+
output = output_parallel
246+
output_bias = self.bias if self.skip_bias_add else None
247+
return output, output_bias
248+
249+
185250
class SequenceQKVParallelOp(CustomColumnParallelOp):
186251

187252
def __init__(self, layer, prefix):
@@ -316,7 +381,6 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
316381

317382
def __init__(self, layer):
318383
super().__init__(layer)
319-
self.forward_type = "flashcomm2_oproj_tp"
320384
self.odp_group = get_flashcomm2_odp_group()
321385
self.odp_size = self.odp_group.world_size
322386
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
@@ -325,11 +389,27 @@ def __init__(self, layer):
325389
@property
326390
def comm_group(self):
327391
return get_flashcomm2_otp_group()
392+
393+
@property
394+
def tp_rank(self):
395+
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
396+
return 0
397+
return self.comm_group.rank_in_group
398+
399+
@property
400+
def tp_size(self):
401+
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
402+
return 1
403+
return self.comm_group.world_size
328404

329405
def apply_impl(
330406
self,
331407
input_: torch.Tensor,
332408
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
409+
"""Linear layer for Flashcomm2.
410+
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
411+
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
412+
"""
333413
# Handle input parallelism - split or use as-is
334414
if self.input_is_parallel:
335415
input_parallel = input_
@@ -383,7 +463,7 @@ def apply_impl(
383463
# Only fuse bias add into GEMM for rank 0 (this ensures that
384464
# bias will not get added more than once in TP>1 case)
385465
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
386-
output_parallel = self.quant_method.apply(self,
466+
output_parallel = self.quant_method.apply(self.layer,
387467
input_parallel,
388468
bias=bias_)
389469
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
@@ -392,11 +472,6 @@ def apply_impl(
392472
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
393473
else:
394474
output = output_parallel
395-
if not forward_context.flashcomm1_ds_prefill:
396-
# flashcomm1 not enabled
397-
output = get_tp_group().all_gather(output, 0)
398-
if num_padding_tokens > 0:
399-
output = output[:-num_padding_tokens]
400475

401476
output_bias = self.bias if self.skip_bias_add else None
402477

@@ -510,22 +585,27 @@ def update_attrs(self):
510585
def get_column_parallel_op(
511586
disable_tp, prefix, layer
512587
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
513-
SequenceQKVParallelOp]], int, int]:
588+
SequenceQKVParallelOp, Flashcomm2MergedColumnParallelOp, Flashcomm2QKVParallelOp]], int, int]:
514589
if disable_tp:
515590
return None, 0, 1
516591

517592
custom_op: Optional[Union[
518593
MLPColumnParallelOp,
519594
SequenceMergedColumnParallelOp,
520595
SequenceQKVParallelOp,
596+
Flashcomm2MergedColumnParallelOp,
597+
Flashcomm2QKVParallelOp
521598
]] = None
522599
if "gate_up_proj" in prefix and mlp_tp_enable():
523600
custom_op = MLPColumnParallelOp(layer)
524601
elif "gate_up_proj" in prefix and enable_sp():
525602
custom_op = SequenceMergedColumnParallelOp(layer)
603+
elif "gate_up_proj" in prefix and flashcomm2_enable():
604+
custom_op = Flashcomm2MergedColumnParallelOp(layer)
526605
elif enable_sp():
527606
custom_op = SequenceQKVParallelOp(layer, prefix)
528-
607+
elif flashcomm2_enable():
608+
custom_op = Flashcomm2QKVParallelOp(layer, prefix)
529609
if custom_op is not None:
530610
return custom_op, custom_op.tp_rank, custom_op.tp_size
531611

vllm_ascend/ops/register_custom_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
2222

2323
if x.size(0) != residual.size(0):
2424
sp_enabled = forward_context.sp_enabled
25-
assert sp_enabled is True, ("Currently, this situation only occurs "
26-
"when sp is enabled")
25+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
26+
assert sp_enabled or flashcomm_v2_enabled is True, ("Currently, this situation only occurs "
27+
"when sp or flashcomm_v2 is enabled")
2728
pad_size = forward_context.pad_size
2829
if pad_size > 0:
2930
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -42,7 +43,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
4243
return x
4344

4445
sp_enabled = forward_context.sp_enabled
45-
if sp_enabled and label:
46+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
47+
if (sp_enabled or flashcomm_v2_enabled) and label:
4648
x = tensor_model_parallel_all_gather(x, 0)
4749
pad_size = forward_context.pad_size
4850
if pad_size > 0:
@@ -57,7 +59,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
5759
return tensor_model_parallel_all_reduce(x)
5860

5961
sp_enabled = forward_context.sp_enabled
60-
if sp_enabled:
62+
flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled
63+
if sp_enabled or flashcomm_v2_enabled:
6164
pad_size = forward_context.pad_size
6265
if pad_size > 0:
6366
x = F.pad(x, (0, 0, 0, pad_size))

0 commit comments

Comments
 (0)