38
38
39
39
from typing import Optional , Tuple , Union
40
40
41
+ from torch import nn
41
42
import torch
42
43
import torch .distributed as dist
43
44
import torch_npu
52
53
from vllm_ascend .utils import (dense_optim_enable , enable_sp , flashcomm2_enable , get_flashcomm2_reorgnized_batch_ids ,
53
54
matmul_allreduce_enable , mlp_tp_enable ,
54
55
oproj_tp_enable )
56
+ from vllm_ascend .ascend_config import get_ascend_config
55
57
56
58
57
59
class CustomTensorParallelOp :
@@ -182,6 +184,69 @@ def apply_impl(
182
184
return output , output_bias
183
185
184
186
187
+ class Flashcomm2MergedColumnParallelOp (CustomColumnParallelOp ):
188
+
189
+ def apply_impl (
190
+ self , input_ : torch .Tensor
191
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
192
+ """Linear layer with column parallelism.
193
+
194
+ Implemented multiple optimization projects for dense models, such as FlashComm and
195
+ communication-computation fusion.
196
+ """
197
+
198
+ bias = self .bias if not self .skip_bias_add else None
199
+
200
+ # Matrix multiply.
201
+ assert self .quant_method is not None
202
+
203
+ input_ = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (input_ , True )
204
+ output_parallel = self .quant_method .apply (self .layer , input_ , bias )
205
+
206
+ if self .gather_output :
207
+ # All-gather across the partitions.
208
+ output = self .comm_group .all_gather (output_parallel )
209
+ else :
210
+ output = output_parallel
211
+ output_bias = self .bias if self .skip_bias_add else None
212
+ return output , output_bias
213
+
214
+
215
+ class Flashcomm2QKVParallelOp (CustomColumnParallelOp ):
216
+
217
+ def __init__ (self , layer , prefix ):
218
+ super ().__init__ (layer )
219
+ self .prefix = prefix
220
+
221
+ def apply_impl (
222
+ self , input_ : torch .Tensor
223
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
224
+ """Linear layer with column parallelism.
225
+
226
+ Implemented multiple optimization projects for dense models, such as FlashComm and
227
+ communication-computation fusion.
228
+ """
229
+
230
+ bias = self .bias if not self .skip_bias_add else None
231
+
232
+ # Matrix multiply.
233
+ assert self .quant_method is not None
234
+
235
+ layer_num = self .prefix .split ('.' )[2 ]
236
+
237
+ input_ = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
238
+ input_ , layer_num != '0' )
239
+ output_parallel = self .quant_method .apply (self .layer , input_ , bias )
240
+
241
+ if self .gather_output :
242
+ # All-gather across the partitions.
243
+ output = self .comm_group .all_gather (output_parallel )
244
+ else :
245
+ output = output_parallel
246
+ output_bias = self .bias if self .skip_bias_add else None
247
+ return output , output_bias
248
+
249
+
185
250
class SequenceQKVParallelOp (CustomColumnParallelOp ):
186
251
187
252
def __init__ (self , layer , prefix ):
@@ -316,7 +381,6 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
316
381
317
382
def __init__ (self , layer ):
318
383
super ().__init__ (layer )
319
- self .forward_type = "flashcomm2_oproj_tp"
320
384
self .odp_group = get_flashcomm2_odp_group ()
321
385
self .odp_size = self .odp_group .world_size
322
386
self .reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids (get_tp_group ().world_size )
@@ -325,11 +389,27 @@ def __init__(self, layer):
325
389
@property
326
390
def comm_group (self ):
327
391
return get_flashcomm2_otp_group ()
392
+
393
+ @property
394
+ def tp_rank (self ):
395
+ if get_ascend_config ().flashcomm2_oproj_tensor_parallel_size == 1 :
396
+ return 0
397
+ return self .comm_group .rank_in_group
398
+
399
+ @property
400
+ def tp_size (self ):
401
+ if get_ascend_config ().flashcomm2_oproj_tensor_parallel_size == 1 :
402
+ return 1
403
+ return self .comm_group .world_size
328
404
329
405
def apply_impl (
330
406
self ,
331
407
input_ : torch .Tensor ,
332
408
) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
409
+ """Linear layer for Flashcomm2.
410
+ Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
411
+ Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
412
+ """
333
413
# Handle input parallelism - split or use as-is
334
414
if self .input_is_parallel :
335
415
input_parallel = input_
@@ -383,7 +463,7 @@ def apply_impl(
383
463
# Only fuse bias add into GEMM for rank 0 (this ensures that
384
464
# bias will not get added more than once in TP>1 case)
385
465
bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
386
- output_parallel = self .quant_method .apply (self ,
466
+ output_parallel = self .quant_method .apply (self . layer ,
387
467
input_parallel ,
388
468
bias = bias_ )
389
469
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
@@ -392,11 +472,6 @@ def apply_impl(
392
472
output = self .comm_group .reduce_scatter (output_parallel , dim = 0 )
393
473
else :
394
474
output = output_parallel
395
- if not forward_context .flashcomm1_ds_prefill :
396
- # flashcomm1 not enabled
397
- output = get_tp_group ().all_gather (output , 0 )
398
- if num_padding_tokens > 0 :
399
- output = output [:- num_padding_tokens ]
400
475
401
476
output_bias = self .bias if self .skip_bias_add else None
402
477
@@ -510,22 +585,27 @@ def update_attrs(self):
510
585
def get_column_parallel_op (
511
586
disable_tp , prefix , layer
512
587
) -> Tuple [Optional [Union [MLPColumnParallelOp , SequenceMergedColumnParallelOp ,
513
- SequenceQKVParallelOp ]], int , int ]:
588
+ SequenceQKVParallelOp , Flashcomm2MergedColumnParallelOp , Flashcomm2QKVParallelOp ]], int , int ]:
514
589
if disable_tp :
515
590
return None , 0 , 1
516
591
517
592
custom_op : Optional [Union [
518
593
MLPColumnParallelOp ,
519
594
SequenceMergedColumnParallelOp ,
520
595
SequenceQKVParallelOp ,
596
+ Flashcomm2MergedColumnParallelOp ,
597
+ Flashcomm2QKVParallelOp
521
598
]] = None
522
599
if "gate_up_proj" in prefix and mlp_tp_enable ():
523
600
custom_op = MLPColumnParallelOp (layer )
524
601
elif "gate_up_proj" in prefix and enable_sp ():
525
602
custom_op = SequenceMergedColumnParallelOp (layer )
603
+ elif "gate_up_proj" in prefix and flashcomm2_enable ():
604
+ custom_op = Flashcomm2MergedColumnParallelOp (layer )
526
605
elif enable_sp ():
527
606
custom_op = SequenceQKVParallelOp (layer , prefix )
528
-
607
+ elif flashcomm2_enable ():
608
+ custom_op = Flashcomm2QKVParallelOp (layer , prefix )
529
609
if custom_op is not None :
530
610
return custom_op , custom_op .tp_rank , custom_op .tp_size
531
611
0 commit comments