Skip to content

Commit d9ac8d2

Browse files
committed
Fix regression #71
1 parent 4cc5c7a commit d9ac8d2

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

examples/multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
model_dir = "/mnt/str/models/gemma3-4b-it/exl3/5.0bpw/"
1919
case "mistral3":
2020
prompt_format = "mistral"
21-
model_dir = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl3/8.0bpw_H8"
21+
model_dir = "/mnt/str/models/mistral-small-3.1-24b-instruct-2503/exl3/4.0bpw/"
2222

2323
images = [
2424
# Cat

exllamav3/architecture/gemma3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,4 +582,9 @@ def get_image_embeddings(
582582

583583
mmes.append(mme)
584584

585-
return mmes if return_batch else mmes[0]
585+
return mmes if return_batch else mmes[0]
586+
587+
588+
@override
589+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
590+
return input_ids

exllamav3/architecture/mistral3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,8 @@ def get_image_embeddings(
517517
})
518518

519519
return mme
520+
521+
522+
@override
523+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
524+
return input_ids

exllamav3/modules/mlp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
key: str,
2121
hidden_size: int,
2222
intermediate_size: int,
23+
out_size: int | None = None,
2324
key_up: str | None = None,
2425
key_down: str | None = None,
2526
qmap: str | None = None,
@@ -40,6 +41,7 @@ def __init__(
4041
self.activation_fn = activation_fn
4142
self.intermediate_size = intermediate_size
4243
self.intermediate_split_size = intermediate_split_size
44+
self.out_size = out_size or hidden_size
4345

4446
fkey, frange_up = None, None
4547

@@ -99,9 +101,9 @@ def __init__(
99101
config = config,
100102
key = s_key_d,
101103
in_features = b - a,
102-
out_features = hidden_size,
104+
out_features = self.out_size,
103105
full_in_features = intermediate_size,
104-
full_out_features = hidden_size,
106+
full_out_features = self.out_size,
105107
first_in_feature = a,
106108
first_out_feature = 0,
107109
qmap = qmap + ".down",

0 commit comments

Comments
 (0)