1+ from __future__ import annotations
2+ from typing_extensions import override
3+ import torch
4+ import torch .nn .functional as F
5+ from torch import nn
6+ from ..models import Config
7+ from ..util .tensor import to2
8+ from . import Module , Linear
9+ from ..ext import exllamav3_ext as ext
10+ from ..constants import MAX_MLP_INTERMEDIATE
11+ from ..util import first_not_none
12+
13+
14+ class MultiLinear :
15+ def __init__ (
16+ self ,
17+ device : torch .Device ,
18+ linears : list [Linear ]
19+ ):
20+ self .device = device
21+ self .linears = linears
22+ self .num_linears = len (linears )
23+
24+ assert all (l .quant_type == "exl3" for l in linears )
25+ assert all (l .inner .bias is None for l in linears )
26+ assert all (not l .softcap for l in linears )
27+ assert all (l .post_scale == 1.0 for l in linears )
28+
29+ self .in_features = linears [0 ].in_features
30+ self .out_features = linears [0 ].out_features
31+ self .K = linears [0 ].inner .K
32+ assert all (l .inner .K == self .K for l in linears )
33+ assert all (l .in_features == self .in_features for l in linears )
34+ assert all (l .out_features == self .out_features for l in linears )
35+
36+ self .ptrs_suh = torch .tensor ([l .inner .suh .data_ptr () for l in linears ], dtype = torch .long , device = device )
37+ self .ptrs_svh = torch .tensor ([l .inner .svh .data_ptr () for l in linears ], dtype = torch .long , device = device )
38+ self .ptrs_trellis = torch .tensor ([l .inner .trellis .data_ptr () for l in linears ], dtype = torch .long , device = device )
39+
40+ def unload (self ):
41+ pass
42+
43+
44+ class BlockSparseMLP (Module ):
45+
46+ def __init__ (
47+ self ,
48+ config : Config ,
49+ key : str ,
50+ hidden_size : int ,
51+ intermediate_size : int ,
52+ num_experts : int ,
53+ num_experts_per_tok : int ,
54+ key_up : str | None = None ,
55+ key_gate : str | None = None ,
56+ key_down : str | None = None ,
57+ key_routing_gate : str | None = None ,
58+ qmap : str | None = None ,
59+ out_dtype : torch .dtype = None ,
60+ activation_fn : str = "silu" ,
61+ interm_dtype : torch .dtype = None ,
62+ ):
63+ super ().__init__ (config , key , None )
64+
65+ self .out_dtype = out_dtype
66+ self .interm_dtype = interm_dtype
67+ self .activation_fn = activation_fn
68+ self .intermediate_size = intermediate_size
69+ self .num_experts = num_experts
70+ self .num_experts_per_tok = num_experts_per_tok
71+ self .hidden_size = hidden_size
72+
73+ self .routing_gate = Linear (
74+ config = config ,
75+ key = f"{ key } .{ key_routing_gate } " ,
76+ in_features = hidden_size ,
77+ out_features = num_experts ,
78+ qmap = None ,
79+ out_dtype = torch .half ,
80+ )
81+ self .register_submodule (self .routing_gate )
82+
83+ self .gates = []
84+ self .ups = []
85+ self .downs = []
86+
87+ for idx in range (num_experts ):
88+
89+ gate = Linear (
90+ config = config ,
91+ key = f"{ key } .{ key_gate } " .replace ("{expert_idx}" , str (idx )),
92+ in_features = hidden_size ,
93+ out_features = intermediate_size ,
94+ qmap = qmap + ".input" ,
95+ out_dtype = self .interm_dtype
96+ )
97+ up = Linear (
98+ config = config ,
99+ key = f"{ key } .{ key_up } " .replace ("{expert_idx}" , str (idx )),
100+ in_features = hidden_size ,
101+ out_features = intermediate_size ,
102+ qmap = qmap + ".input" ,
103+ out_dtype = self .interm_dtype
104+ )
105+ down = Linear (
106+ config = config ,
107+ key = f"{ key } .{ key_down } " .replace ("{expert_idx}" , str (idx )),
108+ in_features = intermediate_size ,
109+ out_features = hidden_size ,
110+ qmap = qmap + f".{ idx } .down" ,
111+ out_dtype = torch .half ,
112+ allow_input_padding = True ,
113+ )
114+
115+ self .ups .append (up )
116+ self .gates .append (gate )
117+ self .downs .append (down )
118+
119+ self .register_submodule (up )
120+ self .register_submodule (gate )
121+ self .register_submodule (down )
122+
123+ match activation_fn :
124+ case "silu" : self .activation_fn_call = ext .silu_mul
125+ case "gelu" : self .activation_fn_call = ext .gelu_mul
126+
127+ self .is_quantized = False
128+ self .multi_gate = None
129+ self .multi_up = None
130+ self .multi_down = None
131+
132+
133+ @override
134+ def load (self , device : torch .Device , ** kwargs ):
135+ super ().load (device , ** kwargs )
136+
137+ # Test if experts can be fused
138+ num_exl3_tensors = 0
139+ num_nonexl3_tensors = 0
140+ for l in self .gates + self .ups + self .downs :
141+ if l .quant_type == "exl3" :
142+ num_exl3_tensors += 1
143+ else :
144+ num_nonexl3_tensors += 1
145+ if num_exl3_tensors and num_nonexl3_tensors :
146+ print (f" !! Warning, partially quantized block-sparse MLP layer: { self .key } " )
147+ self .is_quantized = (num_exl3_tensors > 0 and num_nonexl3_tensors == 0 )
148+
149+ # Make fused modules
150+ if self .is_quantized :
151+ self .multi_gate = MultiLinear (self . device , self .gates )
152+ self .multi_up = MultiLinear (self . device , self .ups )
153+ self .multi_down = MultiLinear (self . device , self .downs )
154+
155+
156+ @override
157+ def unload (self ):
158+ if self .multi_gate is not None :
159+ self .multi_gate .unload ()
160+ self .multi_gate = None
161+ if self .multi_up is not None :
162+ self .multi_up .unload ()
163+ self .multi_up = None
164+ if self .multi_down is not None :
165+ self .multi_down .unload ()
166+ self .multi_down = None
167+ super ().unload ()
168+
169+
170+ @override
171+ def forward (
172+ self ,
173+ x : torch .Tensor ,
174+ params : dict ,
175+ out_dtype : torch .dtype | None = None
176+ ) -> torch .Tensor :
177+
178+ activate_all_experts = params .get ("activate_all_experts" , False )
179+
180+ y = x .view (- 1 , self .hidden_size )
181+ bsz = y .shape [0 ]
182+
183+ router_logits = self .routing_gate .forward (y , params )
184+ routing_weights = F .softmax (router_logits , dim = - 1 )
185+ routing_weights , selected_experts = torch .topk (
186+ routing_weights ,
187+ self .num_experts if activate_all_experts else self .num_experts_per_tok ,
188+ dim = - 1
189+ )
190+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
191+
192+ # Torch path
193+ if bsz > 1 or not self .is_quantized :
194+ final_hidden_states = torch .zeros_like (y )
195+
196+ expert_mask = torch .nn .functional .one_hot (
197+ selected_experts ,
198+ num_classes = self .num_experts
199+ )
200+ expert_count = expert_mask .view (- 1 , self .num_experts ).sum (dim = 0 ).cpu ()
201+ expert_mask = expert_mask .permute (2 , 1 , 0 )
202+
203+ def mlp (exp_i , xc ):
204+ g = self .gates [exp_i ].forward (xc , params )
205+ u = self .ups [exp_i ].forward (xc , params )
206+ self .activation_fn_call (g , u , u )
207+ return self .downs [exp_i ].forward (u , params )
208+
209+ for expert_idx in range (self .num_experts ):
210+ if expert_count [expert_idx ] == 0 :
211+ continue
212+ idx , top_x = torch .where (expert_mask [expert_idx ])
213+ current_state = y [None , top_x ].reshape (- 1 , self .hidden_size )
214+ current_state = mlp (expert_idx , current_state ) * routing_weights [top_x , idx , None ]
215+ final_hidden_states .index_add_ (0 , top_x , current_state )
216+
217+ final_hidden_states = final_hidden_states .reshape (x .shape )
218+ return to2 (final_hidden_states , out_dtype , self .out_dtype )
219+
220+ # Fused path
221+ # TODO: Find good solution for 1 < bsz < 32
222+ else :
223+ y = y .unsqueeze (0 )
224+ yh = torch .empty (
225+ (self .num_experts_per_tok , bsz , y .shape [- 1 ]),
226+ dtype = y .dtype ,
227+ device = y .device
228+ )
229+ interm_g = torch .empty (
230+ (self .num_experts_per_tok , bsz , self .intermediate_size ),
231+ dtype = self .interm_dtype ,
232+ device = y .device
233+ )
234+ interm_u = torch .empty_like (interm_g )
235+ interm_a = torch .empty_like (interm_u , dtype = torch .half ) if self .interm_dtype != torch .half else interm_u
236+ out_d = torch .empty (
237+ (self .num_experts_per_tok , bsz , self .hidden_size ),
238+ dtype = first_not_none (out_dtype , self .out_dtype , torch .half ),
239+ device = y .device
240+ )
241+
242+ # Gate
243+ ext .exl3_mgemm (
244+ y ,
245+ self .multi_gate .ptrs_trellis ,
246+ interm_g ,
247+ self .multi_gate .ptrs_suh ,
248+ yh ,
249+ self .multi_gate .ptrs_svh ,
250+ selected_experts ,
251+ None ,
252+ self .multi_gate .K ,
253+ - 1
254+ )
255+
256+ # Up
257+ ext .exl3_mgemm (
258+ y ,
259+ self .multi_up .ptrs_trellis ,
260+ interm_u ,
261+ self .multi_up .ptrs_suh ,
262+ yh ,
263+ self .multi_up .ptrs_svh ,
264+ selected_experts ,
265+ None ,
266+ self .multi_up .K ,
267+ - 1
268+ )
269+
270+ # Activation
271+ self .activation_fn_call (interm_g , interm_u , interm_a )
272+
273+ # Down
274+ ext .exl3_mgemm (
275+ interm_a ,
276+ self .multi_down .ptrs_trellis ,
277+ out_d ,
278+ self .multi_down .ptrs_suh ,
279+ interm_a ,
280+ self .multi_down .ptrs_svh ,
281+ selected_experts ,
282+ routing_weights ,
283+ self .multi_down .K ,
284+ - 1
285+ )
286+
287+ final_hidden_states = out_d .sum (dim = 0 )
288+ return final_hidden_states .view (x .shape )
0 commit comments