Skip to content

Commit e53fe05

Browse files
authored
Use DYNAMIC on batch size (#213)
* use DYNAMIC(str) * delete check-urls.yml * Revert "delete check-urls.yml" This reverts commit 6983c2c.
1 parent a6caa0a commit e53fe05

10 files changed

+11
-11
lines changed

onnx_diagnostic/tasks/automatic_speech_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_inputs(
7676
assert (
7777
"cls_cache" not in kwargs
7878
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
79-
batch = torch.export.Dim("batch", min=1, max=1024)
79+
batch = "batch"
8080
seq_length = "seq_length"
8181

8282
shapes = {

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_inputs(
4747
assert (
4848
"cls_cache" not in kwargs
4949
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
50-
batch = torch.export.Dim("batch", min=1, max=1024)
50+
batch = "batch"
5151
seq_length = "sequence_length"
5252
shapes = {
5353
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/tasks/fill_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_inputs(
4242
assert (
4343
"cls_cache" not in kwargs
4444
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45-
batch = torch.export.Dim("batch", min=1, max=1024)
45+
batch = "batch"
4646
seq_length = "sequence_length"
4747
shapes = {
4848
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _get_inputs_gemma3(
107107
assert (
108108
"cls_cache" not in kwargs
109109
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
110-
batch = torch.export.Dim("batch", min=1, max=1024)
110+
batch = "batch"
111111
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
112112
# cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
113113

@@ -230,7 +230,7 @@ def get_inputs(
230230
assert (
231231
"cls_cache" not in kwargs
232232
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
233-
batch = torch.export.Dim("batch", min=1, max=1024)
233+
batch = "batch"
234234
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
235235
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
236236
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)

onnx_diagnostic/tasks/sentence_similarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_inputs(
4242
assert (
4343
"cls_cache" not in kwargs
4444
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45-
batch = torch.export.Dim("batch", min=1, max=1024)
45+
batch = "batch"
4646
seq_length = "seq_length"
4747
shapes = {
4848
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/tasks/summarization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_inputs(
7070
assert (
7171
"cls_cache" not in kwargs
7272
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
73-
batch = torch.export.Dim("batch", min=1, max=1024)
73+
batch = "batch"
7474
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
7575
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
7676
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_inputs(
7272
assert (
7373
"cls_cache" not in kwargs
7474
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
75-
batch = torch.export.Dim("batch", min=1, max=1024)
75+
batch = "batch"
7676
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
7777
cache_length = "cache_length_key"
7878
cache_length2 = "cache_length_val"

onnx_diagnostic/tasks/text_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_inputs(
4242
assert (
4343
"cls_cache" not in kwargs
4444
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45-
batch = torch.export.Dim("batch", min=1, max=1024)
45+
batch = "batch"
4646
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
4747
shapes = {
4848
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_inputs(
8383
:class:`transformers.cache_utils.DynamicCache`
8484
:return: dictionary
8585
"""
86-
batch = torch.export.Dim("batch", min=1, max=1024)
86+
batch = "batch"
8787
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
8888
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
8989

onnx_diagnostic/tasks/zero_shot_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_inputs(
6565
input_width, int
6666
), f"Unexpected type for input_height {type(input_height)}{config}"
6767

68-
batch = torch.export.Dim("batch", min=1, max=1024)
68+
batch = "batch"
6969
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
7070
shapes = {
7171
"input_ids": {0: batch, 1: seq_length},

0 commit comments

Comments
 (0)