Skip to content

Commit eabc4ef

Browse files
committed
fix position ids
1 parent 727dcac commit eabc4ef

File tree

9 files changed

+224
-21
lines changed

9 files changed

+224
-21
lines changed

_doc/api/torch_models/llms.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ onnx_diagnostic.torch_models.llms
33
=================================
44

55
.. automodule:: onnx_diagnostic.torch_models.llms
6-
:members: get_tiny_llm
6+
:members: get_phi2, get_tiny_llm
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
4+
from onnx_diagnostic.torch_models.llms import get_phi2
5+
from onnx_diagnostic.helpers import string_type
6+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
7+
8+
9+
class TestLlmPhi(ExtTestCase):
10+
def test_get_phi2(self):
11+
data = get_phi2(num_hidden_layers=2)
12+
model, inputs = data["model"], data["inputs"]
13+
self.assertIn("DynamicCache", string_type(inputs))
14+
model(**inputs)
15+
16+
@ignore_warnings(UserWarning)
17+
@requires_transformers("4.52")
18+
def test_export_phi2_1(self):
19+
data = get_phi2(num_hidden_layers=2)
20+
model, inputs = data["model"], data["inputs"]
21+
self.assertEqual(
22+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
23+
)
24+
ep = torch.export.export(
25+
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
26+
)
27+
assert ep
28+
29+
@ignore_warnings(UserWarning)
30+
def test_export_phi2_2_bypassed(self):
31+
data = get_phi2(num_hidden_layers=2)
32+
model, inputs = data["model"], data["inputs"]
33+
self.assertEqual(
34+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
35+
)
36+
with bypass_export_some_errors(patch_transformers=True) as modificator:
37+
inputs = modificator(inputs)
38+
ep = torch.export.export(
39+
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
40+
)
41+
assert ep
42+
43+
44+
if __name__ == "__main__":
45+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
77

88

9-
class TestLlms(ExtTestCase):
9+
class TestTinyLlm(ExtTestCase):
1010
def test_get_tiny_llm(self):
1111
data = get_tiny_llm()
1212
model, inputs = data["model"], data["inputs"]
@@ -18,7 +18,9 @@ def test_get_tiny_llm(self):
1818
def test_export_tiny_llm_1(self):
1919
data = get_tiny_llm()
2020
model, inputs = data["model"], data["inputs"]
21-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
21+
self.assertEqual(
22+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
23+
)
2224
ep = torch.export.export(
2325
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
2426
)
@@ -28,11 +30,34 @@ def test_export_tiny_llm_1(self):
2830
def test_export_tiny_llm_2_bypassed(self):
2931
data = get_tiny_llm()
3032
model, inputs = data["model"], data["inputs"]
31-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
32-
with bypass_export_some_errors(patch_transformers=True) as modificator:
33+
self.assertEqual(
34+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
35+
)
36+
37+
with bypass_export_some_errors(
38+
patch_torch=False, patch_transformers=True, catch_constraints=False
39+
) as modificator:
3340
inputs = modificator(inputs)
41+
42+
def debug():
43+
print("***", string_type(inputs, with_shape=True))
44+
print("***", data["dynamic_shapes"])
45+
import torch.export._draft_export
46+
47+
ep, report = torch.export._draft_export.draft_export(
48+
model,
49+
(),
50+
kwargs=inputs,
51+
dynamic_shapes=data["dynamic_shapes"],
52+
strict=False,
53+
)
54+
print(report)
55+
56+
if self._debug():
57+
debug()
58+
3459
ep = torch.export.export(
35-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
60+
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
3661
)
3762
assert ep
3863

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1212

1313
try:
14-
from experimental_experiment.torch_interpreter import to_onnx
14+
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1515
except ImportError:
1616
to_onnx = None
1717

1818

