|
1 | 1 | import unittest |
2 | 2 | from onnx_diagnostic.ext_test_case import ExtTestCase, never_test |
3 | 3 | from onnx_diagnostic.helpers import string_type |
4 | | -from onnx_diagnostic.helpers.torch_test_helper import steel_forward |
| 4 | +from onnx_diagnostic.helpers.torch_test_helper import steal_forward |
5 | 5 |
|
6 | 6 |
|
7 | 7 | class TestHuggingFaceHubModel(ExtTestCase): |
@@ -92,7 +92,7 @@ def test_text2text_generation(self): |
92 | 92 |
|
93 | 93 | # simply generate a single sequence |
94 | 94 | print() |
95 | | - with steel_forward(model): |
| 95 | + with steal_forward(model): |
96 | 96 | generated_ids = model.generate( |
97 | 97 | decoder_input_ids=input_ids, attention_mask=mask, max_length=100 |
98 | 98 | ) |
@@ -121,7 +121,7 @@ def test_imagetext2text_generation(self): |
121 | 121 | ["<image>", "<fake_token_around_image>"], add_special_tokens=False |
122 | 122 | ).input_ids |
123 | 123 | print() |
124 | | - with steel_forward(model): |
| 124 | + with steal_forward(model): |
125 | 125 | generated_ids = model.generate( |
126 | 126 | **inputs, max_new_tokens=10, bad_words_ids=bad_words_ids |
127 | 127 | ) |
@@ -184,7 +184,7 @@ def test_automatic_speech_recognition(self): |
184 | 184 |
|
185 | 185 | # generate token ids |
186 | 186 | print() |
187 | | - with steel_forward(model): |
| 187 | + with steal_forward(model): |
188 | 188 | predicted_ids = model.generate( |
189 | 189 | input_features, forced_decoder_ids=forced_decoder_ids |
190 | 190 | ) |
@@ -285,6 +285,80 @@ def mean_pooling(model_output, attention_mask): |
285 | 285 | print("Sentence embeddings:") |
286 | 286 | print(sentence_embeddings) |
287 | 287 |
|
| 288 | + @never_test() |
| 289 | + def test_falcon_mamba_dev(self): |
| 290 | + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev |
| 291 | + # https://huggingface.co/tiiuae/falcon-mamba-tiny-dev |
| 292 | + |
| 293 | + from transformers import AutoTokenizer |
| 294 | + import transformers |
| 295 | + import torch |
| 296 | + |
| 297 | + model = "tiiuae/falcon-mamba-tiny-dev" |
| 298 | + |
| 299 | + tokenizer = AutoTokenizer.from_pretrained(model) |
| 300 | + pipeline = transformers.pipeline( |
| 301 | + "text-generation", |
| 302 | + model=model, |
| 303 | + tokenizer=tokenizer, |
| 304 | + torch_dtype=torch.bfloat16, |
| 305 | + trust_remote_code=True, |
| 306 | + device_map="auto", |
| 307 | + ) |
| 308 | + print() |
| 309 | + with steal_forward(pipeline.model): |
| 310 | + sequences = pipeline( |
| 311 | + "Girafatron is obsessed with giraffes, " |
| 312 | + "the most glorious animal on the face of this Earth. " |
| 313 | + "Giraftron believes all other animals are irrelevant " |
| 314 | + "when compared to the glorious majesty of the giraffe." |
| 315 | + "\nDaniel: Hello, Girafatron!\nGirafatron:", |
| 316 | + max_length=200, |
| 317 | + do_sample=True, |
| 318 | + top_k=10, |
| 319 | + num_return_sequences=1, |
| 320 | + eos_token_id=tokenizer.eos_token_id, |
| 321 | + ) |
| 322 | + for seq in sequences: |
| 323 | + print(f"Result: {seq['generated_text']}") |
| 324 | + |
| 325 | + @never_test() |
| 326 | + def test_falcon_mamba_7b(self): |
| 327 | + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b |
| 328 | + # https://huggingface.co/tiiuae/falcon-mamba-7b |
| 329 | + |
| 330 | + from transformers import AutoTokenizer |
| 331 | + import transformers |
| 332 | + import torch |
| 333 | + |
| 334 | + model = "tiiuae/falcon-mamba-7b" |
| 335 | + |
| 336 | + tokenizer = AutoTokenizer.from_pretrained(model) |
| 337 | + pipeline = transformers.pipeline( |
| 338 | + "text-generation", |
| 339 | + model=model, |
| 340 | + tokenizer=tokenizer, |
| 341 | + torch_dtype=torch.bfloat16, |
| 342 | + trust_remote_code=True, |
| 343 | + device_map="auto", |
| 344 | + ) |
| 345 | + print() |
| 346 | + with steal_forward(pipeline.model): |
| 347 | + sequences = pipeline( |
| 348 | + "Girafatron is obsessed with giraffes, " |
| 349 | + "the most glorious animal on the face of this Earth. " |
| 350 | + "Giraftron believes all other animals are irrelevant " |
| 351 | + "when compared to the glorious majesty of the giraffe." |
| 352 | + "\nDaniel: Hello, Girafatron!\nGirafatron:", |
| 353 | + max_length=200, |
| 354 | + do_sample=True, |
| 355 | + top_k=10, |
| 356 | + num_return_sequences=1, |
| 357 | + eos_token_id=tokenizer.eos_token_id, |
| 358 | + ) |
| 359 | + for seq in sequences: |
| 360 | + print(f"Result: {seq['generated_text']}") |
| 361 | + |
288 | 362 |
|
289 | 363 | if __name__ == "__main__": |
290 | 364 | unittest.main(verbosity=2) |
0 commit comments