Skip to content

Commit a546276

Browse files
committed
fix mamba
1 parent c1224c4 commit a546276

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_falcon_mamba_dev(self):
120120
model, inputs = data["model"], data["inputs"]
121121
print(self.string_type(inputs, with_shape=True))
122122
model(**inputs)
123-
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
123+
self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)])
124124

125125

126126
if __name__ == "__main__":

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ def __init__(self):
157157
device=key_value_pairs[0][0].device,
158158
)
159159
for i in range(len(key_value_pairs)):
160+
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
161+
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
162+
f"got {key_value_pairs[i][0].shape}"
163+
)
160164
cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
165+
assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, (
166+
f"Shape mismatch, expected {cache.ssm_states[i].shape}, "
167+
f"got {key_value_pairs[i][1].shape}"
168+
)
161169
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
162170
return cache

onnx_diagnostic/tasks/text_generation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def get_inputs(
8888
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
8989

9090
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
91+
seq_length_multiple = 8
92+
sequence_length = (
93+
(sequence_length + seq_length_multiple)
94+
// seq_length_multiple
95+
* seq_length_multiple
96+
)
97+
# sequence_inc = seq_length_multiple
98+
sequence_length2 = seq_length_multiple
99+
91100
shapes = {
92101
"input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
93102
"attention_mask": {
@@ -110,9 +119,8 @@ def get_inputs(
110119
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
111120
torch.int64
112121
),
113-
cache_position=torch.arange(0, sequence_length + sequence_length2)
114-
.to(torch.int64)
115-
.expand((batch_size, -1)),
122+
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
123+
# .expand((batch_size, -1))
116124
cache_params=make_mamba_cache(
117125
[
118126
(

0 commit comments

Comments
 (0)