Skip to content

Commit 0b1bdac

Browse files
authored
[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by: wangxiyuan <[email protected]>
1 parent d94e302 commit 0b1bdac

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ def determine_expert_map(
682682
return (local_num_experts, expert_map)
683683

684684

685-
class FusedMoE(torch.nn.Module):
685+
@CustomOp.register("fused_moe")
686+
class FusedMoE(CustomOp):
686687
"""FusedMoE layer for MoE models.
687688
688689
This layer contains both MergedColumnParallel weights (gate_up_proj /

vllm/model_executor/layers/linear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
tensor_model_parallel_all_gather,
1717
tensor_model_parallel_all_reduce)
1818
from vllm.logger import init_logger
19+
from vllm.model_executor.custom_op import CustomOp
1920
from vllm.model_executor.layers.quantization.base_config import (
2021
QuantizationConfig, QuantizeMethodBase)
2122
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@@ -226,7 +227,7 @@ def apply(self,
226227
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
227228

228229

229-
class LinearBase(torch.nn.Module):
230+
class LinearBase(CustomOp):
230231
"""Base linear layer.
231232
232233
Args:
@@ -269,12 +270,8 @@ def __init__(
269270
prefix=prefix)
270271
self.return_bias = return_bias
271272

272-
def forward(
273-
self, x: torch.Tensor
274-
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
275-
raise NotImplementedError
276-
277273

274+
@CustomOp.register("replicated_linear")
278275
class ReplicatedLinear(LinearBase):
279276
"""Replicated linear layer.
280277
@@ -443,6 +440,7 @@ def weight_loader(self,
443440
param[shard_offset:shard_offset + shard_size] = loaded_weight
444441

445442

443+
@CustomOp.register("column_parallel_linear")
446444
class ColumnParallelLinear(LinearBase):
447445
"""Linear layer with column parallelism.
448446
@@ -1229,6 +1227,7 @@ def weight_loader(self,
12291227
param_data.copy_(loaded_weight)
12301228

12311229

1230+
@CustomOp.register("row_parallel_linear")
12321231
class RowParallelLinear(LinearBase):
12331232
"""Linear layer with row parallelism.
12341233
@@ -1405,6 +1404,7 @@ def extra_repr(self) -> str:
14051404
return s
14061405

14071406

1407+
@CustomOp.register("qkv_cross_parallel_linear")
14081408
class QKVCrossParallelLinear(LinearBase):
14091409
"""Linear layers for efficient cross-attention's QKV transformation.
14101410

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size,
1414
tensor_model_parallel_all_reduce)
15+
from vllm.model_executor.custom_op import CustomOp
1516
from vllm.model_executor.layers.quantization.base_config import (
1617
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
1718
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
159160
return input_, ~vocab_mask
160161

161162

162-
class VocabParallelEmbedding(torch.nn.Module):
163+
@CustomOp.register("vocab_parallel_embedding")
164+
class VocabParallelEmbedding(CustomOp):
163165
"""Embedding parallelized in the vocabulary dimension.
164166
165167
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to

0 commit comments

Comments
 (0)