From dc00d0786c3323998788de707d7c280aae723331 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 15 Apr 2025 15:07:58 +0200 Subject: [PATCH 1/3] Add support for sentence similarity --- .../tasks/automatic_speech_recognition.rst | 4 +- _doc/api/tasks/fill_mask.rst | 4 +- _doc/api/tasks/image_classification.rst | 4 +- _doc/api/tasks/index.rst | 1 + _doc/api/tasks/sentence_similarity.rst | 7 ++ _doc/api/tasks/text2text_generation.rst | 4 +- _doc/api/tasks/text_classification.rst | 4 +- _doc/api/tasks/text_generation.rst | 4 +- .../tasks/zero_shot_image_classification.rst | 4 +- _unittests/ut_tasks/test_tasks.py | 12 ++-- _unittests/ut_tasks/try_tasks.py | 49 ++++++++++++++ .../test_documentation_examples.py | 5 +- onnx_diagnostic/tasks/__init__.py | 2 + onnx_diagnostic/tasks/sentence_similarity.py | 67 +++++++++++++++++++ onnx_diagnostic/tasks/text_classification.py | 2 +- .../torch_models/hghub/hub_data.py | 2 + .../hghub/hub_data_cached_configs.py | 28 ++++++++ 17 files changed, 183 insertions(+), 20 deletions(-) create mode 100644 _doc/api/tasks/sentence_similarity.rst create mode 100644 onnx_diagnostic/tasks/sentence_similarity.py diff --git a/_doc/api/tasks/automatic_speech_recognition.rst b/_doc/api/tasks/automatic_speech_recognition.rst index 5c3f64ea..66eab05c 100644 --- a/_doc/api/tasks/automatic_speech_recognition.rst +++ b/_doc/api/tasks/automatic_speech_recognition.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.automatic_speech_recognition -=================================================== +onnx_diagnostic.tasks.automatic_speech_recognition +================================================== .. automodule:: onnx_diagnostic.tasks.automatic_speech_recognition :members: diff --git a/_doc/api/tasks/fill_mask.rst b/_doc/api/tasks/fill_mask.rst index 58c6402a..af33a231 100644 --- a/_doc/api/tasks/fill_mask.rst +++ b/_doc/api/tasks/fill_mask.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.fill_mask -================================ +onnx_diagnostic.tasks.fill_mask +=============================== .. automodule:: onnx_diagnostic.tasks.fill_mask :members: diff --git a/_doc/api/tasks/image_classification.rst b/_doc/api/tasks/image_classification.rst index 3643b2f5..9ff88587 100644 --- a/_doc/api/tasks/image_classification.rst +++ b/_doc/api/tasks/image_classification.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.image_classification -=========================================== +onnx_diagnostic.tasks.image_classification +========================================== .. automodule:: onnx_diagnostic.tasks.image_classification :members: diff --git a/_doc/api/tasks/index.rst b/_doc/api/tasks/index.rst index 4004b8f8..4d83b274 100644 --- a/_doc/api/tasks/index.rst +++ b/_doc/api/tasks/index.rst @@ -9,6 +9,7 @@ onnx_diagnostic.tasks fill_mask image_classification image_text_to_text + sentence_similarity text_classification text_generation text2text_generation diff --git a/_doc/api/tasks/sentence_similarity.rst b/_doc/api/tasks/sentence_similarity.rst new file mode 100644 index 00000000..151ccfa0 --- /dev/null +++ b/_doc/api/tasks/sentence_similarity.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.tasks.sentence_similarity +========================================= + +.. automodule:: onnx_diagnostic.tasks.sentence_similarity + :members: + :no-undoc-members: diff --git a/_doc/api/tasks/text2text_generation.rst b/_doc/api/tasks/text2text_generation.rst index c148d174..f222184f 100644 --- a/_doc/api/tasks/text2text_generation.rst +++ b/_doc/api/tasks/text2text_generation.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.text2text_generation -=========================================== +onnx_diagnostic.tasks.text2text_generation +========================================== .. automodule:: onnx_diagnostic.tasks.text2text_generation :members: diff --git a/_doc/api/tasks/text_classification.rst b/_doc/api/tasks/text_classification.rst index 22b53799..ef7759d6 100644 --- a/_doc/api/tasks/text_classification.rst +++ b/_doc/api/tasks/text_classification.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.text_classification -========================================== +onnx_diagnostic.tasks.text_classification +========================================= .. automodule:: onnx_diagnostic.tasks.text_classification :members: diff --git a/_doc/api/tasks/text_generation.rst b/_doc/api/tasks/text_generation.rst index 3f125381..ae8f2227 100644 --- a/_doc/api/tasks/text_generation.rst +++ b/_doc/api/tasks/text_generation.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.text_generation -====================================== +onnx_diagnostic.tasks.text_generation +===================================== .. automodule:: onnx_diagnostic.tasks.text_generation :members: diff --git a/_doc/api/tasks/zero_shot_image_classification.rst b/_doc/api/tasks/zero_shot_image_classification.rst index 74d9e619..7dc9e7d1 100644 --- a/_doc/api/tasks/zero_shot_image_classification.rst +++ b/_doc/api/tasks/zero_shot_image_classification.rst @@ -1,6 +1,6 @@ -onnx_diagnostic.export.zero_shot_image_classification -===================================================== +onnx_diagnostic.tasks.zero_shot_image_classification +==================================================== .. automodule:: onnx_diagnostic.tasks.zero_shot_image_classification :members: diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 30139048..ba195428 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -9,7 +9,6 @@ class TestTasks(ExtTestCase): @hide_stdout() def test_text2text_generation(self): mid = "sshleifer/tiny-marian-en-de" - # mid = "Salesforce/codet5-small" data = get_untrained_model_with_inputs(mid, verbose=1) self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)]) model, inputs = data["model"], data["inputs"] @@ -85,7 +84,6 @@ def test_automatic_speech_recognition(self): @hide_stdout() def test_imagetext2text_generation(self): mid = "HuggingFaceM4/tiny-random-idefics" - # mid = "Salesforce/codet5-small" data = get_untrained_model_with_inputs(mid, verbose=1) self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)]) model, inputs = data["model"], data["inputs"] @@ -94,7 +92,6 @@ def test_imagetext2text_generation(self): @hide_stdout() def test_fill_mask(self): mid = "google-bert/bert-base-multilingual-cased" - # mid = "Salesforce/codet5-small" data = get_untrained_model_with_inputs(mid, verbose=1) self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)]) model, inputs = data["model"], data["inputs"] @@ -103,12 +100,19 @@ def test_fill_mask(self): @hide_stdout() def test_text_classification(self): mid = "Intel/bert-base-uncased-mrpc" - # mid = "Salesforce/codet5-small" data = get_untrained_model_with_inputs(mid, verbose=1) self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)]) model, inputs = data["model"], data["inputs"] model(**inputs) + @hide_stdout() + def test_sentence_similary(self): + mid = "sentence-transformers/all-MiniLM-L6-v1" + data = get_untrained_model_with_inputs(mid, verbose=1) + self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)]) + model, inputs = data["model"], data["inputs"] + model(**inputs) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 358eb40d..389dfbf2 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -236,6 +236,55 @@ def test_text_classification(self): encoded_input["input_ids"][0] tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) + @never_test() + def test_sentence_similary(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k ce_sim + # https://huggingface.co/sentence-transformers/all-MiniLM-L6-v1 + + from transformers import AutoTokenizer, AutoModel + import torch + import torch.nn.functional as F + + # Mean Pooling - Take attention mask into account for correct averaging + def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + # Sentences we want sentence embeddings for + sentences = ["This is an example sentence", "Each sentence is converted"] + + # Load model from HuggingFace Hub + tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v1") + model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v1") + + # Tokenize sentences + encoded_input = tokenizer( + sentences, padding=True, truncation=True, return_tensors="pt" + ) + + # Compute token embeddings + with torch.no_grad(): + print() + print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True)) + model_output = model(**encoded_input) + print("-- outputs", string_type(model_output, with_shape=True, with_min_max=True)) + + # Perform pooling + sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + + print("Sentence embeddings:") + print(sentence_embeddings) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 77891e73..2382e721 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -54,7 +54,10 @@ def run_test(self, fold: str, name: str, verbose=0) -> int: # dot not installed, this part # is tested in onnx framework raise unittest.SkipTest(f"failed: {name!r} due to missing dot.") - if "We couldn't connect to 'https://huggingface.co'" in st: + if ( + "We couldn't connect to 'https://huggingface.co'" in st + or "Cannot access content at: https://huggingface.co/" in st + ): raise unittest.SkipTest(f"Connectivity issues due to\n{err}") raise AssertionError( # noqa: B904 "Example '{}' (cmd: {} - exec_prefix='{}') " diff --git a/onnx_diagnostic/tasks/__init__.py b/onnx_diagnostic/tasks/__init__.py index 0a770514..7e75482a 100644 --- a/onnx_diagnostic/tasks/__init__.py +++ b/onnx_diagnostic/tasks/__init__.py @@ -4,6 +4,7 @@ fill_mask, image_classification, image_text_to_text, + sentence_similarity, text_classification, text_generation, text2text_generation, @@ -15,6 +16,7 @@ fill_mask, image_classification, image_text_to_text, + sentence_similarity, text_classification, text_generation, text2text_generation, diff --git a/onnx_diagnostic/tasks/sentence_similarity.py b/onnx_diagnostic/tasks/sentence_similarity.py new file mode 100644 index 00000000..220287c7 --- /dev/null +++ b/onnx_diagnostic/tasks/sentence_similarity.py @@ -0,0 +1,67 @@ +from typing import Any, Callable, Dict, Optional, Tuple +import torch +from ..helpers.config_helper import update_config, check_hasattr + +__TASK__ = "sentence-similarity" + + +def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: + """Reduces a model size.""" + check_hasattr(config, "num_attention_heads", "num_hidden_layers") + kwargs = dict( + num_hidden_layers=min(config.num_hidden_layers, 2), + num_attention_heads=min(config.num_attention_heads, 4), + ) + update_config(config, kwargs) + return kwargs + + +def get_inputs( + model: torch.nn.Module, + config: Optional[Any], + batch_size: int, + sequence_length: int, + dummy_max_token_id: int, + **kwargs, # unused +): + """ + Generates inputs for task ``sentence-similarity``. + Example: + + :: + + input_ids:T7s1x13[101,72654:A16789.23076923077], + token_type_ids:T7s1x13[0,0:A0.0], + attention_mask:T7s1x13[1,1:A1.0]) + """ + batch = torch.export.Dim("batch", min=1, max=1024) + seq_length = "seq_length" + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "token_type_ids": {0: batch, 1: seq_length}, + "attention_mask": {0: batch, 1: seq_length}, + } + inputs = dict( + input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to( + torch.int64 + ), + token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64), + attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64), + ) + return dict(inputs=inputs, dynamic_shapes=shapes) + + +def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]: + """ + Inputs kwargs. + + If the configuration is None, the function selects typical dimensions. + """ + if config is not None: + check_hasattr(config, "vocab_size") + kwargs = dict( + batch_size=2, + sequence_length=30, + dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), + ) + return kwargs, get_inputs diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index 5ed6329b..37d262f0 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -25,7 +25,7 @@ def get_inputs( **kwargs, # unused ): """ - Generates inputs for task ``fill-mask``. + Generates inputs for task ``text-classification``. Example: :: diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index d5a70925..d335b170 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -15,6 +15,7 @@ BeitForImageClassification,image-classification BertForMaskedLM,fill-mask BertForSequenceClassification,text-classification + BertModel,sentence-similarity BigBirdModel,feature-extraction BlenderbotModel,feature-extraction BloomModel,feature-extraction @@ -146,6 +147,7 @@ "no-pipeline-tag", "object-detection", "reinforcement-learning", + "sentence-similarity", "text-classification", "text-generation", "text-to-audio", diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 0daf939a..5b037563 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -3502,3 +3502,31 @@ def _ccached_intel_bert_base_uncased_mrpc(): "vocab_size": 30522, } ) + + +def _ccached_sentence_transformers_all_MiniLM_L6_v1(): + "sentence-transformers/all-MiniLM-L6-v1" + return transformers.BertConfig( + **{ + "_name_or_path": "nreimers/MiniLM-L6-H384-uncased", + "architectures": ["BertModel"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 384, + "initializer_range": 0.02, + "intermediate_size": 1536, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 6, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.8.2", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522, + } + ) From c1224c4e4fe32cf14f7a79713be726ec9c685433 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 15 Apr 2025 16:27:45 +0200 Subject: [PATCH 2/3] add mamba --- _doc/examples/plot_export_tiny_llm.py | 6 +- .../ut_helpers/test_torch_test_helper.py | 6 +- _unittests/ut_tasks/test_tasks.py | 9 + _unittests/ut_tasks/try_tasks.py | 82 +++++++- onnx_diagnostic/helpers/cache_helper.py | 24 +++ onnx_diagnostic/helpers/torch_test_helper.py | 2 +- onnx_diagnostic/tasks/text_generation.py | 176 +++++++++++++----- .../torch_models/hghub/hub_data.py | 1 + .../hghub/hub_data_cached_configs.py | 39 ++++ 9 files changed, 284 insertions(+), 61 deletions(-) diff --git a/_doc/examples/plot_export_tiny_llm.py b/_doc/examples/plot_export_tiny_llm.py index 8c1b8035..84303a5b 100644 --- a/_doc/examples/plot_export_tiny_llm.py +++ b/_doc/examples/plot_export_tiny_llm.py @@ -31,7 +31,7 @@ import transformers from onnx_diagnostic import doc from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.helpers.torch_test_helper import steel_forward +from onnx_diagnostic.helpers.torch_test_helper import steal_forward from onnx_diagnostic.torch_models.llms import get_tiny_llm @@ -77,9 +77,9 @@ def _forward_(*args, _f=None, **kwargs): model.forward = keep_model_forward # %% -# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steel_forward`. +# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`. -with steel_forward(model): +with steal_forward(model): model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True) # %% diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index 9dac5ffc..f67e87c8 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -8,7 +8,7 @@ dummy_llm, to_numpy, is_torchdynamo_exporting, - steel_forward, + steal_forward, replace_string_by_dynamic, to_any, torch_deepcopy, @@ -43,14 +43,14 @@ def test_to_numpy(self): self.assertEqual(a.dtype, ml_dtypes.bfloat16) @hide_stdout() - def test_steel_forward(self): + def test_steal_forward(self): class Model(torch.nn.Module): def forward(self, x, y): return x + y inputs = torch.rand(3, 4), torch.rand(3, 4) model = Model() - with steel_forward(model): + with steal_forward(model): model(*inputs) def test_replace_string_by_dynamic(self): diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index ba195428..8d9da4d6 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -113,6 +113,15 @@ def test_sentence_similary(self): model, inputs = data["model"], data["inputs"] model(**inputs) + @hide_stdout() + def test_falcon_mamba_dev(self): + mid = "tiiuae/falcon-mamba-tiny-dev" + data = get_untrained_model_with_inputs(mid, verbose=1) + model, inputs = data["model"], data["inputs"] + print(self.string_type(inputs, with_shape=True)) + model(**inputs) + self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 389dfbf2..8deedf95 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -1,7 +1,7 @@ import unittest from onnx_diagnostic.ext_test_case import ExtTestCase, never_test from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.helpers.torch_test_helper import steel_forward +from onnx_diagnostic.helpers.torch_test_helper import steal_forward class TestHuggingFaceHubModel(ExtTestCase): @@ -92,7 +92,7 @@ def test_text2text_generation(self): # simply generate a single sequence print() - with steel_forward(model): + with steal_forward(model): generated_ids = model.generate( decoder_input_ids=input_ids, attention_mask=mask, max_length=100 ) @@ -121,7 +121,7 @@ def test_imagetext2text_generation(self): ["", ""], add_special_tokens=False ).input_ids print() - with steel_forward(model): + with steal_forward(model): generated_ids = model.generate( **inputs, max_new_tokens=10, bad_words_ids=bad_words_ids ) @@ -184,7 +184,7 @@ def test_automatic_speech_recognition(self): # generate token ids print() - with steel_forward(model): + with steal_forward(model): predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids ) @@ -285,6 +285,80 @@ def mean_pooling(model_output, attention_mask): print("Sentence embeddings:") print(sentence_embeddings) + @never_test() + def test_falcon_mamba_dev(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev + # https://huggingface.co/tiiuae/falcon-mamba-tiny-dev + + from transformers import AutoTokenizer + import transformers + import torch + + model = "tiiuae/falcon-mamba-tiny-dev" + + tokenizer = AutoTokenizer.from_pretrained(model) + pipeline = transformers.pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ) + print() + with steal_forward(pipeline.model): + sequences = pipeline( + "Girafatron is obsessed with giraffes, " + "the most glorious animal on the face of this Earth. " + "Giraftron believes all other animals are irrelevant " + "when compared to the glorious majesty of the giraffe." + "\nDaniel: Hello, Girafatron!\nGirafatron:", + max_length=200, + do_sample=True, + top_k=10, + num_return_sequences=1, + eos_token_id=tokenizer.eos_token_id, + ) + for seq in sequences: + print(f"Result: {seq['generated_text']}") + + @never_test() + def test_falcon_mamba_7b(self): + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b + # https://huggingface.co/tiiuae/falcon-mamba-7b + + from transformers import AutoTokenizer + import transformers + import torch + + model = "tiiuae/falcon-mamba-7b" + + tokenizer = AutoTokenizer.from_pretrained(model) + pipeline = transformers.pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ) + print() + with steal_forward(pipeline.model): + sequences = pipeline( + "Girafatron is obsessed with giraffes, " + "the most glorious animal on the face of this Earth. " + "Giraftron believes all other animals are irrelevant " + "when compared to the glorious majesty of the giraffe." + "\nDaniel: Hello, Girafatron!\nGirafatron:", + max_length=200, + do_sample=True, + top_k=10, + num_return_sequences=1, + eos_token_id=tokenizer.eos_token_id, + ) + for seq in sequences: + print(f"Result: {seq['generated_text']}") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 3c0bf9b7..ac27594f 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -136,3 +136,27 @@ def make_encoder_decoder_cache( return transformers.cache_utils.EncoderDecoderCache( self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache ) + + +def make_mamba_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], +) -> transformers.cache_utils.MambaCache: + "Creates a :class:`transformers.cache_utils.MambaCache`." + + class _config: + def __init__(self): + self.intermediate_size = key_value_pairs[0][0].shape[1] + self.conv_kernel = key_value_pairs[0][0].shape[-1] + self.state_size = key_value_pairs[0][1].shape[-1] + self.num_hidden_layers = len(key_value_pairs) + self.dtype = key_value_pairs[0][0].dtype + + cache = transformers.cache_utils.MambaCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + ) + for i in range(len(key_value_pairs)): + cache.conv_states[i][:, :, :] = key_value_pairs[i][0] + cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] + return cache diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 3c482bdd..d7bec618 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -31,7 +31,7 @@ def _forward_(*args, _f=None, _context=None, **kwargs): @contextlib.contextmanager -def steel_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False): +def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False): """ The necessary modification to steem forward method and prints out inputs and outputs. See example :ref:`l-plot-tiny-llm-export`. diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 7cb249a5..9fbbe920 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from ..helpers.cache_helper import make_dynamic_cache +from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache from ..helpers.config_helper import update_config, check_hasattr, _pick __TASK__ = "text-generation" @@ -8,33 +8,48 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: """Reduces a model size.""" + # FalconMambaConfig: use_mambapy check_hasattr( config, - ("head_dim", ("hidden_size", "num_attention_heads")), + ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"), "num_hidden_layers", - ("num_key_value_heads", "num_attention_heads"), + ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "intermediate_size", "hidden_size", + "vocab_size", ) - kwargs = dict( - head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), - num_hidden_layers=min(config.num_hidden_layers, 2), - num_key_value_heads=( - config.num_key_value_heads - if hasattr(config, "num_key_value_heads") - else config.num_attention_heads - ), - intermediate_size=( - min(config.intermediate_size, 24576 // 4) - if config.intermediate_size % 4 == 0 - else config.intermediate_size - ), - hidden_size=( - min(config.hidden_size, 3072 // 4) - if config.hidden_size % 4 == 0 - else config.hidden_size - ), - ) + if config.__class__.__name__ == "FalconMambaConfig": + check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 + kwargs = dict( + num_hidden_layers=min(config.num_hidden_layers, 2), + intermediate_size=256 if config is None else min(512, config.intermediate_size), + hidden_size=256 if config is None else min(256, config.hidden_size), + cls_cache="MambaCache", + state_size=8 if config is None else getattr(config, "state_size", None), + conv_kernel=4 if config is None else getattr(config, "conv_kernel", None), + ) + else: + kwargs = dict( + head_dim=getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ), + num_hidden_layers=min(config.num_hidden_layers, 2), + num_key_value_heads=( + config.num_key_value_heads + if hasattr(config, "num_key_value_heads") + else config.num_attention_heads + ), + intermediate_size=( + min(config.intermediate_size, 24576 // 4) + if config.intermediate_size % 4 == 0 + else config.intermediate_size + ), + hidden_size=( + min(config.hidden_size, 3072 // 4) + if config.hidden_size % 4 == 0 + else config.hidden_size + ), + ) update_config(config, kwargs) return kwargs @@ -43,13 +58,14 @@ def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, - num_key_value_heads: int, num_hidden_layers: int, - head_dim: int, batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, + num_key_value_heads: Optional[int] = None, + head_dim: Optional[int] = None, + cls_cache: Optional[Union[type, str]] = None, **kwargs, # unused ): """ @@ -63,15 +79,60 @@ def get_inputs( :param sequence_length: sequence length :param sequence_length2: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) + :param cls_cache: cache class, by default it is + :class:`transformers.cache_utils.DynamicCache` :return: dictionary """ - if head_dim is None: - assert config, "head_dim is None, the value cannot be set without a configuration" - head_dim = config.hidden_size // config.num_attention_heads batch = torch.export.Dim("batch", min=1, max=1024) seq_length = torch.export.Dim("seq_length", min=1, max=4096) cache_length = torch.export.Dim("cache_length", min=1, max=4096) + if config is not None and config.__class__.__name__ == "FalconMambaConfig": + shapes = { + "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, + "attention_mask": { + 0: batch, + 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + }, + "cache_position": { + 0: batch, + 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + }, + "cache_params": [ + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch} for _ in range(num_hidden_layers)], + ], + } + inputs = dict( + input_ids=torch.randint( + 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2) + ).to(torch.int64), + attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( + torch.int64 + ), + cache_position=torch.arange(0, sequence_length + sequence_length2) + .to(torch.int64) + .expand((batch_size, -1)), + cache_params=make_mamba_cache( + [ + ( + torch.randn( + batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"] + ), + torch.randn( + batch_size, kwargs["intermediate_size"], kwargs["state_size"] + ), + ) + for i in range(num_hidden_layers) + ] + ), + ) + return dict(inputs=inputs, dynamic_shapes=shapes) + + if head_dim is None: + assert config, "head_dim is None, the value cannot be set without a configuration" + head_dim = config.hidden_size // config.num_attention_heads + shapes = { "input_ids": {0: batch, 1: seq_length}, "attention_mask": { @@ -120,29 +181,44 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl check_hasattr( config, "vocab_size", - "hidden_size", - "num_attention_heads", - ("num_key_value_heads", "num_attention_heads"), + ("num_attention_heads", "use_mambapy"), + ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "intermediate_size", "hidden_size", ) - kwargs = dict( - batch_size=2, - sequence_length=30, - sequence_length2=3, - head_dim=( - 16 - if config is None - else getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - ), - dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), - num_hidden_layers=4 if config is None else config.num_hidden_layers, - num_key_value_heads=( - 24 - if config is None - else _pick(config, "num_key_value_heads", "num_attention_heads") - ), - intermediate_size=1024 if config is None else config.intermediate_size, - hidden_size=512 if config is None else config.hidden_size, - ) + if config.__class__.__name__ == "FalconMambaConfig": + check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 + kwargs = dict( + batch_size=2, + sequence_length=30, + sequence_length2=3, + dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), + num_hidden_layers=4 if config is None else config.num_hidden_layers, + intermediate_size=256 if config is None else config.intermediate_size, + cls_cache="MambaCache", + state_size=8 if config is None else getattr(config, "state_size", None), + conv_kernel=8 if config is None else getattr(config, "conv_kernel", None), + ) + else: + kwargs = dict( + batch_size=2, + sequence_length=30, + sequence_length2=3, + head_dim=( + 16 + if config is None + else getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + ), + dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), + num_hidden_layers=4 if config is None else config.num_hidden_layers, + num_key_value_heads=( + 24 + if config is None + else _pick(config, "num_key_value_heads", "num_attention_heads") + ), + intermediate_size=1024 if config is None else config.intermediate_size, + hidden_size=512 if config is None else config.hidden_size, + ) return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index d335b170..5c5590ff 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -41,6 +41,7 @@ DonutSwinModel,feature-extraction ElectraModel,feature-extraction EsmModel,feature-extraction + FalconMambaForCausalLM,text-generation GLPNModel,image-feature-extraction GPTBigCodeModel,feature-extraction GPTJModel,feature-extraction diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 5b037563..d69b23cd 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -3530,3 +3530,42 @@ def _ccached_sentence_transformers_all_MiniLM_L6_v1(): "vocab_size": 30522, } ) + + +def _ccached_tiiuae_falcon_mamba_tiny_dev(): + "tiiuae/falcon-mamba-tiny-dev" + return transformers.FalconMambaConfig( + **{ + "architectures": ["FalconMambaForCausalLM"], + "bos_token_id": 0, + "conv_kernel": 4, + "eos_token_id": 11, + "expand": 16, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.1, + "intermediate_size": 8192, + "layer_norm_epsilon": 1e-05, + "mixer_rms_eps": 1e-06, + "model_type": "falcon_mamba", + "num_hidden_layers": 64, + "pad_token_id": 11, + "rescale_prenorm_residual": false, + "residual_in_fp32": true, + "state_size": 16, + "tie_word_embeddings": false, + "time_step_floor": 0.0001, + "time_step_init_scheme": "random", + "time_step_max": 0.1, + "time_step_min": 0.001, + "time_step_rank": 256, + "time_step_scale": 1.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.0.dev0", + "use_bias": false, + "use_cache": true, + "use_conv_bias": true, + "use_mambapy": false, + "vocab_size": 65024, + } + ) From a5462761f78979aa7398f1fd4518a17e0015a42e Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 15 Apr 2025 16:55:39 +0200 Subject: [PATCH 3/3] fix mamba --- _unittests/ut_tasks/test_tasks.py | 2 +- onnx_diagnostic/helpers/cache_helper.py | 8 ++++++++ onnx_diagnostic/tasks/text_generation.py | 14 +++++++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 8d9da4d6..415132f9 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -120,7 +120,7 @@ def test_falcon_mamba_dev(self): model, inputs = data["model"], data["inputs"] print(self.string_type(inputs, with_shape=True)) model(**inputs) - self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)]) + self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)]) if __name__ == "__main__": diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index ac27594f..b2417ebb 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -157,6 +157,14 @@ def __init__(self): device=key_value_pairs[0][0].device, ) for i in range(len(key_value_pairs)): + assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, ( + f"Shape mismatch, expected {cache.conv_states[i].shape}, " + f"got {key_value_pairs[i][0].shape}" + ) cache.conv_states[i][:, :, :] = key_value_pairs[i][0] + assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, ( + f"Shape mismatch, expected {cache.ssm_states[i].shape}, " + f"got {key_value_pairs[i][1].shape}" + ) cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] return cache diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 9fbbe920..eb81bc78 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -88,6 +88,15 @@ def get_inputs( cache_length = torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": + seq_length_multiple = 8 + sequence_length = ( + (sequence_length + seq_length_multiple) + // seq_length_multiple + * seq_length_multiple + ) + # sequence_inc = seq_length_multiple + sequence_length2 = seq_length_multiple + shapes = { "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, "attention_mask": { @@ -110,9 +119,8 @@ def get_inputs( attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( torch.int64 ), - cache_position=torch.arange(0, sequence_length + sequence_length2) - .to(torch.int64) - .expand((batch_size, -1)), + cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64), + # .expand((batch_size, -1)) cache_params=make_mamba_cache( [ (