Skip to content

Commit c1224c4

Browse files
committed
add mamba
1 parent dc00d07 commit c1224c4

File tree

9 files changed

+284
-61
lines changed

9 files changed

+284
-61
lines changed

_doc/examples/plot_export_tiny_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import transformers
3232
from onnx_diagnostic import doc
3333
from onnx_diagnostic.helpers import string_type
34-
from onnx_diagnostic.helpers.torch_test_helper import steel_forward
34+
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
3535
from onnx_diagnostic.torch_models.llms import get_tiny_llm
3636

3737

@@ -77,9 +77,9 @@ def _forward_(*args, _f=None, **kwargs):
7777
model.forward = keep_model_forward
7878

7979
# %%
80-
# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steel_forward`.
80+
# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`.
8181

82-
with steel_forward(model):
82+
with steal_forward(model):
8383
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
8484

8585
# %%

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
dummy_llm,
99
to_numpy,
1010
is_torchdynamo_exporting,
11-
steel_forward,
11+
steal_forward,
1212
replace_string_by_dynamic,
1313
to_any,
1414
torch_deepcopy,
@@ -43,14 +43,14 @@ def test_to_numpy(self):
4343
self.assertEqual(a.dtype, ml_dtypes.bfloat16)
4444

4545
@hide_stdout()
46-
def test_steel_forward(self):
46+
def test_steal_forward(self):
4747
class Model(torch.nn.Module):
4848
def forward(self, x, y):
4949
return x + y
5050

5151
inputs = torch.rand(3, 4), torch.rand(3, 4)
5252
model = Model()
53-
with steel_forward(model):
53+
with steal_forward(model):
5454
model(*inputs)
5555

5656
def test_replace_string_by_dynamic(self):

_unittests/ut_tasks/test_tasks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def test_sentence_similary(self):
113113
model, inputs = data["model"], data["inputs"]
114114
model(**inputs)
115115

116+
@hide_stdout()
117+
def test_falcon_mamba_dev(self):
118+
mid = "tiiuae/falcon-mamba-tiny-dev"
119+
data = get_untrained_model_with_inputs(mid, verbose=1)
120+
model, inputs = data["model"], data["inputs"]
121+
print(self.string_type(inputs, with_shape=True))
122+
model(**inputs)
123+
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
124+
116125

117126
if __name__ == "__main__":
118127
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
33
from onnx_diagnostic.helpers import string_type
4-
from onnx_diagnostic.helpers.torch_test_helper import steel_forward
4+
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
55

66

77
class TestHuggingFaceHubModel(ExtTestCase):
@@ -92,7 +92,7 @@ def test_text2text_generation(self):
9292

9393
# simply generate a single sequence
9494
print()
95-
with steel_forward(model):
95+
with steal_forward(model):
9696
generated_ids = model.generate(
9797
decoder_input_ids=input_ids, attention_mask=mask, max_length=100
9898
)
@@ -121,7 +121,7 @@ def test_imagetext2text_generation(self):
121121
["<image>", "<fake_token_around_image>"], add_special_tokens=False
122122
).input_ids
123123
print()
124-
with steel_forward(model):
124+
with steal_forward(model):
125125
generated_ids = model.generate(
126126
**inputs, max_new_tokens=10, bad_words_ids=bad_words_ids
127127
)
@@ -184,7 +184,7 @@ def test_automatic_speech_recognition(self):
184184

185185
# generate token ids
186186
print()
187-
with steel_forward(model):
187+
with steal_forward(model):
188188
predicted_ids = model.generate(
189189
input_features, forced_decoder_ids=forced_decoder_ids
190190
)
@@ -285,6 +285,80 @@ def mean_pooling(model_output, attention_mask):
285285
print("Sentence embeddings:")
286286
print(sentence_embeddings)
287287

