Skip to content

Commit 830f584

Browse files
committed
add option drop_names
1 parent 8f66b33 commit 830f584

File tree

4 files changed

+121
-15
lines changed

4 files changed

+121
-15
lines changed

.github/workflows/check-urls.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ jobs:
4343
timeout: 2
4444
retry_count# : 2
4545
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://codecov.io/
4747
# force_pass : true

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import copy
22
import unittest
33
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
4-
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task, validate_model
4+
from onnx_diagnostic.torch_models.test_helper import (
5+
get_inputs_for_task,
6+
validate_model,
7+
filter_inputs,
8+
)
59
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
610

711

@@ -63,6 +67,27 @@ def test_validate_model_onnx(self):
6367
self.assertIsInstance(data, dict)
6468
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
6569

70+
def test_filter_inputs(self):
71+
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}
72+
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"])
73+
self.assertEqual((ni, nd), ({"b": 2}, {"b": 30}))
74+
75+
inputs, ds = (1, 2), {"a": 20, "b": 30}
76+
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["b"], model=["a", "b"])
77+
self.assertEqual((ni, nd), ((1, None), {"a": 20}))
78+
79+
inputs, ds = (1, 2), (20, 30)
80+
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["b"], model=["a", "b"])
81+
self.assertEqual((ni, nd), ((1, None), (20, None)))
82+
83+
inputs, ds = ((1,), {"b": 4}), {"a": 20, "b": 30}
84+
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["b"], model=["a", "b"])
85+
self.assertEqual((ni, nd), ((1,), {"a": 20}))
86+
87+
inputs, ds = ((1,), {"b": 4}), {"a": 20, "b": 30}
88+
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"], model=["a", "b"])
89+
self.assertEqual((ni, nd), (((None,), {"b": 4}), {"b": 30}))
90+
6691

6792
if __name__ == "__main__":
6893
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def get_parser_validate() -> ArgumentParser:
276276
help="if not empty, a folder is created to dumps statistics, "
277277
"exported program, onnx...",
278278
)
279+
parser.add_argument(
280+
"--drop",
281+
help="drops the following inputs names, it should be a list "
282+
"with comma separated values",
283+
)
279284
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
280285
parser.add_argument("--dtype", help="changes dtype if necessary")
281286
parser.add_argument("--device", help="changes the device if necessary")
@@ -317,6 +322,7 @@ def _cmd_validate(argv: List[Any]):
317322
optimization=args.opt,
318323
exporter=args.export,
319324
dump_folder=args.dump_folder,
325+
drop_inputs=None if not args.drop else args.drop.split(","),
320326
)
321327
print("")
322328
print("-- summary --")

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import inspect
12
import os
2-
from typing import Any, Dict, Optional, Tuple, Union
3+
from typing import Any, Dict, List, Optional, Tuple, Union
34
import time
45
import torch
56
from ..helpers import max_diff, string_type, string_diff
@@ -45,6 +46,69 @@ def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, An
4546
return f(model=None, config=config, **kwargs)
4647

4748

