|
46 | 46 | from vllm.distributed import split_tensor_along_last_dim
|
47 | 47 | from vllm.distributed.parallel_state import get_tp_group
|
48 | 48 |
|
49 |
| -from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, |
| 49 | +from vllm.forward_context import get_forward_context |
| 50 | +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, get_mlp_tp_group, |
50 | 51 | get_otp_group)
|
51 |
| -from vllm_ascend.utils import (dense_optim_enable, enable_sp, |
| 52 | +from vllm_ascend.utils import (dense_optim_enable, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, |
52 | 53 | matmul_allreduce_enable, mlp_tp_enable,
|
53 | 54 | oproj_tp_enable)
|
54 | 55 |
|
@@ -311,6 +312,104 @@ def update_attrs(self):
|
311 | 312 | self.input_size_per_partition = self.layer.input_size_per_partition
|
312 | 313 |
|
313 | 314 |
|
| 315 | +class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): |
| 316 | + |
| 317 | + def __init__(self, layer): |
| 318 | + super().__init__(layer) |
| 319 | + self.forward_type = "flashcomm2_oproj_tp" |
| 320 | + self.odp_group = get_flashcomm2_odp_group() |
| 321 | + self.odp_size = self.odp_group.world_size |
| 322 | + self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size) |
| 323 | + self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() |
| 324 | + |
| 325 | + @property |
| 326 | + def comm_group(self): |
| 327 | + return get_flashcomm2_otp_group() |
| 328 | + |
| 329 | + def apply_impl( |
| 330 | + self, |
| 331 | + input_: torch.Tensor, |
| 332 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 333 | + # Handle input parallelism - split or use as-is |
| 334 | + if self.input_is_parallel: |
| 335 | + input_parallel = input_ |
| 336 | + else: |
| 337 | + tp_rank = self.tp_rank |
| 338 | + splitted_input = split_tensor_along_last_dim( |
| 339 | + input_, num_partitions=self.tp_size) |
| 340 | + input_parallel = splitted_input[tp_rank].contiguous() |
| 341 | + |
| 342 | + # padding for all-to-all |
| 343 | + forward_context = get_forward_context() |
| 344 | + num_padding_tokens = forward_context.pad_size |
| 345 | + if num_padding_tokens > 0: |
| 346 | + input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens)) |
| 347 | + |
| 348 | + # Reorganize the tensor so that the batch id and rank id correspond to each other. |
| 349 | + chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0]) |
| 350 | + batch_size = input_parallel.size(0) |
| 351 | + |
| 352 | + assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})" |
| 353 | + |
| 354 | + batch_size_per_chunk = batch_size // chunk_num |
| 355 | + # Indices of reorganized tensor |
| 356 | + chunked = input_parallel.view(chunk_num, batch_size_per_chunk, input_parallel.shape[1]) |
| 357 | + reorganized_chunks = chunked[self.group_indices] |
| 358 | + send_buf = reorganized_chunks.flatten(1, 2) |
| 359 | + |
| 360 | + # all-to-all operation parameters |
| 361 | + all2all_tp_size = self.odp_size |
| 362 | + local_intermediate_size = input_parallel.size(1) |
| 363 | + chunk_size = input_parallel.size(0) // all2all_tp_size |
| 364 | + total_intermediate_size = local_intermediate_size * all2all_tp_size |
| 365 | + |
| 366 | + # Create receive buffer |
| 367 | + recv_buf = torch.empty( |
| 368 | + total_intermediate_size * chunk_size, |
| 369 | + dtype=input_parallel.dtype, |
| 370 | + device=input_parallel.device) |
| 371 | + |
| 372 | + # Perform all-to-all communication |
| 373 | + dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group) |
| 374 | + |
| 375 | + input_parallel = recv_buf.view( |
| 376 | + all2all_tp_size, |
| 377 | + chunk_size, |
| 378 | + -1 |
| 379 | + ).transpose(0, 1).reshape(chunk_size, -1) |
| 380 | + |
| 381 | + # Matrix multiply. |
| 382 | + assert self.quant_method is not None |
| 383 | + # Only fuse bias add into GEMM for rank 0 (this ensures that |
| 384 | + # bias will not get added more than once in TP>1 case) |
| 385 | + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias |
| 386 | + output_parallel = self.quant_method.apply(self, |
| 387 | + input_parallel, |
| 388 | + bias=bias_) |
| 389 | + # output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate] |
| 390 | + if self.tp_size > 1: |
| 391 | + # flashcomm2 with reduce-scatter |
| 392 | + output = self.comm_group.reduce_scatter(output_parallel, dim=0) |
| 393 | + else: |
| 394 | + output = output_parallel |
| 395 | + if not forward_context.flashcomm1_ds_prefill: |
| 396 | + # flashcomm1 not enabled |
| 397 | + output = get_tp_group().all_gather(output, 0) |
| 398 | + if num_padding_tokens > 0: |
| 399 | + output = output[:-num_padding_tokens] |
| 400 | + |
| 401 | + output_bias = self.bias if self.skip_bias_add else None |
| 402 | + |
| 403 | + if not self.return_bias: |
| 404 | + return output |
| 405 | + return output, output_bias |
| 406 | + |
| 407 | + def update_attrs(self): |
| 408 | + super().update_attrs() |
| 409 | + self.input_is_parallel = self.layer.input_is_parallel |
| 410 | + self.input_size_per_partition = self.layer.input_size_per_partition |
| 411 | + |
| 412 | + |
314 | 413 | class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
315 | 414 | _HCOMM_INFO = None
|
316 | 415 |
|
@@ -437,17 +536,19 @@ def get_row_parallel_op(
|
437 | 536 | disable_tp, prefix, layer
|
438 | 537 | ) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
439 | 538 | MatmulAllreduceRowParallelOp,
|
440 |
| - SequenceRowParallelOp]], int, int]: |
| 539 | + SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]], int, int]: |
441 | 540 | if disable_tp:
|
442 | 541 | return None, 0, 1
|
443 | 542 |
|
444 | 543 | custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
445 | 544 | MatmulAllreduceRowParallelOp,
|
446 |
| - SequenceRowParallelOp]] = None |
| 545 | + SequenceRowParallelOp, Flashcomm2OProjRowParallelOp]] = None |
447 | 546 | if "down_proj" in prefix and mlp_tp_enable():
|
448 | 547 | custom_op = MLPRowParallelOp(layer)
|
449 | 548 | elif "o_proj" in prefix and oproj_tp_enable():
|
450 | 549 | custom_op = OProjRowParallelOp(layer)
|
| 550 | + elif "o_proj" in prefix and flashcomm2_enable(): |
| 551 | + custom_op = Flashcomm2OProjRowParallelOp(layer) |
451 | 552 | elif matmul_allreduce_enable():
|
452 | 553 | custom_op = MatmulAllreduceRowParallelOp(layer)
|
453 | 554 | elif enable_sp():
|
|
0 commit comments