288+
@never_test()
289+
def test_falcon_mamba_dev(self):
290+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev
291+
# https://huggingface.co/tiiuae/falcon-mamba-tiny-dev
292+
293+
from transformers import AutoTokenizer
294+
import transformers
295+
import torch
296+
297+
model = "tiiuae/falcon-mamba-tiny-dev"
298+
299+
tokenizer = AutoTokenizer.from_pretrained(model)
300+
pipeline = transformers.pipeline(
301+
"text-generation",
302+
model=model,
303+
tokenizer=tokenizer,
304+
torch_dtype=torch.bfloat16,
305+
trust_remote_code=True,
306+
device_map="auto",
307+
)
308+
print()
309+
with steal_forward(pipeline.model):
310+
sequences = pipeline(
311+
"Girafatron is obsessed with giraffes, "
312+
"the most glorious animal on the face of this Earth. "
313+
"Giraftron believes all other animals are irrelevant "
314+
"when compared to the glorious majesty of the giraffe."
315+
"\nDaniel: Hello, Girafatron!\nGirafatron:",
316+
max_length=200,
317+
do_sample=True,
318+
top_k=10,
319+
num_return_sequences=1,
320+
eos_token_id=tokenizer.eos_token_id,
321+
)
322+
for seq in sequences:
323+
print(f"Result: {seq['generated_text']}")
324+
325+
@never_test()
326+
def test_falcon_mamba_7b(self):
327+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b
328+
# https://huggingface.co/tiiuae/falcon-mamba-7b
329+
330+
from transformers import AutoTokenizer
331+
import transformers
332+
import torch
333+
334+
model = "tiiuae/falcon-mamba-7b"
335+
336+
tokenizer = AutoTokenizer.from_pretrained(model)
337+
pipeline = transformers.pipeline(
338+
"text-generation",
339+
model=model,
340+
tokenizer=tokenizer,
341+
torch_dtype=torch.bfloat16,
342+
trust_remote_code=True,
343+
device_map="auto",
344+
)
345+
print()
346+
with steal_forward(pipeline.model):
347+
sequences = pipeline(
348+
"Girafatron is obsessed with giraffes, "
349+
"the most glorious animal on the face of this Earth. "
350+
"Giraftron believes all other animals are irrelevant "
351+
"when compared to the glorious majesty of the giraffe."
352+
"\nDaniel: Hello, Girafatron!\nGirafatron:",
353+
max_length=200,
354+
do_sample=True,
355+
top_k=10,
356+
num_return_sequences=1,
357+
eos_token_id=tokenizer.eos_token_id,
358+
)
359+
for seq in sequences:
360+
print(f"Result: {seq['generated_text']}")
361+
288362

289363
if __name__ == "__main__":
290364
unittest.main(verbosity=2)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,27 @@ def make_encoder_decoder_cache(
136136
return transformers.cache_utils.EncoderDecoderCache(
137137
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
138138
)
139+
140+
141+
def make_mamba_cache(
142+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
143+
) -> transformers.cache_utils.MambaCache:
144+
"Creates a :class:`transformers.cache_utils.MambaCache`."
145+
146+
class _config:
147+
def __init__(self):
148+
self.intermediate_size = key_value_pairs[0][0].shape[1]
149+
self.conv_kernel = key_value_pairs[0][0].shape[-1]
150+
self.state_size = key_value_pairs[0][1].shape[-1]
151+
self.num_hidden_layers = len(key_value_pairs)
152+
self.dtype = key_value_pairs[0][0].dtype
153+
154+
cache = transformers.cache_utils.MambaCache(
155+
_config(),
156+
max_batch_size=key_value_pairs[0][0].shape[0],
157+
device=key_value_pairs[0][0].device,
158+
)
159+
for i in range(len(key_value_pairs)):
160+
cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
161+
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
162+
return cache

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
3131

3232

3333
@contextlib.contextmanager
34-
def steel_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
34+
def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
3535
"""
3636
The necessary modification to steem forward method and prints out inputs
3737
and outputs. See example :ref:`l-plot-tiny-llm-export`.

0 commit comments

Comments
 (0)