Skip to content

Commit d677cf8

Browse files
committed
Add utilities to investigate zai_model
1 parent ddea895 commit d677cf8

File tree

7 files changed

+197
-3
lines changed

7 files changed

+197
-3
lines changed

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_image_text_to_text_idefics(self):
3030
)
3131

3232
@hide_stdout()
33-
@requires_transformers("4.56")
33+
@requires_transformers("4.55.99")
3434
@requires_torch("2.7.99")
3535
def test_image_text_to_text_gemma3(self):
3636
"""
@@ -53,6 +53,28 @@ def test_image_text_to_text_gemma3(self):
5353
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
5454
)
5555

56+
@hide_stdout()
57+
@requires_transformers("4.55.99")
58+
@requires_torch("2.7.99")
59+
def test_image_text_to_text_zai_glm(self):
60+
"""
61+
If the model tails because of
62+
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
63+
make sure this PR was merged:
64+
https://github.com/huggingface/transformers/pull/39962.
65+
"""
66+
mid = "zai-org/GLM-4.5V"
67+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
68+
self.assertEqual(data["task"], "image-text-to-text")
69+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
70+
print("--", self.string_type(data["inputs"], with_shape=True))
71+
model(**torch_deepcopy(inputs))
72+
model(**data["inputs2"])
73+
with torch_export_patches(patch_transformers=True, verbose=10):
74+
torch.export.export(
75+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
76+
)
77+
5678

5779
if __name__ == "__main__":
5880
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from onnx_diagnostic.helpers import string_type
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
66
from onnx_diagnostic.helpers.torch_helper import steal_forward
7+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
78

89

910
class TestHuggingFaceHubModel(ExtTestCase):
@@ -712,6 +713,48 @@ def test_text_to_image(self):
712713
# time_step=T7s=101
713714
# encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
714715

716+
@never_test()
717+
def test_imagetext2text_generation_zai_glm(self):
718+
"""
719+
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k zai_glm
720+
"""
721+
from transformers import AutoProcessor
722+
723+
model_id = "zai-org/GLM-4.5V"
724+
data = get_untrained_model_with_inputs(model_id, verbose=1, add_second_input=True)
725+
model = data["model"]
726+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
727+
728+
messages = [
729+
{
730+
"role": "user",
731+
"content": [
732+
{
733+
"type": "image",
734+
"url": "http://images.cocodataset.org/val2017/000000039769.jpg",
735+
},
736+
{"type": "text", "text": "describe this image"},
737+
],
738+
}
739+
]
740+
inputs = processor.apply_chat_template(
741+
messages,
742+
tokenize=True,
743+
add_generation_prompt=True,
744+
return_dict=True,
745+
return_tensors="pt",
746+
).to(model.device)
747+
inputs.pop("token_type_ids", None)
748+
749+
print()
750+
# steal forward creates a bug...
751+
with steal_forward(model): # , torch.inference_mode():
752+
generated_ids = model.generate(**inputs, max_new_tokens=8192)
753+
output_text = processor.decode(
754+
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
755+
)
756+
print(output_text)
757+
715758

716759
if __name__ == "__main__":
717760
unittest.main(verbosity=2)

onnx_diagnostic/helpers/config_helper.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,42 @@ def default_num_hidden_layers():
126126
if capa[0] < 9:
127127
return 2
128128
return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
129+
130+
131+
def build_diff_config(config0, config1):
132+
"""
133+
Returns all the modified values between two configuration
134+
"""
135+
import torch
136+
137+
diff = {}
138+
for k in config0:
139+
assert isinstance(k, str), f"k={k!r}, wrong type in {config0}"
140+
if k not in config1:
141+
diff[k] = f"-{config0[k]}"
142+
for k in config1:
143+
assert isinstance(k, str), f"k={k!r}, wrong type in {config1}"
144+
if k not in config0:
145+
diff[k] = f"+{config1[k]}"
146+
for k in config0:
147+
if k not in config1:
148+
continue
149+
v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
150+
v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
151+
if (
152+
v0 is None
153+
or v1 is None
154+
or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype))
155+
or (
156+
isinstance(v0, dict)
157+
and isinstance(v1, dict)
158+
and all(isinstance(k, int) for k in v1)
159+
)
160+
):
161+
if v1 != v0:
162+
diff[k] = f"{v0} -> {v1}"
163+
else:
164+
d = build_diff_config(v0, v1)
165+
if d:
166+
diff[k] = d
167+
return diff

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def get_inputs(
245245
else {0: batch_img}
246246
),
247247
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
248+
"image_grid_thw": {0: batch},
248249
"use_cache": None,
249250
}
250251

@@ -256,6 +257,11 @@ def get_inputs(
256257
# input_ids[input_ids == image_token_index] = pad_token_id
257258
token_type_ids = torch.zeros_like(input_ids)
258259
token_type_ids[input_ids == image_token_index] = 1
260+
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
261+
image_grid_thw[:, 1] = height
262+
image_grid_thw[:, 2] = width
263+
image_grid_thw[0, :] //= 2
264+
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
259265

260266
inputs = dict(
261267
input_ids=input_ids,
@@ -291,6 +297,7 @@ def get_inputs(
291297
torch.int64
292298
),
293299
token_type_ids=token_type_ids,
300+
image_grid_thw=image_grid_thw,
294301
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
295302
)
296303
res = dict(inputs=inputs, dynamic_shapes=shapes)

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4562,7 +4562,7 @@ def _ccached_diffusers_tiny_torch_full_checker_unet():
45624562
}
45634563

45644564

4565-
def _ccached_riny_random_gemma_3():
4565+
def _ccached_tiny_random_gemma_3():
45664566
"tiny-random/gemma-3"
45674567
return transformers.Gemma3Config(
45684568
**{
@@ -4618,3 +4618,72 @@ def _ccached_riny_random_gemma_3():
46184618
},
46194619
}
46204620
)
4621+
4622+
4623+
def _ccached_zai_glm_45():
4624+
"zai-org/GLM-4.5V"
4625+
return transformers.Glm4vMoeConfig(
4626+
**{
4627+
"architectures": ["Glm4vMoeForConditionalGeneration"],
4628+
"model_type": "glm4v_moe",
4629+
"text_config": {
4630+
"pad_token_id": 151329,
4631+
"vocab_size": 151552,
4632+
"eos_token_id": [151329, 151336, 151338],
4633+
"image_end_token_id": 151340,
4634+
"image_start_token_id": 151339,
4635+
"image_token_id": 151363,
4636+
"head_dim": 128,
4637+
"attention_bias": true,
4638+
"attention_dropout": 0.0,
4639+
"first_k_dense_replace": 1,
4640+
"hidden_act": "silu",
4641+
"hidden_size": 4096,
4642+
"initializer_range": 0.02,
4643+
"intermediate_size": 10944,
4644+
"max_position_embeddings": 65536,
4645+
"model_type": "glm4v_moe_text",
4646+
"moe_intermediate_size": 1408,
4647+
"n_group": 1,
4648+
"n_routed_experts": 128,
4649+
"n_shared_experts": 1,
4650+
"norm_topk_prob": true,
4651+
"num_attention_heads": 96,
4652+
"num_experts_per_tok": 8,
4653+
"num_hidden_layers": 46,
4654+
"num_key_value_heads": 8,
4655+
"partial_rotary_factor": 0.5,
4656+
"rms_norm_eps": 1e-05,
4657+
"torch_dtype": "bfloat16",
4658+
"rope_scaling": {"rope_type": "default", "mrope_section": [8, 12, 12]},
4659+
"rope_theta": 10000.0,
4660+
"routed_scaling_factor": 1.0,
4661+
"topk_group": 1,
4662+
"use_cache": true,
4663+
"use_qk_norm": false,
4664+
},
4665+
"torch_dtype": "bfloat16",
4666+
"transformers_version": "4.55.0.dev0",
4667+
"video_end_token_id": 151342,
4668+
"video_start_token_id": 151341,
4669+
"video_token_id": 151364,
4670+
"vision_config": {
4671+
"attention_bias": false,
4672+
"attention_dropout": 0.0,
4673+
"depth": 24,
4674+
"hidden_act": "silu",
4675+
"hidden_size": 1536,
4676+
"image_size": 336,
4677+
"in_channels": 3,
4678+
"initializer_range": 0.02,
4679+
"intermediate_size": 10944,
4680+
"model_type": "glm4v_moe",
4681+
"num_heads": 12,
4682+
"out_hidden_size": 4096,
4683+
"patch_size": 14,
4684+
"rms_norm_eps": 1e-05,
4685+
"spatial_merge_size": 2,
4686+
"temporal_patch_size": 2,
4687+
},
4688+
}
4689+
)

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import copy
12
import inspect
23
import os
34
import pprint
45
from typing import Any, Dict, Optional, Tuple
56
import torch
67
import transformers
7-
from ...helpers.config_helper import update_config
8+
from ...helpers.config_helper import update_config, build_diff_config
89
from ...tasks import reduce_model_config, random_input_kwargs
910
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
1011

@@ -121,6 +122,7 @@ def get_untrained_model_with_inputs(
121122
)
122123

123124
# updating the configuration
125+
config0 = copy.deepcopy(config)
124126
mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {}
125127
if model_kwargs:
126128
for k, v in model_kwargs.items():
@@ -133,6 +135,12 @@ def get_untrained_model_with_inputs(
133135
mkwargs[k] = v
134136
if mkwargs:
135137
update_config(config, mkwargs)
138+
diff_config = build_diff_config(config0, config)
139+
if verbose:
140+
if diff_config:
141+
print("[get_untrained_model_with_inputs] -- updated config")
142+
pprint.pprint(diff_config)
143+
print("[get_untrained_model_with_inputs] --")
136144

137145
# SDPA
138146
if model_kwargs and "attn_implementation" in model_kwargs:
@@ -232,6 +240,7 @@ def get_untrained_model_with_inputs(
232240

233241
res["input_kwargs"] = kwargs
234242
res["model_kwargs"] = mkwargs
243+
res["dump_info"] = dict(config_diff=diff_config)
235244

236245
sizes = compute_model_size(model)
237246
res["model"] = model

onnx_diagnostic/torch_models/validate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,11 @@ def validate_model(
478478
else data["configuration"].to_dict()
479479
)
480480
)
481+
dump_info = data.get("dump_info", None)
482+
if dump_info:
483+
with open(os.path.join(dump_folder, "model_dump_info.txt"), "w") as f:
484+
f.write(f"model_id: {model_id}\n------\n")
485+
f.write(pprint.pformat(dump_info))
481486

482487
if exporter == "modelbuilder":
483488
# Models used with ModelBuilder do not like batch size > 1.

0 commit comments

Comments
 (0)