38
38
39
39
from typing import Optional , Tuple , Union
40
40
41
+ from torch import nn
41
42
import torch
42
43
import torch .distributed as dist
43
44
import torch_npu
@@ -182,6 +183,69 @@ def apply_impl(
182
183
return output , output_bias
183
184
184
185
186
+ class Flashcomm2MergedColumnParallelOp (CustomColumnParallelOp ):
187
+
188
+ def apply_impl (
189
+ self , input_ : torch .Tensor
190
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
191
+ """Linear layer with column parallelism.
192
+
193
+ Implemented multiple optimization projects for dense models, such as FlashComm and
194
+ communication-computation fusion.
195
+ """
196
+
197
+ bias = self .bias if not self .skip_bias_add else None
198
+
199
+ # Matrix multiply.
200
+ assert self .quant_method is not None
201
+
202
+ input_ = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (input_ , True )
203
+ output_parallel = self .quant_method .apply (self .layer , input_ , bias )
204
+
205
+ if self .gather_output :
206
+ # All-gather across the partitions.
207
+ output = self .comm_group .all_gather (output_parallel )
208
+ else :
209
+ output = output_parallel
210
+ output_bias = self .bias if self .skip_bias_add else None
211
+ return output , output_bias
212
+
213
+
214
+ class Flashcomm2QKVParallelOp (CustomColumnParallelOp ):
215
+
216
+ def __init__ (self , layer , prefix ):
217
+ super ().__init__ (layer )
218
+ self .prefix = prefix
219
+
220
+ def apply_impl (
221
+ self , input_ : torch .Tensor
222
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
223
+ """Linear layer with column parallelism.
224
+
225
+ Implemented multiple optimization projects for dense models, such as FlashComm and
226
+ communication-computation fusion.
227
+ """
228
+
229
+ bias = self .bias if not self .skip_bias_add else None
230
+
231
+ # Matrix multiply.
232
+ assert self .quant_method is not None
233
+
234
+ layer_num = self .prefix .split ('.' )[2 ]
235
+
236
+ input_ = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
237
+ input_ , layer_num != '0' )
238
+ output_parallel = self .quant_method .apply (self .layer , input_ , bias )
239
+
240
+ if self .gather_output :
241
+ # All-gather across the partitions.
242
+ output = self .comm_group .all_gather (output_parallel )
243
+ else :
244
+ output = output_parallel
245
+ output_bias = self .bias if self .skip_bias_add else None
246
+ return output , output_bias
247
+
248
+
185
249
class SequenceQKVParallelOp (CustomColumnParallelOp ):
186
250
187
251
def __init__ (self , layer , prefix ):
@@ -316,20 +380,24 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
316
380
317
381
def __init__ (self , layer ):
318
382
super ().__init__ (layer )
319
- self .forward_type = "flashcomm2_oproj_tp"
320
383
self .odp_group = get_flashcomm2_odp_group ()
321
384
self .odp_size = self .odp_group .world_size
322
385
self .reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids (get_tp_group ().world_size )
323
386
self .group_indices = torch .tensor (self .reorgnized_batch_ids ).npu ()
324
387
325
388
@property
326
389
def comm_group (self ):
390
+ # TODO:otpsize==1时get_flashcomm2_otp_group=None;需要单独考虑
327
391
return get_flashcomm2_otp_group ()
328
392
329
393
def apply_impl (
330
394
self ,
331
395
input_ : torch .Tensor ,
332
396
) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
397
+ """Linear layer for Flashcomm2.
398
+ Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
399
+ Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
400
+ """
333
401
# Handle input parallelism - split or use as-is
334
402
if self .input_is_parallel :
335
403
input_parallel = input_
@@ -383,7 +451,7 @@ def apply_impl(
383
451
# Only fuse bias add into GEMM for rank 0 (this ensures that
384
452
# bias will not get added more than once in TP>1 case)
385
453
bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
386
- output_parallel = self .quant_method .apply (self ,
454
+ output_parallel = self .quant_method .apply (self . layer ,
387
455
input_parallel ,
388
456
bias = bias_ )
389
457
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
@@ -392,11 +460,6 @@ def apply_impl(
392
460
output = self .comm_group .reduce_scatter (output_parallel , dim = 0 )
393
461
else :
394
462
output = output_parallel
395
- if not forward_context .flashcomm1_ds_prefill :
396
- # flashcomm1 not enabled
397
- output = get_tp_group ().all_gather (output , 0 )
398
- if num_padding_tokens > 0 :
399
- output = output [:- num_padding_tokens ]
400
463
401
464
output_bias = self .bias if self .skip_bias_add else None
402
465
@@ -510,22 +573,27 @@ def update_attrs(self):
510
573
def get_column_parallel_op (
511
574
disable_tp , prefix , layer
512
575
) -> Tuple [Optional [Union [MLPColumnParallelOp , SequenceMergedColumnParallelOp ,
513
- SequenceQKVParallelOp ]], int , int ]:
576
+ SequenceQKVParallelOp , Flashcomm2MergedColumnParallelOp , Flashcomm2QKVParallelOp ]], int , int ]:
514
577
if disable_tp :
515
578
return None , 0 , 1
516
579
517
580
custom_op : Optional [Union [
518
581
MLPColumnParallelOp ,
519
582
SequenceMergedColumnParallelOp ,
520
583
SequenceQKVParallelOp ,
584
+ Flashcomm2MergedColumnParallelOp ,
585
+ Flashcomm2QKVParallelOp
521
586
]] = None
522
587
if "gate_up_proj" in prefix and mlp_tp_enable ():
523
588
custom_op = MLPColumnParallelOp (layer )
524
589
elif "gate_up_proj" in prefix and enable_sp ():
525
590
custom_op = SequenceMergedColumnParallelOp (layer )
591
+ elif "gate_up_proj" in prefix and flashcomm2_enable ():
592
+ custom_op = Flashcomm2MergedColumnParallelOp (layer )
526
593
elif enable_sp ():
527
594
custom_op = SequenceQKVParallelOp (layer , prefix )
528
-
595
+ elif flashcomm2_enable ():
596
+ custom_op = Flashcomm2QKVParallelOp (layer , prefix )
529
597
if custom_op is not None :
530
598
return custom_op , custom_op .tp_rank , custom_op .tp_size
531
599
0 commit comments