Skip to content

Commit 627a596

Browse files
committed
other fixes
1 parent 5195875 commit 627a596

File tree

7 files changed

+62
-22
lines changed

7 files changed

+62
-22
lines changed

_unittests/ut_export/test_shape_helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def test_all_dynamic_shapes_from_inputs_dynamic_cache(self):
155155
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
156156
"position_ids": {0: "d_2_0", 1: "d_2_1"},
157157
"past_key_values": [
158-
[{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
159-
[{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
158+
{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"},
159+
{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"},
160160
],
161161
},
162162
ds,
@@ -176,8 +176,8 @@ def test_guess_dynamic_shapes_from_inputs(self):
176176
"attention_mask": {0: "dd_0I0", 1: "dd_0I1"},
177177
"input_ids": {0: "dd_1I0", 1: "dd_1I1"},
178178
"past_key_values": [
179-
[{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}],
180-
[{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}],
179+
{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"},
180+
{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"},
181181
],
182182
"position_ids": {0: "dd_3I0", 1: "dd_3I1"},
183183
},

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,50 @@
55
from onnx_diagnostic.helpers.rt_helper import onnx_generate
66
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
77
from onnx_diagnostic.torch_export_patches import torch_export_patches
8+
from onnx_diagnostic.export.api import to_onnx
89

910

1011
class TestRtSession(ExtTestCase):
12+
def simple_generate_with_cache(
13+
self, model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
14+
):
15+
# First call: prefill
16+
outputs = model(
17+
input_ids,
18+
attention_mask=torch.ones(
19+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
20+
),
21+
use_cache=True,
22+
)
23+
24+
# Next calls: decode
25+
for _ in range(max_new_tokens):
26+
next_token_logits = outputs.logits[:, -1, :]
27+
past_key_values = outputs.past_key_values
28+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
29+
if next_token_id.item() == eos_token_id:
30+
break
31+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
32+
outputs = model(
33+
next_token_id,
34+
use_cache=True,
35+
past_key_values=past_key_values,
36+
attention_mask=torch.ones(
37+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
38+
),
39+
)
40+
return input_ids
41+
1142
@hide_stdout()
1243
def test_onnx_generate(self):
13-
from experimental_experiment.torch_interpreter import to_onnx
14-
1544
mid = "arnir0/Tiny-LLM"
1645
print("-- test_onnx_generate: get model")
1746
data = get_untrained_model_with_inputs(mid)
1847
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
1948
del inputs["position_ids"]
2049
del ds["position_ids"]
2150
input_ids = inputs["input_ids"]
51+
print("----", input_ids.shape)
2252
folder = self.get_dump_folder("test_onnx_generate")
2353
model_name = os.path.join(folder, "model.onnx")
2454
print("-- test_onnx_generate: export model")
@@ -29,13 +59,24 @@ def test_onnx_generate(self):
2959
kwargs=inputs,
3060
dynamic_shapes=ds,
3161
filename=model_name,
62+
exporter="custom",
3263
)
3364

3465
print("-- test_onnx_generate: generate")
3566
res = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
67+
n_inputs = input_ids.shape[1]
68+
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
3669
self.assertEqual(res.dtype, torch.int64)
3770
self.assertEqual(res.shape, (1, 13))
3871
print("-- test_onnx_generate: done")
72+
# expected = model.generate(input_ids[:1], max_new_tokens=10)
73+
expected = self.simple_generate_with_cache(model, input_ids[:1], 2, max_new_tokens=10)
74+
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
75+
print("******", res)
76+
print("******", expected)
77+
self.assertEqual(expected.dtype, torch.int64)
78+
self.assertEqual(expected.shape, (1, 13))
79+
self.assertEqualArray(expected, res)
3980

4081

4182
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def test_automatic_speech_recognition_float32(self):
113113
"cache_position": {0: "seq_length"},
114114
"encoder_outputs": [{0: "batch"}],
115115
"past_key_values": [
116-
[[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]],
117-
[[{0: "batch"}, {0: "batch"}], [{0: "batch"}, {0: "batch"}]],
116+
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
117+
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
118118
],
119119
},
120120
ds,

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_1(self):
4848
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
4949
"input_ids": {0: "batch_size", 1: "sequence_length"},
5050
"past_key_values": [
51-
[{0: "batch_size", 2: "past_sequence_length"}],
52-
[{0: "batch_size", 2: "past_sequence_length"}],
51+
{0: "batch_size", 2: "past_sequence_length"},
52+
{0: "batch_size", 2: "past_sequence_length"},
5353
],
5454
"position_ids": {0: "batch_size", 1: "sequence_length"},
5555
},
@@ -98,8 +98,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
9898
"attention_mask": {0: "batch_size", 1: "sequence_length"},
9999
"input_ids": {0: "batch_size", 1: "sequence_length"},
100100
"past_key_values": [
101-
[{0: "batch_size", 2: "past_sequence_length"}],
102-
[{0: "batch_size", 2: "past_sequence_length"}],
101+
{0: "batch_size", 2: "past_sequence_length"},
102+
{0: "batch_size", 2: "past_sequence_length"},
103103
],
104104
"position_ids": {0: "batch_size", 1: "sequence_length"},
105105
},

onnx_diagnostic/export/shape_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ def make_fake_with_dynamic_dimensions(
260260
"attention_mask": {0: "batch", 1: "cache+seq"},
261261
"position_ids": {0: "batch", 1: "seq_length"},
262262
"past_key_values": [
263-
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
264-
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
263+
{0: "batch", 2: "cache_length"},
264+
{0: "batch", 2: "cache_length"},
265+
{0: "batch", 2: "cache_length"},
266+
{0: "batch", 2: "cache_length"},
265267
],
266268
},
267269
)

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def onnx_generate(
206206
), f"Only text generation is supported but input_names == {input_names}"
207207

208208
# First call: prefill
209-
input_feeds = dict(
209+
feeds = dict(
210210
input_ids=input_ids,
211211
attention_mask=torch.ones(
212212
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
@@ -216,9 +216,9 @@ def onnx_generate(
216216
new_shape = tuple(
217217
_get_dim(i, s, batch=input_ids.shape[0]) for i, s in enumerate(shape)
218218
)
219-
input_feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
219+
feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
220220

221-
outputs = session.run(None, input_feeds)
221+
outputs = session.run(None, feeds)
222222

223223
# Next calls: decode
224224
for _ in range(max_new_tokens):
@@ -241,7 +241,7 @@ def onnx_generate(
241241
),
242242
)
243243
feeds.update(dict(zip(input_names[2:], outputs[1:])))
244-
outputs = session.run(None, input_feeds)
244+
outputs = session.run(None, feeds)
245245

246246
if return_session:
247247
return input_ids, session

onnx_diagnostic/torch_models/untrained/llm_phi2.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ def get_phi2(
8484
0: batch,
8585
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
8686
},
87-
"past_key_values": [
88-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
89-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
90-
],
87+
"past_key_values": [{0: batch, 2: cache_length} for _ in range(n_layers * 2)],
9188
}
9289
inputs = dict(
9390
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(

0 commit comments

Comments
 (0)