diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 65ea3ea0d2..ee642b0ab5 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -94,6 +94,25 @@ def __init__(self, vllm_config): raise AssertionError( "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." ) + self.flashcomm2_oproj_tensor_parallel_size = additional_config.get( + "flashcomm2_oproj_tensor_parallel_size", None) + if self.flashcomm2_oproj_tensor_parallel_size is not None: + global_tp_size = vllm_config.parallel_config.tensor_parallel_size + logger.info( + f"Enable Flashcomm2 with flashcomm2_oproj_tensor_parallel_size={self.flashcomm2_oproj_tensor_parallel_size} and global_tp_size={global_tp_size}" + ) + if self.oproj_tensor_parallel_size is not None: + raise AssertionError( + f"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" + ) + if global_tp_size <= self.flashcomm2_oproj_tensor_parallel_size: + raise AssertionError( + f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" + ) + if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: + raise AssertionError( + f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" + ) class TorchairGraphConfig: diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 607f02923f..8062979cb5 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -12,6 +12,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import flashcomm2_enable class FusedMoEState(Enum): @@ -109,12 +110,18 @@ def set_ascend_forward_context( sp_enabled = enable_sp(vllm_config) and \ tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 + + flashcomm_v2_enabled = flashcomm2_enable() and \ + tp_world_size > 1 and \ + num_tokens is not None - if sp_enabled: + if sp_enabled or flashcomm_v2_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size + forward_context.sp_enabled = sp_enabled + forward_context.flashcomm_v2_enabled = flashcomm_v2_enabled # set this for rope forward_oot using forward_context.is_first_layer = True diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 07c707e3f5..0d8dd5a85d 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -2,8 +2,11 @@ import torch from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, +from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, get_tp_group, init_model_parallel_group) +import vllm_ascend.envs as envs_ascend +from vllm.logger import logger +from vllm_ascend.utils import flashcomm2_enable, oproj_tp_enable import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -13,6 +16,9 @@ _MLP_TP: Optional[GroupCoordinator] = None _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None +_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None +_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None + def get_mc2_group() -> GroupCoordinator: @@ -25,12 +31,18 @@ def get_otp_group() -> GroupCoordinator: "output tensor parallel group is not initialized") return _OTP - def get_lmhead_tp_group() -> GroupCoordinator: assert _LMTP is not None, ( "lm head tensor parallel group is not initialized") return _LMTP +def get_flashcomm2_otp_group() -> GroupCoordinator: + return _FLASHCOMM2_OTP + +def get_flashcomm2_odp_group() -> GroupCoordinator: + assert _FLASHCOMM2_ODP is not None, ( + "output data parallel group for flashcomm2 is not initialized") + return _FLASHCOMM2_ODP def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") @@ -110,6 +122,39 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): get_world_group().local_rank, backend, group_name="lmheadtp") + + if flashcomm2_enable(): + flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size + global_tp_size = get_tp_group().world_size + num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) + + global _FLASHCOMM2_OTP + global _FLASHCOMM2_ODP + + _FLASHCOMM2_OTP = None + _FLASHCOMM2_ODP = get_tp_group() + + if flashcomm2_otp_size > 1: + otp_group_ranks = [] + odp_group_ranks = [[] for _ in range(flashcomm2_otp_size)] + dp_group_index = torch.distributed.get_rank() // global_tp_size + + for i in range(num_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = dp_group_index * global_tp_size + i + j * num_oproj_tensor_parallel_groups + ranks.append(rank_idx) + odp_group_ranks[j].append(rank_idx) + otp_group_ranks.append(ranks) + + _FLASHCOMM2_OTP = init_model_parallel_group(otp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_otp") + _FLASHCOMM2_ODP = init_model_parallel_group(odp_group_ranks, + get_world_group().local_rank, + backend, + group_name="flashcomm2_odp") def get_mlp_tensor_model_parallel_world_size(): @@ -142,3 +187,13 @@ def destroy_ascend_model_parallel(): if _OTP: _OTP.destroy() _OTP = None + + global _FLASHCOMM2_OTP + if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_OTP.destroy() + _FLASHCOMM2_OTP = None + + global _FLASHCOMM2_ODP + if _FLASHCOMM2_ODP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: + _FLASHCOMM2_ODP.destroy() + _FLASHCOMM2_ODP = None diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 57c91bd278..b7a5bbeb5d 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig @@ -133,6 +134,18 @@ def forward( if num_tokens % self.tp_size: rows += 1 output_shape = (rows, hidden_states.shape[1]) + + forward_context = get_forward_context() + is_prefill = forward_context.with_prefill + if forward_context.flashcomm_v2_enabled and forward_context.flashcomm1_ds_prefill: + num_padding_tokens = forward_context.pad_size + if is_prefill and self.debug_layer_idx > 0 and self.debug_layer_idx < self.layers: + output_shape = hidden_states.shape + else: + B = (hidden_states.shape[0] + num_padding_tokens) // get_tensor_model_parallel_world_size() + H = hidden_states.shape[1] + output_shape = (B, H) + # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 51399cc7fa..5bb589a697 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -275,6 +275,7 @@ def forward( self, input_, is_prefill: bool = True, + is_force_scatter: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.custom_op is not None: return self.custom_op.apply(input_) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 819af72242..ce8b7ca8a8 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -38,6 +38,7 @@ from typing import Optional, Tuple, Union +from torch import nn import torch import torch.distributed as dist import torch_npu @@ -46,9 +47,10 @@ from vllm.distributed import split_tensor_along_last_dim from vllm.distributed.parallel_state import get_tp_group -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, +from vllm.forward_context import get_forward_context +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (dense_optim_enable, enable_sp, +from vllm_ascend.utils import (dense_optim_enable, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable) @@ -181,6 +183,69 @@ def apply_impl( return output, output_bias +class Flashcomm2MergedColumnParallelOp(CustomColumnParallelOp): + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class Flashcomm2QKVParallelOp(CustomColumnParallelOp): + + def __init__(self, layer, prefix): + super().__init__(layer) + self.prefix = prefix + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + layer_num = self.prefix.split('.')[2] + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, layer_num != '0') + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + class SequenceQKVParallelOp(CustomColumnParallelOp): def __init__(self, layer, prefix): @@ -311,6 +376,103 @@ def update_attrs(self): self.input_size_per_partition = self.layer.input_size_per_partition +class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + self.odp_group = get_flashcomm2_odp_group() + self.odp_size = self.odp_group.world_size + self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size) + self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() + + @property + def comm_group(self): + # TODO:otpsize==1时get_flashcomm2_otp_group=None;需要单独考虑 + return get_flashcomm2_otp_group() + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer for Flashcomm2. + Input.ahspe = [batchsize*seqlength, headnum*headdim/TP] + Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] + """ + # Handle input parallelism - split or use as-is + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = self.tp_rank + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # padding for all-to-all + forward_context = get_forward_context() + num_padding_tokens = forward_context.pad_size + if num_padding_tokens > 0: + input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens)) + + # Reorganize the tensor so that the batch id and rank id correspond to each other. + chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0]) + batch_size = input_parallel.size(0) + + assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})" + + batch_size_per_chunk = batch_size // chunk_num + # Indices of reorganized tensor + chunked = input_parallel.view(chunk_num, batch_size_per_chunk, input_parallel.shape[1]) + reorganized_chunks = chunked[self.group_indices] + send_buf = reorganized_chunks.flatten(1, 2) + + # all-to-all operation parameters + all2all_tp_size = self.odp_size + local_intermediate_size = input_parallel.size(1) + chunk_size = input_parallel.size(0) // all2all_tp_size + total_intermediate_size = local_intermediate_size * all2all_tp_size + + # Create receive buffer + recv_buf = torch.empty( + total_intermediate_size * chunk_size, + dtype=input_parallel.dtype, + device=input_parallel.device) + + # Perform all-to-all communication + dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group) + + input_parallel = recv_buf.view( + all2all_tp_size, + chunk_size, + -1 + ).transpose(0, 1).reshape(chunk_size, -1) + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + # output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate] + if self.tp_size > 1: + # flashcomm2 with reduce-scatter + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.input_size_per_partition = self.layer.input_size_per_partition + + class MatmulAllreduceRowParallelOp(CustomRowParallelOp): _HCOMM_INFO = None @@ -411,7 +573,7 @@ def update_attrs(self): def get_column_parallel_op( disable_tp, prefix, layer ) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, - SequenceQKVParallelOp]], int, int]: + SequenceQKVParallelOp, Flashcomm2MergedColumnParallelOp, Flashcomm2QKVParallelOp]], int, int]: if disable_tp: return None, 0, 1 @@ -419,14 +581,19 @@ def get_column_parallel_op( MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, + Flashcomm2MergedColumnParallelOp, + Flashcomm2QKVParallelOp ]] = None if "gate_up_proj" in prefix and mlp_tp_enable(): custom_op = MLPColumnParallelOp(layer) elif "gate_up_proj" in prefix and enable_sp(): custom_op = SequenceMergedColumnParallelOp(layer) + elif "gate_up_proj" in prefix and flashcomm2_enable(): + custom_op = Flashcomm2MergedColumnParallelOp(layer) elif enable_sp(): custom_op = SequenceQKVParallelOp(layer, prefix) - + elif flashcomm2_enable(): + custom_op = Flashcomm2QKVParallelOp(layer, prefix) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size @@ -437,17 +604,19 @@ def get_row_parallel_op( disable_tp, prefix, layer ) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]], int, int]: + SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]], int, int]: if disable_tp: return None, 0, 1 custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]] = None + SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]] = None if "down_proj" in prefix and mlp_tp_enable(): custom_op = MLPRowParallelOp(layer) elif "o_proj" in prefix and oproj_tp_enable(): custom_op = OProjRowParallelOp(layer) + elif "o_proj" in prefix and flashcomm2_enable(): + custom_op = Flashcomm2OProjRowParallelOp(layer) elif matmul_allreduce_enable(): custom_op = MatmulAllreduceRowParallelOp(layer) elif enable_sp(): diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 438bff1935..2b71b2e6ca 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -22,8 +22,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, if x.size(0) != residual.size(0): sp_enabled = forward_context.sp_enabled - assert sp_enabled is True, ("Currently, this situation only occurs " - "when sp is enabled") + flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled + assert sp_enabled or flashcomm_v2_enabled is True, ("Currently, this situation only occurs " + "when sp or flashcomm_v2 is enabled") pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -42,7 +43,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, return x sp_enabled = forward_context.sp_enabled - if sp_enabled and label: + flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled + if (sp_enabled or flashcomm_v2_enabled) and label: x = tensor_model_parallel_all_gather(x, 0) pad_size = forward_context.pad_size if pad_size > 0: @@ -57,7 +59,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: return tensor_model_parallel_all_reduce(x) sp_enabled = forward_context.sp_enabled - if sp_enabled: + flashcomm_v2_enabled = forward_context.flashcomm_v2_enabled + if sp_enabled or flashcomm_v2_enabled: pad_size = forward_context.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 130251cdc8..740f91f85a 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -36,12 +36,14 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, - get_otp_group) + get_otp_group, + get_flashcomm2_otp_group) from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, mlp_tp_enable, oproj_tp_enable) from .utils import get_quant_method +from vllm_ascend.ascend_config import get_ascend_config @register_quantization_config(ASCEND_QUANTIZATION_METHOD) @@ -301,6 +303,11 @@ def apply( tp_rank = get_otp_group().rank_in_group elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): tp_rank = get_mlp_tp_group().rank_in_group + elif layer.prefix.find("o_proj") != -1 and flashcomm2_enable(): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: + tp_rank = 0 + else: + tp_rank = get_flashcomm2_otp_group().rank_in_group else: tp_rank = get_tensor_model_parallel_rank() else: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 805fd5793a..c979df5eff 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -650,3 +650,25 @@ def npu_stream_switch(target_stream: torch.npu.Stream, return nullcontext() assert target_stream is not None return torch.npu.stream(target_stream) + + +def flashcomm2_enable() -> bool: + return get_ascend_config().flashcomm2_oproj_tensor_parallel_size is not None + + +def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: + # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain. + # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2, + # the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]]. + flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size + num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) + + reorgnized_batch_ids = [] + for i in range(num_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + rank_idx = i + j * num_oproj_tensor_parallel_groups + ranks.append(rank_idx) + reorgnized_batch_ids.append(ranks) + + return reorgnized_batch_ids