diff --git a/onnx_diagnostic/tasks/automatic_speech_recognition.py b/onnx_diagnostic/tasks/automatic_speech_recognition.py index 43aac84c..c122c086 100644 --- a/onnx_diagnostic/tasks/automatic_speech_recognition.py +++ b/onnx_diagnostic/tasks/automatic_speech_recognition.py @@ -76,7 +76,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" shapes = { diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index 13861ba4..b049a5b6 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -47,7 +47,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "sequence_length" shapes = { "input_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/tasks/fill_mask.py b/onnx_diagnostic/tasks/fill_mask.py index a59365e1..0e790ee8 100644 --- a/onnx_diagnostic/tasks/fill_mask.py +++ b/onnx_diagnostic/tasks/fill_mask.py @@ -42,7 +42,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "sequence_length" shapes = { "input_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 585b6ddf..fe3b6d4a 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -107,7 +107,7 @@ def _get_inputs_gemma3( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) @@ -230,7 +230,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" batch_img = torch.export.Dim("batch_img", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/sentence_similarity.py b/onnx_diagnostic/tasks/sentence_similarity.py index 7bcdb889..7386ec5a 100644 --- a/onnx_diagnostic/tasks/sentence_similarity.py +++ b/onnx_diagnostic/tasks/sentence_similarity.py @@ -42,7 +42,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" shapes = { "input_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index 6c7fde27..5760c41c 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -70,7 +70,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096) cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index 94d0d888..fc8cd2e0 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -72,7 +72,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length_key" cache_length2 = "cache_length_val" diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index 380b15cc..4b82155d 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -42,7 +42,7 @@ def get_inputs( assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024) shapes = { "input_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 6b8ae5ef..6e6e29ba 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -83,7 +83,7 @@ def get_inputs( :class:`transformers.cache_utils.DynamicCache` :return: dictionary """ - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) diff --git a/onnx_diagnostic/tasks/zero_shot_image_classification.py b/onnx_diagnostic/tasks/zero_shot_image_classification.py index 61fee29e..d19e3c9e 100644 --- a/onnx_diagnostic/tasks/zero_shot_image_classification.py +++ b/onnx_diagnostic/tasks/zero_shot_image_classification.py @@ -65,7 +65,7 @@ def get_inputs( input_width, int ), f"Unexpected type for input_height {type(input_height)}{config}" - batch = torch.export.Dim("batch", min=1, max=1024) + batch = "batch" seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length},