19-
class TestLlmsOnnx(ExtTestCase):
19+
class TestTinyLlmOnnx(ExtTestCase):
2020
@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
2121
@requires_transformers("4.50.9999")
2222
@hide_stdout()
2323
def test_onnx_export_tiny_llm_official(self):
2424
data = get_tiny_llm()
2525
model, inputs = data["model"], data["inputs"]
26-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
26+
self.assertEqual(
27+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
28+
)
2729
ep = torch.onnx.export(
2830
model,
2931
(),
@@ -43,7 +45,9 @@ def test_onnx_export_tiny_llm_official(self):
4345
def test_onnx_export_tiny_llm_xdbg(self):
4446
data = get_tiny_llm()
4547
model, inputs = data["model"], data["inputs"]
46-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
48+
self.assertEqual(
49+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
50+
)
4751
onx = to_onnx(
4852
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], verbose=1
4953
)
@@ -56,7 +60,9 @@ def test_onnx_export_tiny_llm_xdbg(self):
5660
def test_bypass_onnx_export_tiny_llm_official(self):
5761
data = get_tiny_llm()
5862
model, inputs = data["model"], data["inputs"]
59-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
63+
self.assertEqual(
64+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
65+
)
6066
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
6167
new_inputs = modificator(inputs)
6268
ep = torch.onnx.export(
@@ -77,11 +83,18 @@ def test_bypass_onnx_export_tiny_llm_official(self):
7783
def test_bypass_onnx_export_tiny_llm_xdbg(self):
7884
data = get_tiny_llm()
7985
model, inputs = data["model"], data["inputs"]
80-
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
86+
self.assertEqual(
87+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
88+
)
8189
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
8290
new_inputs = modificator(inputs)
8391
onx = to_onnx(
84-
model, (), kwargs=new_inputs, dynamic_shapes=data["dynamic_shapes"], verbose=1
92+
model,
93+
(),
94+
kwargs=new_inputs,
95+
dynamic_shapes=data["dynamic_shapes"],
96+
verbose=1,
97+
export_options=ExportOptions(strict=False),
8598
)
8699
self.assert_onnx_disc(
87100
inspect.currentframe().f_code.co_name, onx, model, inputs, verbose=1

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.2.1"
6+
__version__ = "0.3.0"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/ext_test_case.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,3 +1090,7 @@ def assert_onnx_disc(
10901090
and not numpy.isnan(diff["rel"])
10911091
and diff["rel"] <= rtol
10921092
), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1093+
1094+
def _debug(self):
1095+
"Tells if DEBUG=1 is set up."
1096+
return os.environ.get("DEBUG") in BOOLEAN_VALUES
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .untrained.tiny_llm import get_tiny_llm
1+
from .untrained.llm_phi2 import get_phi2
2+
from .untrained.llm_tiny_llm import get_tiny_llm
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Any, Dict
2+
import torch
3+
import transformers
4+
from ...cache_helpers import make_dynamic_cache
5+
6+
7+
def get_phi2(
8+
batch_size: int = 1,
9+
sequence_length: int = 30,
10+
sequence_length2: int = 3,
11+
dynamic_rope: bool = False,
12+
**kwargs,
13+
) -> Dict[str, Any]:
14+
"""
15+
Gets a non initialized model
16+
similar to `microsoft/phi-2 <https://huggingface.co/microsoft/phi-2>`_
17+
18+
:param batch_size: batch size
19+
:param sequence_length: sequence length
20+
:param sequence_length2: new sequence length
21+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
22+
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
23+
:return: dictionary
24+
25+
See :ref:`l-plot-tiny-llm-export-patched` for an example with a similar model.
26+
"""
27+
config = {
28+
"_name_or_path": "microsoft/phi-2",
29+
"architectures": ["PhiForCausalLM"],
30+
"attention_dropout": 0.0,
31+
"bos_token_id": 50256,
32+
"embd_pdrop": 0.0,
33+
"eos_token_id": 50256,
34+
"hidden_act": "gelu_new",
35+
"hidden_size": 2560,
36+
"initializer_range": 0.02,
37+
"intermediate_size": 10240,
38+
"layer_norm_eps": 1e-05,
39+
"max_position_embeddings": 2048,
40+
"model_type": "phi",
41+
"num_attention_heads": 32,
42+
"num_hidden_layers": 32,
43+
"num_key_value_heads": 32,
44+
"partial_rotary_factor": 0.4,
45+
"qk_layernorm": False,
46+
"resid_pdrop": 0.1,
47+
"rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
48+
"rope_theta": 10000.0,
49+
"tie_word_embeddings": False,
50+
"torch_dtype": "float16",
51+
"transformers_version": "4.37.0",
52+
"use_cache": True,
53+
"vocab_size": 51200,
54+
}
55+
config.update(**kwargs)
56+
conf = transformers.PhiConfig(**config)
57+
model = transformers.PhiForCausalLM(conf)
58+
model.eval()
59+
60+
# now the inputs
61+
cache_last_dim = 80
62+
max_token_id = config["vocab_size"] - 1
63+
n_layers = config["num_hidden_layers"]
64+
num_key_value_heads = config["num_key_value_heads"]
65+
66+
batch = torch.export.Dim("batch", min=1, max=1024)
67+
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
68+
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
69+
70+
shapes = {
71+
"input_ids": {0: batch, 1: seq_length},
72+
"position_ids": {
73+
0: batch,
74+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
75+
},
76+
"attention_mask": {
77+
0: batch,
78+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
79+
},
80+
"past_key_values": [
81+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
82+
[{0: batch, 2: cache_length} for _ in range(n_layers)],
83+
],
84+
}
85+
inputs = dict(
86+
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
87+
torch.int64
88+
),
89+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
90+
torch.int64
91+
),
92+
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
93+
.to(torch.int64)
94+
.expand((batch_size, -1)),
95+
past_key_values=make_dynamic_cache(
96+
[
97+
(
98+
torch.randn(
99+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
100+
),
101+
torch.randn(
102+
batch_size, num_key_value_heads, sequence_length, cache_last_dim
103+
),
104+
)
105+
for i in range(n_layers)
106+
]
107+
),
108+
)
109+
return dict(inputs=inputs, model=model, dynamic_shapes=shapes)

