Skip to content

Commit f7f1390

Browse files
committed
refactoring
1 parent 6eda5cc commit f7f1390

File tree

9 files changed

+614
-579
lines changed

9 files changed

+614
-579
lines changed

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ onnx_diagnostic.helpers
2222
onnx_helper
2323
ort_session
2424
rt_helper
25+
torch_fx_graph_helper
2526
torch_helper
2627

2728
.. autofunction:: onnx_diagnostic.helpers.flatten_object
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.torch_fx_graph_helper
3+
=============================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.torch_fx_graph_helper
6+
:members:
7+
:no-undoc-members:

_doc/api/torch_onnx/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ onnx_diagnostic.torch_onnx
77

88
runtime_info
99
sbs
10+
sbs_dataclasses
1011

1112
.. automodule:: onnx_diagnostic.torch_onnx
1213
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_onnx.sbs
3+
==============================
4+
5+
.. automodule:: onnx_diagnostic.torch_onnx.sbs
6+
:members:
7+
:no-undoc-members:
8+

_unittests/ut_tasks/try_export.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
3434
use_cache:bool,
3535
return_dict:bool
3636
)
37+
38+
.. code-block:: bash
39+
40+
QWEN25ATTENTION=BIGMASK \\
41+
PRETRAINED=1 \\
42+
TESTDEVICE=cuda \\
43+
TESTDTYPE=float16 \\
44+
EXPORTER=custom \\
45+
NEVERTEST=1 \\
46+
DROPPATTERN=SkipSimplifiedLayerNormalizationMulPattern,SkipSimplifiedLayerNormalizationPattern \\
47+
python _unittests/ut_tasks/try_export.py -k qwen_2_5_vl_instruct_visual
3748
"""
3849
device = os.environ.get("TESTDEVICE", "cpu")
3950
dtype = os.environ.get("TESTDTYPE", "float32")
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from python import Any, Dict, Optional, Tuple
2+
import torch
3+
from .helper import string_type
4+
5+
6+
def validate_fx_tensor(
7+
node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
8+
) -> None:
9+
"""
10+
Validates the shape of tensor is expected.
11+
12+
:param node: node
13+
:param tensor: tensor
14+
:param expected_shape: expected shape
15+
"""
16+
assert len(tensor.shape) == len(expected_shape), (
17+
f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
18+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
19+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
20+
f"node.meta={node.meta}"
21+
)
22+
for a, b in zip(tensor.shape, expected_shape):
23+
assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
24+
f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
25+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
26+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
27+
f"node.meta={node.meta}"
28+
)
29+
30+
31+
def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
32+
"""
33+
Validates the outputs of a node using metadata stored in the node.
34+
35+
:param node: node
36+
:param outputs: outputs
37+
"""
38+
if "val" not in node.meta:
39+
return
40+
if isinstance(outputs, torch.Tensor):
41+
validate_fx_tensor(node, outputs, node.meta["val"].shape)
42+
return
43+
if isinstance(outputs, (tuple, list)):
44+
assert isinstance(node.meta["val"], (list, tuple)), (
45+
f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
46+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
47+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
48+
f"node.meta={node.meta}"
49+
)
50+
assert len(outputs) == len(node.meta["val"]), (
51+
f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
52+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
53+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
54+
f"node.meta={node.meta}"
55+
)
56+
for a, b in zip(outputs, node.meta["val"]):
57+
validate_fx_tensor(node, a, b.shape)
58+
return
59+
if isinstance(outputs, int):
60+
assert (
61+
isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
62+
or outputs == node.meta["val"]
63+
), (
64+
f"Int mismatch, got {outputs} expected {node.meta['val']}, "
65+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
66+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
67+
f"node.meta={node.meta}"
68+
)
69+
return
70+
if outputs is None:
71+
assert node.meta["val"] is None, (
72+
f"None mismatch, got {outputs} expected {node.meta['val']}, "
73+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
74+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
75+
f"node.meta={node.meta}"
76+
)
77+
return
78+
raise NotImplementedError(
79+
f"Validation for output type {type(outputs)} is not implemented, "
80+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
81+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
82+
f"node.meta={node.meta}"
83+
)
84+
85+
86+
def run_fx_node(
87+
node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
88+
) -> Tuple[Any, ...]:
89+
"""
90+
Executes a node
91+
92+
:param node: runs a node
93+
:param args: unnamed inputs to the node
94+
:param kwargs: named inputs to the node
95+
:return: results
96+
"""
97+
if node.op == "output":
98+
assert len(args) == 1 and not kwargs, (
99+
f"Unexpected inputs: args={string_type(args, limit=20)} "
100+
f"kwargs={string_type(kwargs, limit=20)}"
101+
)
102+
return args
103+
if node.op == "call_function":
104+
assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
105+
for a, ea in zip(args, node.args):
106+
if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
107+
ta = ea.meta["val"]
108+
assert (
109+
isinstance(ta, torch.Tensor)
110+
and len(a.shape) == len(ta.shape)
111+
and a.dtype == ta.dtype
112+
), (
113+
f"Unable to run node {node!r}, target={node.target!r}, "
114+
f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
115+
f"args={string_type(args, with_shape=True, with_device=True)}, "
116+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
117+
)
118+
try:
119+
outputs = node.target(*args, **(kwargs or {}))
120+
except RuntimeError as e:
121+
raise RuntimeError(
122+
f"Unable to run node {node!r}, target={node.target!r}, "
123+
f"args={string_type(args, with_shape=True, with_device=True)}, "
124+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
125+
) from e
126+
validate_fx_outputs(node, outputs)
127+
return outputs
128+
raise NotImplementedError(
129+
f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
130+
)
131+
132+
133+
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
134+
"See :func:`prepare_args_kwargs`."
135+
if isinstance(ref, torch.fx.Node):
136+
return torch_results[ref.name]
137+
if isinstance(ref, list):
138+
return [_pick_result(torch_results, n) for n in ref]
139+
if isinstance(ref, tuple):
140+
return tuple(_pick_result(torch_results, n) for n in ref)
141+
if isinstance(ref, dict):
142+
return {k: _pick_result(torch_results, v) for k, v in ref.items()}
143+
if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
144+
return ref
145+
if ref is None:
146+
return None
147+
if isinstance(ref, torch.layout):
148+
return ref
149+
raise NotImplementedError(f"Unable to process args type {type(ref)}")
150+
151+
152+
def prepare_args_kwargs(
153+
torch_results: Dict[str, Any], node: torch.fx.Node
154+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
155+
"""
156+
Prepares args and kwargs before executing a fx node.
157+
158+
:param torch_results: existing results
159+
:param node: node to execute
160+
:return: new args and kwargs
161+
"""
162+
new_args = _pick_result(torch_results, node.args)
163+
new_kwargs = _pick_result(torch_results, node.kwargs)
164+
return new_args, new_kwargs

0 commit comments

Comments
 (0)