Skip to content

Commit 32559d0

Browse files
author
Levi-JQ
committed
[main] flashcomm_v2 optim solution
1 parent 3d21ed9 commit 32559d0

File tree

7 files changed

+221
-9
lines changed

7 files changed

+221
-9
lines changed

vllm_ascend/ascend_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,25 @@ def __init__(self, vllm_config):
9292
raise AssertionError(
9393
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
9494
)
95+
self.flashcomm2_oproj_tensor_parallel_size = additional_config.get(
96+
"flashcomm2_oproj_tensor_parallel_size", None)
97+
if self.flashcomm2_oproj_tensor_parallel_size is not None:
98+
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
99+
logger.info(
100+
f"Enable Flashcomm2 with flashcomm2_oproj_tensor_parallel_size={self.flashcomm2_oproj_tensor_parallel_size} and global_tp_size={global_tp_size}"
101+
)
102+
if self.oproj_tensor_parallel_size is not None:
103+
raise AssertionError(
104+
"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
105+
)
106+
if global_tp_size <= self.flashcomm2_oproj_tensor_parallel_size:
107+
raise AssertionError(
108+
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
109+
)
110+
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
111+
raise AssertionError(
112+
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
113+
)
95114

96115

97116
class TorchairGraphConfig:

vllm_ascend/ascend_forward_context.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import vllm_ascend.envs as envs_ascend
1414
from vllm_ascend.utils import enable_sp
15+
from vllm_ascend.utils import flashcomm2_enable
1516

1617

1718
class FusedMoEState(Enum):
@@ -109,12 +110,18 @@ def set_ascend_forward_context(
109110
sp_enabled = enable_sp() and \
110111
tp_world_size > 1 and \
111112
num_tokens is not None and num_tokens > 1000
113+
114+
flashcomm_v2_enabled = flashcomm2_enable() and \
115+
tp_world_size > 1 and \
116+
num_tokens is not None
112117

113-
if sp_enabled:
118+
if sp_enabled or flashcomm_v2_enabled:
114119
pad_size = (tp_world_size -
115120
(num_tokens % tp_world_size)) % tp_world_size
116121
forward_context.pad_size = pad_size
122+
117123
forward_context.sp_enabled = sp_enabled
124+
forward_context.flashcomm_v2_enabled = flashcomm_v2_enabled
118125

119126
# set this for rope forward_oot using
120127
forward_context.is_first_layer = True

vllm_ascend/distributed/parallel_state.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import torch
44
from vllm.config import ParallelConfig
5-
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
5+
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, get_tp_group,
66
init_model_parallel_group)
7+
import vllm_ascend.envs as envs_ascend
8+
from vllm.logger import logger
9+
from vllm_ascend.utils import flashcomm2_enable, oproj_tp_enable
710

811
import vllm_ascend.envs as envs_ascend
912
from vllm_ascend.ascend_config import get_ascend_config
@@ -13,6 +16,9 @@
1316
_MLP_TP: Optional[GroupCoordinator] = None
1417
_OTP: Optional[GroupCoordinator] = None
1518
_LMTP: Optional[GroupCoordinator] = None
19+
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
20+
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
21+
1622

1723

1824
def get_mc2_group() -> GroupCoordinator:
@@ -25,12 +31,18 @@ def get_otp_group() -> GroupCoordinator:
2531
"output tensor parallel group is not initialized")
2632
return _OTP
2733

28-
2934
def get_lmhead_tp_group() -> GroupCoordinator:
3035
assert _LMTP is not None, (
3136
"lm head tensor parallel group is not initialized")
3237
return _LMTP
3338

39+
def get_flashcomm2_otp_group() -> GroupCoordinator:
40+
return _FLASHCOMM2_OTP
41+
42+
def get_flashcomm2_odp_group() -> GroupCoordinator:
43+
assert _FLASHCOMM2_ODP is not None, (
44+
"output data parallel group for flashcomm2 is not initialized")
45+
return _FLASHCOMM2_ODP
3446

