|
38 | 38 |
|
39 | 39 | from typing import Optional, Tuple, Union
|
40 | 40 |
|
| 41 | +from torch import nn |
41 | 42 | import torch
|
42 | 43 | import torch.distributed as dist
|
43 | 44 | import torch_npu
|
@@ -182,6 +183,69 @@ def apply_impl(
|
182 | 183 | return output, output_bias
|
183 | 184 |
|
184 | 185 |
|
| 186 | +class Flashcomm2MergedColumnParallelOp(CustomColumnParallelOp): |
| 187 | + |
| 188 | + def apply_impl( |
| 189 | + self, input_: torch.Tensor |
| 190 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 191 | + """Linear layer with column parallelism. |
| 192 | +
|
| 193 | + Implemented multiple optimization projects for dense models, such as FlashComm and |
| 194 | + communication-computation fusion. |
| 195 | + """ |
| 196 | + |
| 197 | + bias = self.bias if not self.skip_bias_add else None |
| 198 | + |
| 199 | + # Matrix multiply. |
| 200 | + assert self.quant_method is not None |
| 201 | + |
| 202 | + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) |
| 203 | + output_parallel = self.quant_method.apply(self.layer, input_, bias) |
| 204 | + |
| 205 | + if self.gather_output: |
| 206 | + # All-gather across the partitions. |
| 207 | + output = self.comm_group.all_gather(output_parallel) |
| 208 | + else: |
| 209 | + output = output_parallel |
| 210 | + output_bias = self.bias if self.skip_bias_add else None |
| 211 | + return output, output_bias |
| 212 | + |
| 213 | + |
| 214 | +class Flashcomm2QKVParallelOp(CustomColumnParallelOp): |
| 215 | + |
| 216 | + def __init__(self, layer, prefix): |
| 217 | + super().__init__(layer) |
| 218 | + self.prefix = prefix |
| 219 | + |
| 220 | + def apply_impl( |
| 221 | + self, input_: torch.Tensor |
| 222 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
| 223 | + """Linear layer with column parallelism. |
| 224 | +
|
| 225 | + Implemented multiple optimization projects for dense models, such as FlashComm and |
| 226 | + communication-computation fusion. |
| 227 | + """ |
| 228 | + |
| 229 | + bias = self.bias if not self.skip_bias_add else None |
| 230 | + |
| 231 | + # Matrix multiply. |
| 232 | + assert self.quant_method is not None |
| 233 | + |
| 234 | + layer_num = self.prefix.split('.')[2] |
| 235 | + |
| 236 | + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
| 237 | + input_, layer_num != '0') |
| 238 | + output_parallel = self.quant_method.apply(self.layer, input_, bias) |
| 239 | + |
| 240 | + if self.gather_output: |
| 241 | + # All-gather across the partitions. |
| 242 | + output = self.comm_group.all_gather(output_parallel) |
| 243 | + else: |
| 244 | + output = output_parallel |
| 245 | + output_bias = self.bias if self.skip_bias_add else None |
| 246 | + return output, output_bias |
| 247 | + |
| 248 | + |
185 | 249 | class SequenceQKVParallelOp(CustomColumnParallelOp):
|
186 | 250 |
|
187 | 251 | def __init__(self, layer, prefix):
|
@@ -330,6 +394,10 @@ def apply_impl(
|
330 | 394 | self,
|
331 | 395 | input_: torch.Tensor,
|
332 | 396 | ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
| 397 | + """Linear layer for Flashcomm2. |
| 398 | + Input.ahspe = [batchsize*seqlength, headnum*headdim/TP] |
| 399 | + Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] |
| 400 | + """ |
333 | 401 | # Handle input parallelism - split or use as-is
|
334 | 402 | if self.input_is_parallel:
|
335 | 403 | input_parallel = input_
|
@@ -392,11 +460,6 @@ def apply_impl(
|
392 | 460 | output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
393 | 461 | else:
|
394 | 462 | output = output_parallel
|
395 |
| - if not forward_context.flashcomm1_ds_prefill: |
396 |
| - # flashcomm1 not enabled |
397 |
| - output = get_tp_group().all_gather(output, 0) |
398 |
| - if num_padding_tokens > 0: |
399 |
| - output = output[:-num_padding_tokens] |
400 | 463 |
|
401 | 464 | output_bias = self.bias if self.skip_bias_add else None
|
402 | 465 |
|
@@ -510,22 +573,27 @@ def update_attrs(self):
|
510 | 573 | def get_column_parallel_op(
|
511 | 574 | disable_tp, prefix, layer
|
512 | 575 | ) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
|
513 |
| - SequenceQKVParallelOp]], int, int]: |
| 576 | + SequenceQKVParallelOp, Flashcomm2MergedColumnParallelOp, Flashcomm2QKVParallelOp]], int, int]: |
514 | 577 | if disable_tp:
|
515 | 578 | return None, 0, 1
|
516 | 579 |
|
517 | 580 | custom_op: Optional[Union[
|
518 | 581 | MLPColumnParallelOp,
|
519 | 582 | SequenceMergedColumnParallelOp,
|
520 | 583 | SequenceQKVParallelOp,
|
| 584 | + Flashcomm2MergedColumnParallelOp, |
| 585 | + Flashcomm2QKVParallelOp |
521 | 586 | ]] = None
|
522 | 587 | if "gate_up_proj" in prefix and mlp_tp_enable():
|
523 | 588 | custom_op = MLPColumnParallelOp(layer)
|
524 | 589 | elif "gate_up_proj" in prefix and enable_sp():
|
525 | 590 | custom_op = SequenceMergedColumnParallelOp(layer)
|
| 591 | + elif "gate_up_proj" in prefix and flashcomm2_enable(): |
| 592 | + custom_op = Flashcomm2MergedColumnParallelOp(layer) |
526 | 593 | elif enable_sp():
|
527 | 594 | custom_op = SequenceQKVParallelOp(layer, prefix)
|
528 |
| - |
| 595 | + elif flashcomm2_enable(): |
| 596 | + custom_op = Flashcomm2QKVParallelOp(layer, prefix) |
529 | 597 | if custom_op is not None:
|
530 | 598 | return custom_op, custom_op.tp_rank, custom_op.tp_size
|
531 | 599 |
|
|
0 commit comments