diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3c7eb2fc..51747935 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.8 +++++ +* :pr:`210`: add utilities to investigate models * :pr:`208`: add a patch for Qwen3 (rewrite a loop) 0.7.7 diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 360caf3f..512edded 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -20,7 +20,6 @@ def test_image_text_to_text_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) self.assertEqual(data["task"], "image-text-to-text") - self.assertIn((data["size"], data["n_weights"]), [(12628776, 3157194)]) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**torch_deepcopy(inputs)) model(**data["inputs2"]) @@ -30,7 +29,7 @@ def test_image_text_to_text_idefics(self): ) @hide_stdout() - @requires_transformers("4.56") + @requires_transformers("4.56.99") @requires_torch("2.7.99") def test_image_text_to_text_gemma3(self): """ @@ -53,6 +52,28 @@ def test_image_text_to_text_gemma3(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) + @hide_stdout() + @requires_transformers("4.56.99") + @requires_torch("2.7.99") + def test_image_text_to_text_zai_glm(self): + """ + If the model tails because of + ``if inputs_embeds[special_image_mask].numel() != image_features.numel():```, + make sure this PR was merged: + https://github.com/huggingface/transformers/pull/39962. + """ + mid = "zai-org/GLM-4.5V" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "image-text-to-text") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + print("--", self.string_type(data["inputs"], with_shape=True)) + model(**torch_deepcopy(inputs)) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index f4a3c141..91f7e3b7 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -4,6 +4,7 @@ from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache from onnx_diagnostic.helpers.torch_helper import steal_forward +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs class TestHuggingFaceHubModel(ExtTestCase): @@ -712,6 +713,48 @@ def test_text_to_image(self): # time_step=T7s=101 # encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] + @never_test() + def test_imagetext2text_generation_zai_glm(self): + """ + clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k zai_glm + """ + from transformers import AutoProcessor + + model_id = "zai-org/GLM-4.5V" + data = get_untrained_model_with_inputs(model_id, verbose=1, add_second_input=True) + model = data["model"] + processor = AutoProcessor.from_pretrained(model_id, use_fast=True) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "http://images.cocodataset.org/val2017/000000039769.jpg", + }, + {"type": "text", "text": "describe this image"}, + ], + } + ] + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ).to(model.device) + inputs.pop("token_type_ids", None) + + print() + # steal forward creates a bug... + with steal_forward(model): # , torch.inference_mode(): + generated_ids = model.generate(**inputs, max_new_tokens=8192) + output_text = processor.decode( + generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False + ) + print(output_text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 39657e8b..e695d8bf 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -30,6 +30,7 @@ def test_get_untrained_model_with_inputs_tiny_llm(self): "input_kwargs", "model_kwargs", "task", + "dump_info", }, ) model, inputs = data["model"], data["inputs"] diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 3a5b71d9..02a025b2 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -126,3 +126,44 @@ def default_num_hidden_layers(): if capa[0] < 9: return 2 return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4 + + +def build_diff_config(config0, config1): + """ + Returns all the modified values between two configuration + """ + import torch + + diff = {} + for k in config0: + assert isinstance(k, str), f"k={k!r}, wrong type in {config0}" + if k not in config1: + v0 = getattr(config0, k) if hasattr(config0, k) else config0[k] + diff[k] = f"-{v0}" + for k in config1: + assert isinstance(k, str), f"k={k!r}, wrong type in {config1}" + if k not in config0: + v1 = getattr(config1, k) if hasattr(config1, k) else config1[k] + diff[k] = f"+{v1}" + for k in config0: + if k not in config1: + continue + v0 = getattr(config0, k) if hasattr(config0, k) else config0[k] + v1 = getattr(config1, k) if hasattr(config1, k) else config1[k] + if ( + v0 is None + or v1 is None + or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype)) + or ( + isinstance(v0, dict) + and isinstance(v1, dict) + and all(isinstance(k, int) for k in v1) + ) + ): + if v1 != v0: + diff[k] = f"{v0} -> {v1}" + else: + d = build_diff_config(v0, v1) + if d: + diff[k] = d + return diff diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 13df970f..585b6ddf 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -23,14 +23,20 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: config.vision_config.num_hidden_layers = min( config.vision_config.num_hidden_layers, 2 ) + if hasattr(config.vision_config, "num_heads"): + config.vision_config.num_heads = min(config.vision_config.num_heads, 4) if hasattr(config.vision_config, "image_size"): - config.vision_config.image_size = min(config.vision_config.image_size, 96) + config.vision_config.image_size = min(config.vision_config.image_size, 168 // 2) if hasattr(config.vision_config, "intermediate_size"): config.vision_config.intermediate_size = min( config.vision_config.intermediate_size, 1076 ) if hasattr(config.vision_config, "patch_size"): - config.vision_config.patch_size = min(config.vision_config.patch_size, 2) + config.vision_config.patch_size = min(config.vision_config.patch_size, 1) + if hasattr(config.vision_config, "temporal_patch_size"): + config.vision_config.temporal_patch_size = min( + config.vision_config.temporal_patch_size, 8 + ) if hasattr(config.vision_config, "hidden_size"): config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16) if hasattr(config, "text_config"): @@ -245,6 +251,7 @@ def get_inputs( else {0: batch_img} ), "image_attention_mask": {0: batch, 1: seq_length, 2: images}, + "image_grid_thw": {0: batch}, "use_cache": None, } @@ -256,6 +263,11 @@ def get_inputs( # input_ids[input_ids == image_token_index] = pad_token_id token_type_ids = torch.zeros_like(input_ids) token_type_ids[input_ids == image_token_index] = 1 + image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64) + image_grid_thw[:, 1] = height + image_grid_thw[:, 2] = width + image_grid_thw[0, :] //= 2 + image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype) inputs = dict( input_ids=input_ids, @@ -291,6 +303,7 @@ def get_inputs( torch.int64 ), token_type_ids=token_type_ids, + image_grid_thw=image_grid_thw, use_cache=True, # Gemma3 does not set this value to true when a cache is provided ) res = dict(inputs=inputs, dynamic_shapes=shapes) diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index be5cd153..ff6cdb5e 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4562,7 +4562,7 @@ def _ccached_diffusers_tiny_torch_full_checker_unet(): } -def _ccached_riny_random_gemma_3(): +def _ccached_tiny_random_gemma_3(): "tiny-random/gemma-3" return transformers.Gemma3Config( **{ @@ -4618,3 +4618,72 @@ def _ccached_riny_random_gemma_3(): }, } ) + + +def _ccached_zai_glm_45(): + "zai-org/GLM-4.5V" + return transformers.Glm4vMoeConfig( + **{ + "architectures": ["Glm4vMoeForConditionalGeneration"], + "model_type": "glm4v_moe", + "text_config": { + "pad_token_id": 151329, + "vocab_size": 151552, + "eos_token_id": [151329, 151336, 151338], + "image_end_token_id": 151340, + "image_start_token_id": 151339, + "image_token_id": 151363, + "head_dim": 128, + "attention_bias": true, + "attention_dropout": 0.0, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 10944, + "max_position_embeddings": 65536, + "model_type": "glm4v_moe_text", + "moe_intermediate_size": 1408, + "n_group": 1, + "n_routed_experts": 128, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 96, + "num_experts_per_tok": 8, + "num_hidden_layers": 46, + "num_key_value_heads": 8, + "partial_rotary_factor": 0.5, + "rms_norm_eps": 1e-05, + "torch_dtype": "bfloat16", + "rope_scaling": {"rope_type": "default", "mrope_section": [8, 12, 12]}, + "rope_theta": 10000.0, + "routed_scaling_factor": 1.0, + "topk_group": 1, + "use_cache": true, + "use_qk_norm": false, + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.55.0.dev0", + "video_end_token_id": 151342, + "video_start_token_id": 151341, + "video_token_id": 151364, + "vision_config": { + "attention_bias": false, + "attention_dropout": 0.0, + "depth": 24, + "hidden_act": "silu", + "hidden_size": 1536, + "image_size": 336, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 10944, + "model_type": "glm4v_moe", + "num_heads": 12, + "out_hidden_size": 4096, + "patch_size": 14, + "rms_norm_eps": 1e-05, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + } + ) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index c9d97a99..96350892 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -1,10 +1,11 @@ +import copy import inspect import os import pprint from typing import Any, Dict, Optional, Tuple import torch import transformers -from ...helpers.config_helper import update_config +from ...helpers.config_helper import update_config, build_diff_config from ...tasks import reduce_model_config, random_input_kwargs from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid @@ -121,6 +122,7 @@ def get_untrained_model_with_inputs( ) # updating the configuration + config0 = copy.deepcopy(config) mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {} if model_kwargs: for k, v in model_kwargs.items(): @@ -133,6 +135,15 @@ def get_untrained_model_with_inputs( mkwargs[k] = v if mkwargs: update_config(config, mkwargs) + try: + diff_config = build_diff_config(config0, config) + except (ValueError, AttributeError, TypeError) as e: + diff_config = f"DIFF CONFIG ERROR {e}" + if verbose: + if diff_config: + print("[get_untrained_model_with_inputs] -- updated config") + pprint.pprint(diff_config) + print("[get_untrained_model_with_inputs] --") # SDPA if model_kwargs and "attn_implementation" in model_kwargs: @@ -232,6 +243,7 @@ def get_untrained_model_with_inputs( res["input_kwargs"] = kwargs res["model_kwargs"] = mkwargs + res["dump_info"] = dict(config_diff=diff_config) sizes = compute_model_size(model) res["model"] = model diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 528584d0..43e54307 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -478,6 +478,11 @@ def validate_model( else data["configuration"].to_dict() ) ) + dump_info = data.get("dump_info", None) + if dump_info: + with open(os.path.join(dump_folder, "model_dump_info.txt"), "w") as f: + f.write(f"model_id: {model_id}\n------\n") + f.write(pprint.pformat(dump_info)) if exporter == "modelbuilder": # Models used with ModelBuilder do not like batch size > 1.