Skip to content

Commit 33fd641

Browse files
committed
fix unittest
1 parent e1cd7cb commit 33fd641

File tree

1 file changed

+14
-46
lines changed

1 file changed

+14
-46
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -57,37 +57,21 @@ def test_automatic_speech_recognition_float32(self):
5757
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
5858
model(**data["inputs"])
5959
model(**data["inputs2"])
60-
Dim = torch.export.Dim
6160
self.maxDiff = None
62-
self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds))
61+
self.assertIn("{0:DYN(batch),1:DYN(seq_length)}", self.string_type(ds))
6362
self.assertEqualAny(
6463
{
65-
"decoder_input_ids": {
66-
0: Dim("batch", min=1, max=1024),
67-
1: "seq_length",
68-
},
64+
"decoder_input_ids": {0: "batch", 1: "seq_length"},
6965
"cache_position": {0: "seq_length"},
70-
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
66+
"encoder_outputs": [{0: "batch"}],
7167
"past_key_values": [
7268
[
73-
[
74-
{0: Dim("batch", min=1, max=1024)},
75-
{0: Dim("batch", min=1, max=1024)},
76-
],
77-
[
78-
{0: Dim("batch", min=1, max=1024)},
79-
{0: Dim("batch", min=1, max=1024)},
80-
],
69+
[{0: "batch"}, {0: "batch"}],
70+
[{0: "batch"}, {0: "batch"}],
8171
],
8272
[
83-
[
84-
{0: Dim("batch", min=1, max=1024)},
85-
{0: Dim("batch", min=1, max=1024)},
86-
],
87-
[
88-
{0: Dim("batch", min=1, max=1024)},
89-
{0: Dim("batch", min=1, max=1024)},
90-
],
73+
[{0: "batch"}, {0: "batch"}],
74+
[{0: "batch"}, {0: "batch"}],
9175
],
9276
],
9377
},
@@ -134,37 +118,21 @@ def test_automatic_speech_recognition_float16(self):
134118
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
135119
model = to_any(model, torch.float16)
136120
model(**data["inputs2"])
137-
Dim = torch.export.Dim
138121
self.maxDiff = None
139-
self.assertIn("{0:Dim(batch),1:DYN(seq_length)}", self.string_type(ds))
122+
self.assertIn("{0:DYN(batch),1:DYN(seq_length)}", self.string_type(ds))
140123
self.assertEqualAny(
141124
{
142-
"decoder_input_ids": {
143-
0: Dim("batch", min=1, max=1024),
144-
1: "seq_length",
145-
},
125+
"decoder_input_ids": {0: "batch", 1: "seq_length"},
146126
"cache_position": {0: "seq_length"},
147-
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
127+
"encoder_outputs": [{0: "batch"}],
148128
"past_key_values": [
149129
[
150-
[
151-
{0: Dim("batch", min=1, max=1024)},
152-
{0: Dim("batch", min=1, max=1024)},
153-
],
154-
[
155-
{0: Dim("batch", min=1, max=1024)},
156-
{0: Dim("batch", min=1, max=1024)},
157-
],
130+
[{0: "batch"}, {0: "batch"}],
131+
[{0: "batch"}, {0: "batch"}],
158132
],
159133
[
160-
[
161-
{0: Dim("batch", min=1, max=1024)},
162-
{0: Dim("batch", min=1, max=1024)},
163-
],
164-
[
165-
{0: Dim("batch", min=1, max=1024)},
166-
{0: Dim("batch", min=1, max=1024)},
167-
],
134+
[{0: "batch"}, {0: "batch"}],
135+
[{0: "batch"}, {0: "batch"}],
168136
],
169137
],
170138
},

0 commit comments

Comments
 (0)