Skip to content

Commit 5c55755

Browse files
committed
resolve conflicts
2 parents 94a9b10 + f86c55d commit 5c55755

File tree

8 files changed

+152
-52
lines changed

8 files changed

+152
-52
lines changed

CHANGELOGS.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Change Logs
22
===========
33

4+
0.7.12
5+
++++++
6+
7+
* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
8+
* :pr:`226`: fix input order for models created with modelbuilder
9+
410
0.7.11
511
++++++
612

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def linkcode_resolve(domain, info):
122122
("py:class", "CacheProcessor"),
123123
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
124124
("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"),
125+
("py:class", "MambaCache"),
125126
("py:class", "ModelProto"),
126127
("py:class", "Model"),
127128
("py:class", "Module"),

_doc/index.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,8 @@ The function replaces dynamic dimensions defined as strings by
239239
Older versions
240240
==============
241241

242+
* `0.7.12 <../v0.7.12/index.html>`_
242243
* `0.7.11 <../v0.7.11/index.html>`_
243-
* `0.7.10 <../v0.7.10/index.html>`_
244-
* `0.7.9 <../v0.7.9/index.html>`_
245244
* `0.6.3 <../v0.6.3/index.html>`_
246245
* `0.5.0 <../v0.5.0/index.html>`_
247246
* `0.4.4 <../v0.4.4/index.html>`_

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,18 @@ def test_k_filter_inputs(self):
195195
@ignore_warnings(FutureWarning)
196196
@requires_transformers("4.51")
197197
def test_l_validate_model_modelbuilder(self):
198-
mid = "arnir0/Tiny-LLM"
198+
mid = "microsoft/phi-2"
199199
summary, data = validate_model(
200200
mid,
201201
do_run=True,
202202
verbose=10,
203203
exporter="modelbuilder",
204204
dump_folder="dump_test/validate_model_modelbuilder",
205+
patch=True,
205206
)
206207
self.assertIsInstance(summary, dict)
207208
self.assertIsInstance(data, dict)
208-
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
209+
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
209210
onnx_filename = data["onnx_filename"]
210211
self.assertExists(onnx_filename)
211212

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.7.11"
6+
__version__ = "0.7.12"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import onnx
44
import torch
55
from .helper import string_type, flatten_object
6-
from .onnx_helper import dtype_to_tensor_dtype
76
from .cache_helper import is_cache_dynamic_registered
87

98

