Skip to content

Commit 6fc153d

Browse files
committed
Add block-sparse MLP module
1 parent 538dcae commit 6fc153d

File tree

6 files changed

+309
-7
lines changed

6 files changed

+309
-7
lines changed

exllamav3/conversion/allocation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def allocate_transformer(
4444
assert d
4545
if isinstance(g, list):
4646
for m in (g, u, d):
47-
key_ = m[0].key.replace(".slice.0", ".slice.*")
47+
key_ = m[0].key.replace(".slice.0", ".slice.*").replace(".experts.0.", ".experts.*.")
4848
keys += [key_]
4949
numels += [sum(mm.weights_numel() for mm in m)]
5050
for mm in m:
@@ -65,7 +65,7 @@ def allocate_transformer(
6565
assert d
6666
if isinstance(u, list):
6767
for m in (u, d):
68-
key_ = m[0].key.replace(".slice.0", ".slice.*")
68+
key_ = m[0].key.replace(".slice.0", ".slice.*").replace(".experts.0.", ".experts.*.")
6969
keys += [m]
7070
numels += [sum(mm.weights_numel() for mm in m)]
7171
for mm in m:

exllamav3/conversion/convert_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def main(args, job_state):
278278
qmaps = module.get_qmaps()
279279
if len(qmaps) > 0:
280280

281-
# Capture calibration input states during forward pass
281+
# Capture calibration input states during forward pass. For block-sparse models, all expert layers
282+
# are activated to ensure all down projections capture at least some calibration data. When the
283+
# state is advanced later, only selected experts will be used.
282284
with ProgressBar(f" -- Capturing: {module.key}" + slice_str, len(state)) as progress:
283285
capture_H = {}
284286
ref_states = []
@@ -287,12 +289,20 @@ def main(args, job_state):
287289
params = {
288290
"attn_mode": "flash_attn_nc",
289291
"capture": capture_H,
292+
"activate_all_experts": model.calibration_all_experts,
290293
}
291294
if slicing:
292295
params["q_mlp_slice"] = current_slice
293296
rs = module.prepare_for_device(state[i], params)
294297
rs = module.forward(rs, params)
295298
if i < num_ref_states:
299+
if model.calibration_all_experts:
300+
# Reference state for measuring error need, with only selected experts
301+
params = { "attn_mode": "flash_attn_nc" }
302+
if slicing:
303+
params["q_mlp_slice"] = current_slice
304+
rs = module.prepare_for_device(state[i], params)
305+
rs = module.forward(rs, params)
296306
ref_states.append(rs.cpu())
297307
rs = None
298308
print(f" -- Captured: {module.key}" + slice_str)

exllamav3/models/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(
2727
self.logit_layer_idx = None
2828
self.first_block_idx = None
2929

30+
# Calibration options
31+
self.calibration_all_experts = False
32+
3033

3134
def __iter__(self):
3235
for module in self.modules:

exllamav3/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .module import Module
22
from .linear import Linear
33
from .mlp import MLP, GatedMLP
4+
from .block_sparse_mlp import BlockSparseMLP
45
from .rmsnorm import RMSNorm
56
from .layernorm import LayerNorm
67
from .embedding import Embedding
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
from __future__ import annotations
2+
from typing_extensions import override
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn
6+
from ..models import Config
7+
from ..util.tensor import to2
8+
from . import Module, Linear
9+
from ..ext import exllamav3_ext as ext
10+
from ..constants import MAX_MLP_INTERMEDIATE
11+
from ..util import first_not_none
12+
13+
14+
class MultiLinear:
15+
def __init__(
16+
self,
17+
device: torch.Device,
18+
linears: list[Linear]
19+
):
20+
self.device = device
21+
self.linears = linears
22+
self.num_linears = len(linears)
23+
24+
assert all(l.quant_type == "exl3" for l in linears)
25+
assert all(l.inner.bias is None for l in linears)
26+
assert all(not l.softcap for l in linears)
27+
assert all(l.post_scale == 1.0 for l in linears)
28+
29+
self.in_features = linears[0].in_features
30+
self.out_features = linears[0].out_features
31+
self.K = linears[0].inner.K
32+
assert all(l.inner.K == self.K for l in linears)
33+
assert all(l.in_features == self.in_features for l in linears)
34+
assert all(l.out_features == self.out_features for l in linears)
35+
36+
self.ptrs_suh = torch.tensor([l.inner.suh.data_ptr() for l in linears], dtype = torch.long, device = device)
37+
self.ptrs_svh = torch.tensor([l.inner.svh.data_ptr() for l in linears], dtype = torch.long, device = device)
38+
self.ptrs_trellis = torch.tensor([l.inner.trellis.data_ptr() for l in linears], dtype = torch.long, device = device)
39+
40+
def unload(self):
41+
pass
42+
43+
44+
class BlockSparseMLP(Module):
45+
46+
def __init__(
47+
self,
48+
config: Config,
49+
key: str,
50+
hidden_size: int,
51+
intermediate_size: int,
52+
num_experts: int,
53+
num_experts_per_tok: int,
54+
key_up: str | None = None,
55+
key_gate: str | None = None,
56+
key_down: str | None = None,
57+
key_routing_gate: str | None = None,
58+
qmap: str | None = None,
59+
out_dtype: torch.dtype = None,
60+
activation_fn: str = "silu",
61+
interm_dtype: torch.dtype = None,
62+
):
63+
super().__init__(config, key, None)
64+
65+
self.out_dtype = out_dtype
66+
self.interm_dtype = interm_dtype
67+
self.activation_fn = activation_fn
68+
self.intermediate_size = intermediate_size
69+
self.num_experts = num_experts
70+
self.num_experts_per_tok = num_experts_per_tok
71+
self.hidden_size = hidden_size
72+
73+
self.routing_gate = Linear(
74+
config = config,
75+
key = f"{key}.{key_routing_gate}",
76+
in_features = hidden_size,
77+
out_features = num_experts,
78+
qmap = None,
79+
out_dtype = torch.half,
80+
)
81+
self.register_submodule(self.routing_gate)
82+
83+
self.gates = []
84+
self.ups = []
85+
self.downs = []
86+
87+
for idx in range(num_experts):
88+
89+
gate = Linear(
90+
config = config,
91+
key = f"{key}.{key_gate}".replace("{expert_idx}", str(idx)),
92+
in_features = hidden_size,
93+
out_features = intermediate_size,
94+
qmap = qmap + ".input",
95+
out_dtype = self.interm_dtype
96+
)
97+
up = Linear(
98+
config = config,
99+
key = f"{key}.{key_up}".replace("{expert_idx}", str(idx)),
100+
in_features = hidden_size,
101+
out_features = intermediate_size,
102+
qmap = qmap + ".input",
103+
out_dtype = self.interm_dtype
104+
)
105+
down = Linear(
106+
config = config,
107+
key = f"{key}.{key_down}".replace("{expert_idx}", str(idx)),
108+
in_features = intermediate_size,
109+
out_features = hidden_size,
110+
qmap = qmap + f".{idx}.down",
111+
out_dtype = torch.half,
112+
allow_input_padding = True,
113+
)
114+
115+
self.ups.append(up)
116+
self.gates.append(gate)
117+
self.downs.append(down)
118+
119+
self.register_submodule(up)
120+
self.register_submodule(gate)
121+
self.register_submodule(down)
122+
123+
match activation_fn:
124+
case "silu": self.activation_fn_call = ext.silu_mul
125+
case "gelu": self.activation_fn_call = ext.gelu_mul
126+
127+
self.is_quantized = False
128+
self.multi_gate = None
129+
self.multi_up = None
130+
self.multi_down = None
131+
132+
133+
@override
134+
def load(self, device: torch.Device, **kwargs):
135+
super().load(device, **kwargs)
136+
137+
# Test if experts can be fused
138+
num_exl3_tensors = 0
139+
num_nonexl3_tensors = 0
140+
for l in self.gates + self.ups + self.downs:
141+
if l.quant_type == "exl3":
142+
num_exl3_tensors += 1
143+
else:
144+
num_nonexl3_tensors += 1
145+
if num_exl3_tensors and num_nonexl3_tensors:
146+
print(f" !! Warning, partially quantized block-sparse MLP layer: {self.key}")
147+
self.is_quantized = (num_exl3_tensors > 0 and num_nonexl3_tensors == 0)
148+
149+
# Make fused modules
150+
if self.is_quantized:
151+
self.multi_gate = MultiLinear(self. device, self.gates)
152+
self.multi_up = MultiLinear(self. device, self.ups)
153+
self.multi_down = MultiLinear(self. device, self.downs)
154+
155+
156+
@override
157+
def unload(self):
158+
if self.multi_gate is not None:
159+
self.multi_gate.unload()
160+
self.multi_gate = None
161+
if self.multi_up is not None:
162+
self.multi_up.unload()
163+
self.multi_up = None
164+
if self.multi_down is not None:
165+
self.multi_down.unload()
166+
self.multi_down = None
167+
super().unload()
168+
169+
170+
@override
171+
def forward(
172+
self,
173+
x: torch.Tensor,
174+
params: dict,
175+
out_dtype: torch.dtype | None = None
176+
) -> torch.Tensor:
177+
178+
activate_all_experts = params.get("activate_all_experts", False)
179+
180+
y = x.view(-1, self.hidden_size)
181+
bsz = y.shape[0]
182+
183+
router_logits = self.routing_gate.forward(y, params)
184+
routing_weights = F.softmax(router_logits, dim = -1)
185+
routing_weights, selected_experts = torch.topk(
186+
routing_weights,
187+
self.num_experts if activate_all_experts else self.num_experts_per_tok,
188+
dim = -1
189+
)
190+
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
191+
192+
# Torch path
193+
if bsz > 1 or not self.is_quantized:
194+
final_hidden_states = torch.zeros_like(y)
195+
196+
expert_mask = torch.nn.functional.one_hot(
197+
selected_experts,
198+
num_classes = self.num_experts
199+
)
200+
expert_count = expert_mask.view(-1, self.num_experts).sum(dim = 0).cpu()
201+
expert_mask = expert_mask.permute(2, 1, 0)
202+
203+
def mlp(exp_i, xc):
204+
g = self.gates[exp_i].forward(xc, params)
205+
u = self.ups[exp_i].forward(xc, params)
206+
self.activation_fn_call(g, u, u)
207+
return self.downs[exp_i].forward(u, params)
208+
209+
for expert_idx in range(self.num_experts):
210+
if expert_count[expert_idx] == 0:
211+
continue
212+
idx, top_x = torch.where(expert_mask[expert_idx])
213+
current_state = y[None, top_x].reshape(-1, self.hidden_size)
214+
current_state = mlp(expert_idx, current_state) * routing_weights[top_x, idx, None]
215+
final_hidden_states.index_add_(0, top_x, current_state)
216+
217+
final_hidden_states = final_hidden_states.reshape(x.shape)
218+
return to2(final_hidden_states, out_dtype, self.out_dtype)
219+
220+
# Fused path
221+
# TODO: Find good solution for 1 < bsz < 32
222+
else:
223+
y = y.unsqueeze(0)
224+
yh = torch.empty(
225+
(self.num_experts_per_tok, bsz, y.shape[-1]),
226+
dtype = y.dtype,
227+
device = y.device
228+
)
229+
interm_g = torch.empty(
230+
(self.num_experts_per_tok, bsz, self.intermediate_size),
231+
dtype = self.interm_dtype,
232+
device = y.device
233+
)
234+
interm_u = torch.empty_like(interm_g)
235+
interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u
236+
out_d = torch.empty(
237+
(self.num_experts_per_tok, bsz, self.hidden_size),
238+
dtype = first_not_none(out_dtype, self.out_dtype, torch.half),
239+
device = y.device
240+
)
241+
242+
# Gate
243+
ext.exl3_mgemm(
244+
y,
245+
self.multi_gate.ptrs_trellis,
246+
interm_g,
247+
self.multi_gate.ptrs_suh,
248+
yh,
249+
self.multi_gate.ptrs_svh,
250+
selected_experts,
251+
None,
252+
self.multi_gate.K,
253+
-1
254+
)
255+
256+
# Up
257+
ext.exl3_mgemm(
258+
y,
259+
self.multi_up.ptrs_trellis,
260+
interm_u,
261+
self.multi_up.ptrs_suh,
262+
yh,
263+
self.multi_up.ptrs_svh,
264+
selected_experts,
265+
None,
266+
self.multi_up.K,
267+
-1
268+
)
269+
270+
# Activation
271+
self.activation_fn_call(interm_g, interm_u, interm_a)
272+
273+
# Down
274+
ext.exl3_mgemm(
275+
interm_a,
276+
self.multi_down.ptrs_trellis,
277+
out_d,
278+
self.multi_down.ptrs_suh,
279+
interm_a,
280+
self.multi_down.ptrs_svh,
281+
selected_experts,
282+
routing_weights,
283+
self.multi_down.K,
284+
-1
285+
)
286+
287+
final_hidden_states = out_d.sum(dim = 0)
288+
return final_hidden_states.view(x.shape)

exllamav3/modules/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn
66
from ..util.tensor import to2
77
from ..models import Config
8-
from . import Module, RMSNorm, LayerNorm, Attention, GatedMLP, MLP
8+
from . import Module, RMSNorm, LayerNorm, Attention, GatedMLP, MLP, BlockSparseMLP
99
from ..conversion.allocation import allocate_transformer
1010

1111
class TransformerBlock(Module):
@@ -18,7 +18,7 @@ def __init__(
1818
attn: Attention | None = None,
1919
attn_post_norm: RMSNorm | None = None,
2020
mlp_norm: RMSNorm | LayerNorm | None = None,
21-
mlp: MLP | GatedMLP | None = None,
21+
mlp: MLP | GatedMLP | BlockSparseMLP | None = None,
2222
mlp_post_norm: RMSNorm | None = None,
2323
qmap: str | None = None,
2424
qbits_key: str = "bits",
@@ -83,7 +83,7 @@ def allocate_q(self, quant_args: dict, surplus_bits: int):
8383
self.attn.k_proj if self.attn else None,
8484
self.attn.v_proj if self.attn else None,
8585
self.attn.o_proj if self.attn else None,
86-
self.mlp.gates if isinstance(self.mlp, GatedMLP) else None,
86+
self.mlp.gates if any(isinstance(self.mlp, x) for x in [GatedMLP, BlockSparseMLP]) else None,
8787
self.mlp.ups if self.mlp else None,
8888
self.mlp.downs if self.mlp else None,
8989
)
@@ -152,7 +152,7 @@ def allocate_q(self, quant_args: dict, surplus_bits: int):
152152
self.attn.k_proj if self.attn else None,
153153
self.attn.v_proj if self.attn else None,
154154
self.attn.o_proj if self.attn else None,
155-
self.mlp.gates if isinstance(self.mlp, GatedMLP) else None,
155+
self.mlp.gates if any(isinstance(self.mlp, x) for x in [GatedMLP, BlockSparseMLP]) else None,
156156
self.mlp.ups if self.mlp else None,
157157
self.mlp.downs if self.mlp else None,
158158
)

0 commit comments

Comments
 (0)