3547
def get_mlp_tp_group() -> GroupCoordinator:
3648
assert _MLP_TP is not None, ("mlp group is not initialized")
@@ -110,6 +122,39 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
110122
get_world_group().local_rank,
111123
backend,
112124
group_name="lmheadtp")
125+
126+
if flashcomm2_enable():
127+
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
128+
global_tp_size = get_tp_group().world_size
129+
num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size)
130+
131+
global _FLASHCOMM2_OTP
132+
global _FLASHCOMM2_ODP
133+
134+
_FLASHCOMM2_OTP = None
135+
_FLASHCOMM2_ODP = get_tp_group()
136+
137+
if flashcomm2_otp_size > 1:
138+
otp_group_ranks = []
139+
odp_group_ranks = [[] for _ in range(flashcomm2_otp_size)]
140+
dp_group_index = torch.distributed.get_rank() // global_tp_size
141+
142+
for i in range(num_oproj_tensor_parallel_groups):
143+
ranks = []
144+
for j in range(flashcomm2_otp_size):
145+
rank_idx = dp_group_index * global_tp_size + i + j * num_oproj_tensor_parallel_groups
146+
ranks.append(rank_idx)
147+
odp_group_ranks[j].append(rank_idx)
148+
otp_group_ranks.append(ranks)
149+
150+
_FLASHCOMM2_OTP = init_model_parallel_group(otp_group_ranks,
151+
get_world_group().local_rank,
152+
backend,
153+
group_name="flashcomm2_otp")
154+
_FLASHCOMM2_ODP = init_model_parallel_group(odp_group_ranks,
155+
get_world_group().local_rank,
156+
backend,
157+
group_name="flashcomm2_odp")
113158

114159

115160
def get_mlp_tensor_model_parallel_world_size():
@@ -142,3 +187,13 @@ def destroy_ascend_model_parallel():
142187
if _OTP:
143188
_OTP.destroy()
144189
_OTP = None
190+
191+
global _FLASHCOMM2_OTP
192+
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
193+
_FLASHCOMM2_OTP.destroy()
194+
_FLASHCOMM2_OTP = None
195+
196+
global _FLASHCOMM2_ODP
197+
if _FLASHCOMM2_ODP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
198+
_FLASHCOMM2_ODP.destroy()
199+
_FLASHCOMM2_ODP = None

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def forward(
275275
self,
276276
input_,
277277
is_prefill: bool = True,
278+
is_force_scatter: bool = False,
278279
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
279280
if self.custom_op is not None:
280281
return self.custom_op.apply(input_)

vllm_ascend/ops/linear_op.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@
4646
from vllm.distributed import split_tensor_along_last_dim
4747
from vllm.distributed.parallel_state import get_tp_group
4848

49-
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
49+
from vllm.forward_context import get_forward_context
50+
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, get_mlp_tp_group,
5051
get_otp_group)
51-
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
52+
from vllm_ascend.utils import (dense_optim_enable, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids,
5253
matmul_allreduce_enable, mlp_tp_enable,
5354
oproj_tp_enable)
5455

@@ -311,6 +312,104 @@ def update_attrs(self):
311312
self.input_size_per_partition = self.layer.input_size_per_partition
312313

313314

315+
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
316+
317+
def __init__(self, layer):
318+
super().__init__(layer)
319+
self.forward_type = "flashcomm2_oproj_tp"
320+
self.odp_group = get_flashcomm2_odp_group()
321+
self.odp_size = self.odp_group.world_size
322+
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
323+
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
324+
325+
@property
326+
def comm_group(self):
327+
return get_flashcomm2_otp_group()
328+
329+
def apply_impl(
330+
self,
331+
input_: torch.Tensor,
332+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
333+
# Handle input parallelism - split or use as-is
334+
if self.input_is_parallel:
335+
input_parallel = input_
336+
else:
337+
tp_rank = self.tp_rank
338+
splitted_input = split_tensor_along_last_dim(
339+
input_, num_partitions=self.tp_size)
340+
input_parallel = splitted_input[tp_rank].contiguous()
341+
342+
# padding for all-to-all
343+
forward_context = get_forward_context()
344+
num_padding_tokens = forward_context.pad_size
345+
if num_padding_tokens > 0:
346+
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
347+
348+
# Reorganize the tensor so that the batch id and rank id correspond to each other.
349+
chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0])
350+
batch_size = input_parallel.size(0)
351+
352+
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
353+
354+
batch_size_per_chunk = batch_size // chunk_num
355+
# Indices of reorganized tensor
356+
chunked = input_parallel.view(chunk_num, batch_size_per_chunk, input_parallel.shape[1])
357+
reorganized_chunks = chunked[self.group_indices]
358+
send_buf = reorganized_chunks.flatten(1, 2)
359+
360+
# all-to-all operation parameters
361+
all2all_tp_size = self.odp_size
362+
local_intermediate_size = input_parallel.size(1)
363+
chunk_size = input_parallel.size(0) // all2all_tp_size
364+
total_intermediate_size = local_intermediate_size * all2all_tp_size
365+
366+
# Create receive buffer
367+
recv_buf = torch.empty(
368+
total_intermediate_size * chunk_size,
369+
dtype=input_parallel.dtype,
370+
device=input_parallel.device)
371+
372+
# Perform all-to-all communication
373+
dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group)
374+
375+
input_parallel = recv_buf.view(
376+
all2all_tp_size,
377+
chunk_size,
378+
-1
379+
).transpose(0, 1).reshape(chunk_size, -1)
380+
381+
# Matrix multiply.
382+
assert self.quant_method is not None
383+
# Only fuse bias add into GEMM for rank 0 (this ensures that
384+
# bias will not get added more than once in TP>1 case)
385+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
386+
output_parallel = self.quant_method.apply(self,
387+
input_parallel,
388+
bias=bias_)
389+
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
390+
if self.tp_size > 1:
391+
# flashcomm2 with reduce-scatter
392+
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
393+
else:
394+
output = output_parallel
395+
if not forward_context.flashcomm1_ds_prefill:
396+
# flashcomm1 not enabled
397+
output = get_tp_group().all_gather(output, 0)
398+
if num_padding_tokens > 0:
399+
output = output[:-num_padding_tokens]
400+
401+
output_bias = self.bias if self.skip_bias_add else None
402+
403+
if not self.return_bias:
404+
return output
405+
return output, output_bias
406+
407+
def update_attrs(self):
408+
super().update_attrs()
409+
self.input_is_parallel = self.layer.input_is_parallel
410+
self.input_size_per_partition = self.layer.input_size_per_partition
411+
412+
314413
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
315414
_HCOMM_INFO = None
316415

