@@ -48,6 +48,73 @@ def init_weights(self, init_std: float = 0.02):
48
48
nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
49
49
50
50
51
+ # TODO: keeping this for-loop implementation for comparison
52
+ # and readability, may remove later
53
+ @expert_parallel
54
+ def _run_experts_for_loop (
55
+ w1 : torch .Tensor ,
56
+ w2 : torch .Tensor ,
57
+ w3 : torch .Tensor ,
58
+ x : torch .Tensor ,
59
+ num_tokens_per_expert : torch .Tensor | None = None ,
60
+ ) -> torch .Tensor :
61
+ if num_tokens_per_expert is not None :
62
+ # NOTE: this would incur a synchronization between device and host
63
+ num_tokens_per_expert = num_tokens_per_expert .tolist ()
64
+
65
+ # side-effect code due to the usage of generate_permute_indices
66
+ num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
67
+
68
+ # a tuple of tensors indexed by experts
69
+ # each with shape (tokens_per_expert(varying), dim)
70
+ x = torch .split (
71
+ x [: sum (num_tokens_per_expert )],
72
+ split_size_or_sections = num_tokens_per_expert ,
73
+ dim = 0 ,
74
+ )
75
+ out_experts_splits = []
76
+ for expert_idx , x_expert in enumerate (x ):
77
+ h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
78
+ h = h * torch .matmul (x_expert , w3 [expert_idx ])
79
+ h = torch .matmul (h , w2 [expert_idx ])
80
+ # h shape (tokens_per_expert(varying), dim)
81
+ out_experts_splits .append (h )
82
+ out = torch .cat (out_experts_splits , dim = 0 )
83
+
84
+ # side-effect code due to the usage of generate_permute_indices
85
+ out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
86
+ else :
87
+ # x shape (num_experts, tokens_per_expert, dim)
88
+ h = F .silu (torch .bmm (x , w1 ))
89
+ h = h * torch .bmm (x , w3 )
90
+ # out shape (num_experts, tokens_per_expert, dim)
91
+ out = torch .bmm (h , w2 )
92
+
93
+ return out
94
+
95
+ @expert_parallel
96
+ def _run_experts_grouped_mm (
97
+ w1 : torch .Tensor ,
98
+ w2 : torch .Tensor ,
99
+ w3 : torch .Tensor ,
100
+ x : torch .Tensor ,
101
+ num_tokens_per_expert : torch .Tensor | None = None ,
102
+ ) -> torch .Tensor :
103
+ if num_tokens_per_expert is not None :
104
+ offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
105
+ # grouped mm between a 2D tensor and a 3D tensor
106
+ assert x .dim () == 2
107
+ else :
108
+ offsets = None
109
+ # fall back to regular bmm between 3D tensors
110
+ assert x .dim () == 3
111
+
112
+ h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
113
+ h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
114
+ out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
115
+
116
+ return out
117
+
51
118
class GroupedExperts (nn .Module ):
52
119
def __init__ (
53
120
self ,
@@ -69,83 +136,14 @@ def forward(
69
136
num_tokens_per_expert : torch .Tensor | None = None ,
70
137
) -> torch .Tensor :
71
138
if self .use_grouped_mm :
72
- return GroupedExperts . _run_experts_grouped_mm (
139
+ return _run_experts_grouped_mm (
73
140
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
74
141
)
75
142
else :
76
- return GroupedExperts . _run_experts_for_loop (
143
+ return _run_experts_for_loop (
77
144
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
78
145
)
79
146
80
- # TODO: keeping this for-loop implementation for comparison
81
- # and readability, may remove later
82
- @expert_parallel
83
- @staticmethod
84
- def _run_experts_for_loop (
85
- w1 : torch .Tensor ,
86
- w2 : torch .Tensor ,
87
- w3 : torch .Tensor ,
88
- x : torch .Tensor ,
89
- num_tokens_per_expert : torch .Tensor | None = None ,
90
- ) -> torch .Tensor :
91
- if num_tokens_per_expert is not None :
92
- # NOTE: this would incur a synchronization between device and host
93
- num_tokens_per_expert = num_tokens_per_expert .tolist ()
94
-
95
- # side-effect code due to the usage of generate_permute_indices
96
- num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
97
-
98
- # a tuple of tensors indexed by experts
99
- # each with shape (tokens_per_expert(varying), dim)
100
- x = torch .split (
101
- x [: sum (num_tokens_per_expert )],
102
- split_size_or_sections = num_tokens_per_expert ,
103
- dim = 0 ,
104
- )
105
- out_experts_splits = []
106
- for expert_idx , x_expert in enumerate (x ):
107
- h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
108
- h = h * torch .matmul (x_expert , w3 [expert_idx ])
109
- h = torch .matmul (h , w2 [expert_idx ])
110
- # h shape (tokens_per_expert(varying), dim)
111
- out_experts_splits .append (h )
112
- out = torch .cat (out_experts_splits , dim = 0 )
113
-
114
- # side-effect code due to the usage of generate_permute_indices
115
- out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
116
- else :
117
- # x shape (num_experts, tokens_per_expert, dim)
118
- h = F .silu (torch .bmm (x , w1 ))
119
- h = h * torch .bmm (x , w3 )
120
- # out shape (num_experts, tokens_per_expert, dim)
121
- out = torch .bmm (h , w2 )
122
-
123
- return out
124
-
125
- @expert_parallel
126
- @staticmethod
127
- def _run_experts_grouped_mm (
128
- w1 : torch .Tensor ,
129
- w2 : torch .Tensor ,
130
- w3 : torch .Tensor ,
131
- x : torch .Tensor ,
132
- num_tokens_per_expert : torch .Tensor | None = None ,
133
- ) -> torch .Tensor :
134
- if num_tokens_per_expert is not None :
135
- offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
136
- # grouped mm between a 2D tensor and a 3D tensor
137
- assert x .dim () == 2
138
- else :
139
- offsets = None
140
- # fall back to regular bmm between 3D tensors
141
- assert x .dim () == 3
142
-
143
- h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
144
- h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
145
- out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
146
-
147
- return out
148
-
149
147
def init_weights (self , init_std : float ):
150
148
nn .init .trunc_normal_ (self .w1 , mean = 0.0 , std = 0.02 )
151
149
nn .init .trunc_normal_ (self .w2 , mean = 0.0 , std = init_std )
0 commit comments