Skip to content

Commit 4bdaa1c

Browse files
committed
fix
1 parent 94266ec commit 4bdaa1c

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_image_text_to_text_idefics(self):
2222
self.assertEqual(data["task"], "image-text-to-text")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
25+
print("***", self.string_type(data["inputs2"], with_shape=True))
2526
model(**data["inputs2"])
2627
with torch_export_patches(patch_transformers=True, verbose=10):
2728
torch.export.export(

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
self.dynamic_shapes = dynamic_shapes
5858
self.args_names = args_names
5959
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
60-
# This assumes the dicionary for the dynamic shapes is ordered
60+
# This assumes the dictionary for the dynamic shapes is ordered
6161
# the same way the args are. The input names are not known.
6262
assert len(self.dynamic_shapes) == len(self.args), (
6363
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def get_inputs_default(
290290
)
291291
if total_sequence_length > 0:
292292
input_ids[0, 0] = image_token_index
293-
input_ids[1, 1] = image_token_index
293+
if min(input_ids.shape) > 1:
294+
input_ids[1, 1] = image_token_index
294295
# input_ids[input_ids == image_token_index] = pad_token_id
295296
token_type_ids = torch.zeros_like(input_ids)
296297
token_type_ids[input_ids == image_token_index] = 1
@@ -439,9 +440,9 @@ def get_inputs(
439440
height=height,
440441
num_channels=num_channels,
441442
batch_size=3,
442-
sequence_length=0,
443-
max_sequence_length=0,
444-
total_sequence_length=0,
443+
sequence_length=1,
444+
max_sequence_length=1,
445+
total_sequence_length=1,
445446
n_images=0,
446447
pad_token_id=pad_token_id,
447448
image_token_index=image_token_index,

0 commit comments

Comments
 (0)