@@ -437,17 +536,19 @@ def get_row_parallel_op(
437536
disable_tp, prefix, layer
438537
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
439538
MatmulAllreduceRowParallelOp,
440-
SequenceRowParallelOp]], int, int]:
539+
SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]], int, int]:
441540
if disable_tp:
442541
return None, 0, 1
443542

444543
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
445544
MatmulAllreduceRowParallelOp,
446-
SequenceRowParallelOp]] = None
545+
SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]] = None
447546
if "down_proj" in prefix and mlp_tp_enable():
448547
custom_op = MLPRowParallelOp(layer)
449548
elif "o_proj" in prefix and oproj_tp_enable():
450549
custom_op = OProjRowParallelOp(layer)
550+
elif "o_proj" in prefix and flashcomm2_enable():
551+
custom_op = Flashcomm2OProjRowParallelOp(layer)
451552
elif matmul_allreduce_enable():
452553
custom_op = MatmulAllreduceRowParallelOp(layer)
453554
elif enable_sp():

vllm_ascend/quantization/quant_config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@
3636
from vllm.model_executor.utils import set_weight_attrs
3737

3838
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
39-
get_otp_group)
39+
get_otp_group,
40+
get_flashcomm2_otp_group)
4041
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
41-
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
42+
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, mlp_tp_enable,
4243
oproj_tp_enable)
4344

4445
from .utils import get_quant_method
46+
from vllm_ascend.ascend_config import get_ascend_config
4547

4648

4749
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
@@ -301,6 +303,11 @@ def apply(
301303
tp_rank = get_otp_group().rank_in_group
302304
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
303305
tp_rank = get_mlp_tp_group().rank_in_group
306+
elif layer.prefix.find("o_proj") != -1 and flashcomm2_enable():
307+
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
308+
tp_rank = 0
309+
else:
310+
tp_rank = get_flashcomm2_otp_group().rank_in_group
304311
else:
305312
tp_rank = get_tensor_model_parallel_rank()
306313
else:

vllm_ascend/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,25 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
642642
return nullcontext()
643643
assert target_stream is not None
644644
return torch.npu.stream(target_stream)
645+
646+
647+
def flashcomm2_enable() -> bool:
648+
return get_ascend_config().flashcomm2_oproj_tensor_parallel_size is not None
649+
650+
651+
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
652+
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
653+
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
654+
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
655+
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
656+
num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size)
657+
658+
reorgnized_batch_ids = []
659+
for i in range(num_oproj_tensor_parallel_groups):
660+
ranks = []
661+
for j in range(flashcomm2_otp_size):
662+
rank_idx = i + j * num_oproj_tensor_parallel_groups
663+
ranks.append(rank_idx)
664+
reorgnized_batch_ids.append(ranks)
665+
666+
return reorgnized_batch_ids

0 commit comments

Comments
 (0)