@@ -31,6 +31,79 @@ class MoEArgs:
31
31
load_balance_coeff : float | None = 1e-3
32
32
33
33
34
+ # TODO: keeping this for-loop implementation for comparison
35
+ # and readability, may remove later
36
+ @expert_parallel
37
+ def _run_experts_for_loop (
38
+ w1 : torch .Tensor ,
39
+ w2 : torch .Tensor ,
40
+ w3 : torch .Tensor ,
41
+ x : torch .Tensor ,
42
+ num_tokens_per_expert : torch .Tensor | None = None ,
43
+ ) -> torch .Tensor :
44
+ if num_tokens_per_expert is not None :
45
+ # NOTE: this would incur a synchronization between device and host
46
+ num_tokens_per_expert = num_tokens_per_expert .tolist ()
47
+
48
+ # side-effect code due to the usage of generate_permute_indices
49
+ num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
50
+
51
+ # a tuple of tensors indexed by experts
52
+ # each with shape (tokens_per_expert(varying), dim)
53
+ x = torch .split (
54
+ x [: sum (num_tokens_per_expert )],
55
+ split_size_or_sections = num_tokens_per_expert ,
56
+ dim = 0 ,
57
+ )
58
+ out_experts_splits = []
59
+ for expert_idx , x_expert in enumerate (x ):
60
+ h = F .silu (torch .matmul (x_expert , w1 [expert_idx ].transpose (- 2 , - 1 )))
61
+ h = h * torch .matmul (x_expert , w3 [expert_idx ].transpose (- 2 , - 1 ))
62
+ h = torch .matmul (h , w2 [expert_idx ].transpose (- 2 , - 1 ))
63
+ # h shape (tokens_per_expert(varying), dim)
64
+ out_experts_splits .append (h )
65
+ out = torch .cat (out_experts_splits , dim = 0 )
66
+
67
+ # side-effect code due to the usage of generate_permute_indices
68
+ out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
69
+ else :
70
+ # x shape (num_experts, tokens_per_expert, dim)
71
+ h = F .silu (torch .bmm (x , w1 .transpose (- 2 , - 1 )))
72
+ h = h * torch .bmm (x , w3 .transpose (- 2 , - 1 ))
73
+ # out shape (num_experts, tokens_per_expert, dim)
74
+ out = torch .bmm (h , w2 .transpose (- 2 , - 1 ))
75
+
76
+ return out
77
+
78
+
79
+ @expert_parallel
80
+ def _run_experts_grouped_mm (
81
+ w1 : torch .Tensor ,
82
+ w2 : torch .Tensor ,
83
+ w3 : torch .Tensor ,
84
+ x : torch .Tensor ,
85
+ num_tokens_per_expert : torch .Tensor | None = None ,
86
+ ) -> torch .Tensor :
87
+ if num_tokens_per_expert is not None :
88
+ offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
89
+ # grouped mm between a 2D tensor and a 3D tensor
90
+ assert x .dim () == 2
91
+ else :
92
+ offsets = None
93
+ # fall back to regular bmm between 3D tensors
94
+ assert x .dim () == 3
95
+
96
+ h = F .silu (
97
+ torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets )
98
+ )
99
+ h = h * torch ._grouped_mm (
100
+ x .bfloat16 (), w3 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets
101
+ )
102
+ out = torch ._grouped_mm (h , w2 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets ).type_as (x )
103
+
104
+ return out
105
+
106
+
34
107
class GroupedExperts (nn .Module ):
35
108
def __init__ (
36
109
self ,
@@ -52,91 +125,14 @@ def forward(
52
125
num_tokens_per_expert : torch .Tensor | None = None ,
53
126
) -> torch .Tensor :
54
127
if self .use_grouped_mm :
55
- return GroupedExperts . _run_experts_grouped_mm (
128
+ return _run_experts_grouped_mm (
56
129
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
57
130
)
58
131
else :
59
- return GroupedExperts . _run_experts_for_loop (
132
+ return _run_experts_for_loop (
60
133
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
61
134
)
62
135
63
- # TODO: keeping this for-loop implementation for comparison
64
- # and readability, may remove later
65
- @expert_parallel
66
- @staticmethod
67
- def _run_experts_for_loop (
68
- w1 : torch .Tensor ,
69
- w2 : torch .Tensor ,
70
- w3 : torch .Tensor ,
71
- x : torch .Tensor ,
72
- num_tokens_per_expert : torch .Tensor | None = None ,
73
- ) -> torch .Tensor :
74
- if num_tokens_per_expert is not None :
75
- # NOTE: this would incur a synchronization between device and host
76
- num_tokens_per_expert = num_tokens_per_expert .tolist ()
77
-
78
- # side-effect code due to the usage of generate_permute_indices
79
- num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
80
-
81
- # a tuple of tensors indexed by experts
82
- # each with shape (tokens_per_expert(varying), dim)
83
- x = torch .split (
84
- x [: sum (num_tokens_per_expert )],
85
- split_size_or_sections = num_tokens_per_expert ,
86
- dim = 0 ,
87
- )
88
- out_experts_splits = []
89
- for expert_idx , x_expert in enumerate (x ):
90
- h = F .silu (torch .matmul (x_expert , w1 [expert_idx ].transpose (- 2 , - 1 )))
91
- h = h * torch .matmul (x_expert , w3 [expert_idx ].transpose (- 2 , - 1 ))
92
- h = torch .matmul (h , w2 [expert_idx ].transpose (- 2 , - 1 ))
93
- # h shape (tokens_per_expert(varying), dim)
94
- out_experts_splits .append (h )
95
- out = torch .cat (out_experts_splits , dim = 0 )
96
-
97
- # side-effect code due to the usage of generate_permute_indices
98
- out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
99
- else :
100
- # x shape (num_experts, tokens_per_expert, dim)
101
- h = F .silu (torch .bmm (x , w1 .transpose (- 2 , - 1 )))
102
- h = h * torch .bmm (x , w3 .transpose (- 2 , - 1 ))
103
- # out shape (num_experts, tokens_per_expert, dim)
104
- out = torch .bmm (h , w2 .transpose (- 2 , - 1 ))
105
-
106
- return out
107
-
108
- @expert_parallel
109
- @staticmethod
110
- def _run_experts_grouped_mm (
111
- w1 : torch .Tensor ,
112
- w2 : torch .Tensor ,
113
- w3 : torch .Tensor ,
114
- x : torch .Tensor ,
115
- num_tokens_per_expert : torch .Tensor | None = None ,
116
- ) -> torch .Tensor :
117
- if num_tokens_per_expert is not None :
118
- offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
119
- # grouped mm between a 2D tensor and a 3D tensor
120
- assert x .dim () == 2
121
- else :
122
- offsets = None
123
- # fall back to regular bmm between 3D tensors
124
- assert x .dim () == 3
125
-
126
- h = F .silu (
127
- torch ._grouped_mm (
128
- x .bfloat16 (), w1 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets
129
- )
130
- )
131
- h = h * torch ._grouped_mm (
132
- x .bfloat16 (), w3 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets
133
- )
134
- out = torch ._grouped_mm (
135
- h , w2 .bfloat16 ().transpose (- 2 , - 1 ), offs = offsets
136
- ).type_as (x )
137
-
138
- return out
139
-
140
136
def init_weights (self , init_std : float ):
141
137
nn .init .trunc_normal_ (self .w1 , mean = 0.0 , std = 0.02 )
142
138
nn .init .trunc_normal_ (self .w2 , mean = 0.0 , std = init_std )
0 commit comments