Skip to content

Commit 2c91c4a

Browse files
committed
fix mamba dtype
1 parent 750d08f commit 2c91c4a

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,21 +155,31 @@ def make_mamba_cache(
155155
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
156156
) -> transformers.cache_utils.MambaCache:
157157
"Creates a :class:`transformers.cache_utils.MambaCache`."
158+
dtype = key_value_pairs[0][0].dtype
158159

159160
class _config:
160161
def __init__(self):
161162
self.intermediate_size = key_value_pairs[0][0].shape[1]
162163
self.conv_kernel = key_value_pairs[0][0].shape[-1]
163164
self.state_size = key_value_pairs[0][1].shape[-1]
164165
self.num_hidden_layers = len(key_value_pairs)
165-
self.dtype = key_value_pairs[0][0].dtype
166+
self.dtype = dtype
166167

167168
cache = transformers.cache_utils.MambaCache(
168169
_config(),
169170
max_batch_size=key_value_pairs[0][0].shape[0],
170171
device=key_value_pairs[0][0].device,
172+
dtype=dtype,
171173
)
172174
for i in range(len(key_value_pairs)):
175+
assert cache.conv_states[i].dtype == dtype, (
176+
f"Type mismatch for cache.conv_states[{i}].dtype="
177+
f"{cache.conv_states[i].dtype} != {dtype}"
178+
)
179+
assert cache.ssm_states[i].dtype == dtype, (
180+
f"Type mismatch for cache.ssm_states[{i}].dtype="
181+
f"{cache.ssm_states[i].dtype} != {dtype}"
182+
)
173183
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
174184
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
175185
f"got {key_value_pairs[i][0].shape}"

0 commit comments

Comments
 (0)