49+
def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
50+
"""Splits into args, kwargs."""
51+
if isinstance(inputs, dict):
52+
return (), inputs
53+
if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[1], dict):
54+
return inputs
55+
assert isinstance(inputs, tuple), f"Unexpected inputs {string_type(inputs)}"
56+
return inputs, {}
57+
58+
59+
def make_inputs(
60+
args: Optional[Tuple[Any, ...]], kwargs: Optional[Dict[str, Any]] = None
61+
) -> Any:
62+
"""Returns either args, kwargs or both depending on which ones are empty."""
63+
assert args or kwargs, "No input was given."
64+
if not args:
65+
return kwargs
66+
if not kwargs:
67+
return args
68+
return args, kwargs
69+
70+
71+
def filter_inputs(
72+
inputs: Any,
73+
drop_names: List[str],
74+
model: Optional[Union[torch.nn.Module, List[str]]] = None,
75+
dynamic_shapes: Optional[Any] = None,
76+
):
77+
"""
78+
Drops some inputs from the given inputs.
79+
It updates the dynamic shapes as well.
80+
"""
81+
args, kwargs = split_args_kwargs(inputs)
82+
set_drop_names = set(drop_names)
83+
kwargs = {k: v for k, v in kwargs.items() if k not in set_drop_names}
84+
dyn = (
85+
{k: v for k, v in dynamic_shapes.items() if k not in set_drop_names}
86+
if dynamic_shapes and isinstance(dynamic_shapes, dict)
87+
else dynamic_shapes
88+
)
89+
if not args or all(i in kwargs for i in set_drop_names):
90+
return make_inputs(args, kwargs), dyn
91+
assert model, (
92+
f"we need the model to get the parameter name but model is None, "
93+
f"input_names={drop_names} and args={string_type(args)}"
94+
)
95+
pnames = (
96+
list(inspect.signature(model.forward).parameters)
97+
if isinstance(model, torch.nn.Module)
98+
else model
99+
)
100+
new_args = []
101+
new_ds = []
102+
for i, a in enumerate(args):
103+
if isinstance(dynamic_shapes, tuple):
104+
new_ds.append(None if pnames[i] in set_drop_names else dynamic_shapes[i])
105+
new_args.append(None if pnames[i] in set_drop_names else a)
106+
new_inputs = make_inputs(tuple(new_args), kwargs)
107+
if new_ds:
108+
return new_inputs, tuple(new_ds)
109+
return new_inputs, dyn
110+
111+
48112
def validate_model(
49113
model_id: str,
50114
task: Optional[str] = None,
@@ -59,6 +123,7 @@ def validate_model(
59123
quiet: bool = False,
60124
patch: bool = False,
61125
dump_folder: Optional[str] = None,
126+
drop_inputs: Optional[List[str]] = None,
62127
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
63128
"""
64129
Validates a model.
@@ -80,6 +145,7 @@ def validate_model(
80145
:param quiet: if quiet, catches exception if any issue
81146
:param patch: applies patches before exporting
82147
:param dump_folder: dumps everything in a subfolder of this one
148+
:param drop_inputs: drops this list of inputs (given their names)
83149
:return: two dictionaries, one with some metrics,
84150
another one with whatever the function produces
85151
"""
@@ -112,6 +178,27 @@ def validate_model(
112178
else:
113179
data = get_untrained_model_with_inputs(model_id, verbose=verbose, task=task)
114180

181+
if drop_inputs:
182+
if verbose:
183+
print(f"[validate_model] drop inputs {drop_inputs!r}")
184+
print(f"[validate_model] current inputs: {string_type(data["inputs"])}")
185+
print(
186+
f"[validate_model] current dynnamic_shapes: "
187+
f"{_ds_clean(data["dynamic_shapes"])}"
188+
)
189+
data["inputs"], data["dynamic_shapes"] = filter_inputs(
190+
data["inputs"],
191+
drop_names=drop_inputs,
192+
model=data["model"],
193+
dynamic_shapes=data["dynamic_shapes"],
194+
)
195+
if verbose:
196+
print(f"[validate_model] new inputs: {string_type(data["inputs"])}")
197+
print(
198+
f"[validate_model] new dynnamic_shapes: "
199+
f"{_ds_clean(data["dynamic_shapes"])}"
200+
)
201+
115202
if not empty(dtype):
116203
if isinstance(dtype, str):
117204
dtype = getattr(torch, dtype)
@@ -338,18 +425,6 @@ def call_exporter(
338425
)
339426

340427

341-
def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
342-
"""
343-
Splits into args, kwargs.
344-
"""
345-
if isinstance(inputs, dict):
346-
return (), inputs
347-
if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[1], dict):
348-
return inputs
349-
assert isinstance(inputs, tuple), f"Unexpected inputs {string_type(inputs)}"
350-
return inputs, {}
351-
352-
353428
def call_torch_export_export(
354429
data: Dict[str, Any],
355430
exporter: str,

0 commit comments

Comments
 (0)