Skip to content

Commit 94266ec

Browse files
committed
fix a few things
1 parent 1eeb807 commit 94266ec

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def __init__(
5656
self.kwargs = kwargs
5757
self.dynamic_shapes = dynamic_shapes
5858
self.args_names = args_names
59+
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
60+
# This assumes the dicionary for the dynamic shapes is ordered
61+
# the same way the args are. The input names are not known.
62+
assert len(self.dynamic_shapes) == len(self.args), (
63+
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
64+
f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
65+
)
66+
self.dynamic_shapes = tuple(self.dynamic_shapes.values())
5967

6068
def __str__(self) -> str:
6169
return "\n".join(

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def get_inputs_default(
256256
max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
257257
total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
258258

259+
assert batch_size > 0, "batch_size cannot be null"
259260
assert (
260261
"cls_cache" not in kwargs
261262
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
@@ -287,19 +288,22 @@ def get_inputs_default(
287288
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
288289
torch.int64
289290
)
290-
input_ids[0, 0] = image_token_index
291-
input_ids[1, 1] = image_token_index
291+
if total_sequence_length > 0:
292+
input_ids[0, 0] = image_token_index
293+
input_ids[1, 1] = image_token_index
292294
# input_ids[input_ids == image_token_index] = pad_token_id
293295
token_type_ids = torch.zeros_like(input_ids)
294296
token_type_ids[input_ids == image_token_index] = 1
295297
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
296-
image_grid_thw[:, 1] = height
297-
image_grid_thw[:, 2] = width
298-
image_grid_thw[0, :] //= 2
299-
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
298+
if n_images > 0:
299+
image_grid_thw[:, 1] = height
300+
image_grid_thw[:, 2] = width
301+
image_grid_thw[0, :] //= 2
302+
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
300303

301304
inputs = dict(
302305
input_ids=input_ids,
306+
token_type_ids=token_type_ids,
303307
attention_mask=torch.cat(
304308
[
305309
torch.ones((batch_size, sequence_length), dtype=torch.int64),
@@ -324,10 +328,9 @@ def get_inputs_default(
324328
if model.__class__.__name__ == "IdeficsForVisionText2Text"
325329
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
326330
),
327-
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
328-
# torch.int64
329-
# ),
330-
token_type_ids=token_type_ids,
331+
image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to(
332+
torch.int64
333+
),
331334
image_grid_thw=image_grid_thw,
332335
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
333336
)

0 commit comments

Comments
 (0)