onnx_diagnostic/torch_models/untrained/tiny_llm.py renamed to onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,23 @@
66

77
def get_tiny_llm(
88
batch_size: int = 2,
9+
sequence_length: int = 30,
10+
sequence_length2: int = 3,
911
dynamic_rope: bool = False,
1012
**kwargs,
1113
) -> Dict[str, Any]:
1214
"""
13-
Gets a non initialized model.
15+
Gets a non initialized model
16+
similar to `arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_
1417
1518
:param batch_size: batch size
19+
:param sequence_length: sequence length
20+
:param sequence_length2: new sequence length
1621
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
1722
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
1823
:return: dictionary
1924
20-
See :ref:`l-plot-tiny-llm-export` for an example.
25+
See :ref:`l-plot-tiny-llm-export` or :ref:`l-plot-tiny-llm-export-patched` for examples.
2126
"""
2227
config = {
2328
"architectures": ["LlamaForCausalLM"],
@@ -49,19 +54,20 @@ def get_tiny_llm(
4954

5055
# now the inputs
5156
cache_last_dim = 96
52-
sequence_length = 30
53-
sequence_length2 = 3
54-
num_key_value_heads = 1
5557
max_token_id = config["vocab_size"] - 1
5658
n_layers = config["num_hidden_layers"]
59+
num_key_value_heads = config["num_key_value_heads"]
5760

5861
batch = torch.export.Dim("batch", min=1, max=1024)
5962
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
6063
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
6164

6265
shapes = {
6366
"input_ids": {0: batch, 1: seq_length},
64-
"position_ids": {0: torch.export.Dim.DYNAMIC},
67+
"position_ids": {
68+
0: batch,
69+
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
70+
},
6571
"attention_mask": {
6672
0: batch,
6773
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length

0 commit comments

Comments
 (0)