@@ -23,6 +22,7 @@ def make_feeds(
2322
use_numpy: bool = False,
2423
copy: bool = False,
2524
check_flatten: bool = True,
25+
is_modelbuilder: bool = False,
2626
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
2727
"""
2828
Serializes the inputs to produce feeds expected
@@ -35,10 +35,15 @@ def make_feeds(
3535
by ``OrtValue``
3636
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
3737
returns the same number of outputs
38+
:param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
39+
the past_key_values inputs to match the expected order, and get rid of position_ids.
3840
:return: feeds dictionary
3941
"""
40-
# position_ids is a special case because ModelBuilder does not usually use it.
41-
# We use types to detect the best inputs.
42+
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
43+
# because it's fued into rotary embedding in GQA.
44+
if is_modelbuilder and isinstance(inputs, dict):
45+
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
46+
4247
flat = flatten_object(inputs, drop_keys=True)
4348
assert (
4449
not check_flatten
@@ -76,39 +81,6 @@ def make_feeds(
7681
f"\n-- inputs={string_type(inputs, with_shape=True)}"
7782
f"\n-- names={names}"
7883
)
79-
if len(names) < len(flat) and (
80-
isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
81-
):
82-
83-
typed_names = (
84-
[(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
85-
if isinstance(proto, onnx.ModelProto)
86-
else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
87-
)
88-
89-
new_flat = []
90-
pos = 0
91-
for _name, dtype in typed_names:
92-
assert isinstance(
93-
dtype, int
94-
), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
95-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
96-
while dtype != itype:
97-
pos += 1
98-
if pos >= len(flat):
99-
break
100-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
101-
if pos >= len(flat):
102-
break
103-
new_flat.append(flat[pos])
104-
pos += 1
105-
assert len(new_flat) == len(names), (
106-
f"Unable to align expected input {names} with the given input, "
107-
f"type(proto)={type(proto)}"
108-
f"\n-- inputs: {string_type(inputs, with_shape=True)}"
109-
f"\n-- typed_names: {typed_names}"
110-
)
111-
flat = new_flat
11284

11385
if copy:
11486
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
@@ -122,4 +94,49 @@ def make_feeds(
12294
elif isinstance(i, float):
12395
i = np.array(i, dtype=np.float32)
12496
new_flat.append(i)
97+
98+
# NOTE: model builder has a different order for past_key_values
99+
# we need to reorder them to match the expected order
100+
if is_modelbuilder:
101+
# We assume that if "past_key_values" is in the names when it's
102+
# modelbuilder
103+
non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
104+
past_kv_names = [n for n in names if "past_key_values" in n]
105+
reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
106+
names = non_past_kv_input_names + reorder_past_kv_names
125107
return dict(zip(names, new_flat))
108+
109+
110+
def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
111+
"""
112+
Reorders the past_kvs for ModelBuilder to match the expected order
113+
by PyTorch exported models.
114+
115+
.. note::
116+
This function can take either the names or the actual tensors
117+
as long as they are in a list.
118+
119+
Conceptually,
120+
121+
From::
122+
123+
[past_key_values.0.key, past_key_values.0.value,
124+
past_key_values.1.key, past_key_values.1.value, ...]
125+
126+
To::
127+
128+
[past_key_values.0.key, past_key_values.1.key,
129+
..., past_key_values.0.value, past_key_values.1.value, ...]
130+
131+
:param past_kv: list of flattened inputs
132+
:return: reordered list of flattened inputs
133+
"""
134+
total_len = len(past_kv)
135+
if total_len % 2 != 0:
136+
raise ValueError("The length of past_key_values should be even.")
137+
keys = []
138+
values = []
139+
for i in range(0, total_len, 2):
140+
keys.append(past_kv[i])
141+
values.append(past_kv[i + 1])
142+
return keys + values

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def get_untrained_model_with_inputs(
189189
f"subfolder={subfolder!r}"
190190
)
191191
model = transformers.AutoModel.from_pretrained(
192-
model_id, subfolder=subfolder, trust_remote_code=True, **mkwargs
192+
model_id, subfolder=subfolder or "", trust_remote_code=True, **mkwargs
193193
)
194194
if verbose:
195195
print(

onnx_diagnostic/torch_models/validate.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..export import CoupleInputsDynamicShapes
1212
from ..helpers import max_diff, string_type, string_diff
1313
from ..helpers.helper import flatten_object
14-
from ..helpers.rt_helper import make_feeds
14+
from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
1515
from ..helpers.torch_helper import to_any, torch_deepcopy
1616
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
1717
from ..tasks import random_input_kwargs
@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264
return new_cfg
265265

266266

267-
def _preprocess_model_id(model_id, subfolder):
267+
def _preprocess_model_id(
268+
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
269+
) -> Tuple[str, Optional[str], bool, bool]:
268270
if subfolder or "//" not in model_id:
269-
return model_id, subfolder
271+
return model_id, subfolder, same_as_pretrained, use_pretrained
270272
spl = model_id.split("//")
273+
if spl[-1] == "pretrained":
274+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
271275
if spl[-1] in {"transformer", "vae"}:
272276
# known subfolder
273-
return "//".join(spl[:-1]), spl[-1]
274-
return model_id, subfolder
277+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
278+
return model_id, subfolder, same_as_pretrained, use_pretrained
275279

276280

277281
def validate_model(
@@ -384,7 +388,12 @@ def validate_model(
384388
if ``runtime == 'ref'``,
385389
``orteval10`` increases the verbosity.
386390
"""
387-
model_id, subfolder = _preprocess_model_id(model_id, subfolder)
391+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
392+
model_id,
393+
subfolder,
394+
same_as_pretrained=same_as_pretrained,
395+
use_pretrained=use_pretrained,
396+
)
388397
if isinstance(patch, bool):
389398
patch_kwargs = (
390399
dict(patch_transformers=True, patch_diffusers=True, patch=True)
@@ -812,6 +821,8 @@ def validate_model(
812821
)
813822
summary.update(summary_valid)
814823

824+
_compute_final_statistics(summary)
825+
815826
if verbose:
816827
print("[validate_model] -- done (final)")
817828
if dump_stats:
@@ -824,15 +835,24 @@ def validate_model(
824835
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
825836
"""Computes some statistics on the model itself."""
826837
onx = onnx.load(onnx_filename, load_external_data=False)
838+
cache_functions = {(f.domain, f.name): f for f in onx.functions}
839+
local_domains = set(f.domain for f in onx.functions)
827840

828841
def node_iter(proto):
829842
if isinstance(proto, onnx.ModelProto):
830-
yield from node_iter(proto.graph)
831843
for f in proto.functions:
832844
yield from node_iter(f)
845+
yield from node_iter(proto.graph)
833846
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
834847
for node in proto.node:
835848
yield node
849+
850+
# Let's inline the function
851+
key = node.domain, node.op_type
852+
if key in cache_functions:
853+
yield from node_iter(cache_functions[key])
854+
855+
# Let's continue
836856
for att in node.attribute:
837857
if att.type == onnx.AttributeProto.GRAPH:
838858
yield from node_iter(att.g)
@@ -850,6 +870,11 @@ def node_iter(proto):
850870
n_nodes += 1
851871
if proto.op_type != "Constant":
852872
n_nodes_nocst += 1
873+
if proto.domain in local_domains:
874+
key = "n_node_local_function"
875+
if key not in counts:
876+
counts[key] = 0
877+
counts[key] += 1
853878
else:
854879
key = f"n_node_initializer_{proto.data_type}"
855880

@@ -1298,7 +1323,13 @@ def _mk(key, flavour=flavour):
12981323
print(
12991324
f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
13001325
)
1301-
feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False)
1326+
feeds = make_feeds(
1327+
sess,
1328+
data[k_input],
1329+
use_numpy=True,
1330+
check_flatten=False,
1331+
is_modelbuilder=data["exporter"] == "modelbuilder",
1332+
)
13021333
if verbose:
13031334
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
13041335
summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
@@ -1318,6 +1349,13 @@ def _mk(key, flavour=flavour):
13181349
repeat=repeat,
13191350
warmup=warmup,
13201351
)
1352+
# NOTE: modelbuilder has different order on past_kv outputs
1353+
if data["exporter"] == "modelbuilder":
1354+
logits = got[:1]
1355+
past_key_values = got[1:]
1356+
reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
1357+
got = logits + reorder_past_key_values
1358+
13211359
if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
13221360
return summary, data
13231361

@@ -1358,7 +1396,7 @@ def call_torch_export_onnx(
13581396
:return: two dictionaries, one with some metrics,
13591397
another one with whatever the function produces
13601398
"""
1361-
available = {None, "", "ir", "os_ort"}
1399+
available = {None, "", "ir", "os_ort", "ir+default"}
13621400
assert (
13631401
optimization in available
13641402
), f"unexpected value for optimization={optimization}, available={available}"
@@ -1448,11 +1486,31 @@ def call_torch_export_onnx(
14481486
print(epo)
14491487
print("[call_torch_export_onnx] -- End of ONNXProgram")
14501488

1451-
if optimization in {"ir", "os_ort"}:
1489+
if optimization in {"ir", "os_ort", "ir+default"}:
14521490
if verbose:
14531491
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
14541492
if optimization == "ir":
14551493
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1494+
elif optimization == "ir+default":
1495+
import onnxscript
1496+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1497+
1498+
def _ir_default_opt(epo):
1499+
onnxscript.optimizer.optimize_ir(epo.model)
1500+
onx = epo.model_proto
1501+
# not very efficient
1502+
gr = GraphBuilder(
1503+
onx,
1504+
infer_shapes_options=True,
1505+
optimization_options=OptimizationOptions(patterns="default"),
1506+
)
1507+
cont = gr.to_onnx(large_model=True)
1508+
epo.model = cont.to_ir()
1509+
1510+
label, f_optim = "export_onnx_opt_ir_default", (
1511+
lambda epo=epo: _ir_default_opt(epo)
1512+
)
1513+
14561514
else:
14571515
import onnxscript
14581516
import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -1851,3 +1909,21 @@ def run_ort_fusion(
18511909
f"opt_ort_{model_type}_duration": duration,
18521910
f"opt_ort_{model_type}_duration_save": d,
18531911
}, {f"opt_ort_{model_type}": output_path}
1912+
1913+
1914+
def _compute_final_statistics(summary: Dict[str, Any]):
1915+
"""
1916+
Updates inline the list of statistics. It adds:
1917+
1918+
- speedup
1919+
"""
1920+
stats = {}
1921+
if (
1922+
"time_run_latency" in summary
1923+
and "time_run_onnx_ort_latency" in summary
1924+
and summary["time_run_onnx_ort_latency"] > 0
1925+
):
1926+
stats["stat_estimated_speedup_ort"] = (
1927+
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
1928+
)
1929+
summary.update(stats)

0 commit comments

Comments
 (0)