Skip to content

Commit 581cbb5

Browse files
committed
fix a few things
1 parent 92dbf02 commit 581cbb5

File tree

3 files changed

+34
-28
lines changed

3 files changed

+34
-28
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.7.13
55
++++++
66

7-
* :pr:`237`: dummy inputs for gemma-3-4b-it
7+
* :pr:`237`: dummy inputs for google/gemma-3-4b-it
88
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
99

1010
0.7.12

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _get_inputs_gemma3(
144144
"sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
145145
},
146146
"position_ids": {0: batch, 1: seq_length},
147-
"cache_position": {1: seq_length},
147+
"cache_position": {0: seq_length},
148148
"past_key_values": [
149149
[{0: batch} for _ in range(num_hidden_layers)],
150150
[{0: batch} for _ in range(num_hidden_layers)],
@@ -159,31 +159,37 @@ def _get_inputs_gemma3(
159159
dummies = dummies[("", 0, "I")][1]
160160
dummies = {k: v for k, v in dummies.items() if k in shapes}
161161
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
162-
assert expected & set(
163-
dummies
164-
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
165-
assert sequence_length == dummies["input_ids"].shape[-1], (
166-
f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
167-
f"model class {model.__class__.__name__}"
168-
)
169-
assert batch_size == dummies["input_ids"].shape[0], (
170-
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
171-
f"model class {model.__class__.__name__}"
172-
)
173-
assert max_sequence_length == 580, (
174-
f"max_sequence_length={max_sequence_length} != 580 "
175-
f"for model {model.__class__.__name__}"
176-
)
177-
assert total_sequence_length == 860, (
178-
f"total_sequence_length={total_sequence_length} != 860 "
179-
f"for model {model.__class__.__name__}"
180-
)
181-
assert head_dim == 256, f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
182-
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
183-
assert num_key_value_heads == 4, (
184-
f"num_key_value_heads={num_key_value_heads} != 256 "
185-
f"for this model {model.__class__.__name__}"
186-
)
162+
163+
def _check_():
164+
assert expected & set(
165+
dummies
166+
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
167+
assert sequence_length == dummies["input_ids"].shape[-1], (
168+
f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
169+
f"model class {model.__class__.__name__}"
170+
)
171+
assert batch_size == dummies["input_ids"].shape[0], (
172+
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
173+
f"model class {model.__class__.__name__}"
174+
)
175+
assert max_sequence_length == 580, (
176+
f"max_sequence_length={max_sequence_length} != 580 "
177+
f"for model {model.__class__.__name__}"
178+
)
179+
assert total_sequence_length == 860, (
180+
f"total_sequence_length={total_sequence_length} != 860 "
181+
f"for model {model.__class__.__name__}"
182+
)
183+
assert (
184+
head_dim == 256
185+
), f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
186+
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
187+
assert num_key_value_heads == 4, (
188+
f"num_key_value_heads={num_key_value_heads} != 256 "
189+
f"for this model {model.__class__.__name__}"
190+
)
191+
192+
_check_()
187193

188194
inputs = dict(
189195
input_ids=dummies["input_ids"],

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def validate_model(
823823
for key in ["model", "onnx_program", "config"]:
824824
if key in data:
825825
del data[key]
826-
if "cuda" in device.lower():
826+
if device is not None and "cuda" in str(device).lower():
827827
torch.cuda.empty_cache()
828828
gc.collect()
829829
print("[validation_model] -- done")

0 commit comments

Comments
 (0)