|
5 | 5 | from onnx_diagnostic.helpers import string_type |
6 | 6 | from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache |
7 | 7 | from onnx_diagnostic.helpers.torch_helper import steal_forward |
8 | | -from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs |
9 | 8 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
| 9 | +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class TestHuggingFaceHubModel(ExtTestCase): |
@@ -132,6 +132,52 @@ def test_text2text_generation_static(self): |
132 | 132 | ) |
133 | 133 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
134 | 134 |
|
| 135 | + @never_test() |
| 136 | + def test_text_generation_tiny_llm(self): |
| 137 | + # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k tiny_llm |
| 138 | + """ |
| 139 | + dict(cache_position:T7s21, |
| 140 | + past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]), |
| 141 | + input_ids:T7s1x21, |
| 142 | + position_ids:T7s1x21 |
| 143 | + attention_mask:T1s1x21) |
| 144 | + dict(cache_position:T7s1, |
| 145 | + past_key_values:DynamicCache(key_cache=#32[T1s1x8x21x128,...], |
| 146 | + value_cache=#32[T1s1x8x21x128,...]), |
| 147 | + input_ids:T7s1x21, |
| 148 | + position_ids:T7s1x1 |
| 149 | + attention_mask:T1s1x1) |
| 150 | + """ |
| 151 | + from transformers import AutoTokenizer, AutoModelForCausalLM |
| 152 | + |
| 153 | + tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM") |
| 154 | + model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-mini-instruct") |
| 155 | + |
| 156 | + text = "def greet(user): print(f'hello <extra_id_0>!')" |
| 157 | + input_ids = tokenizer(text, return_tensors="pt").input_ids.reshape((1, -1)) |
| 158 | + mask = ( |
| 159 | + torch.tensor([1 for i in range(input_ids.shape[1])]) |
| 160 | + .to(torch.int64) |
| 161 | + .reshape((1, -1)) |
| 162 | + ) |
| 163 | + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).reshape((1, -1)) |
| 164 | + |
| 165 | + # simply generate a single sequence |
| 166 | + print() |
| 167 | + with ( |
| 168 | + torch_export_patches( |
| 169 | + patch_transformers=True, patch_torch=False, patch_sympy=False |
| 170 | + ), |
| 171 | + steal_forward(model), |
| 172 | + ): |
| 173 | + generated_ids = model.generate( |
| 174 | + input_ids=input_ids, |
| 175 | + max_length=100, |
| 176 | + attention_mask=mask, |
| 177 | + position_ids=position_ids, |
| 178 | + ) |
| 179 | + print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
| 180 | + |
135 | 181 | @never_test() |
136 | 182 | def test_text_generation_phi4_mini(self): |
137 | 183 | # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini |
|
0 commit comments