Skip to content

Commit ea6e646

Browse files
committed
fix sbs
1 parent 78bb5de commit ea6e646

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all: false
4343
timeout: 2
4444
retry_count# : 2
45-
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/
45+
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
4747
# force_pass : true

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
473473
yield node
474474

475475
@classmethod
476-
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
476+
def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
477477
"""
478478
Returns the hidden inputs (inputs coming from an upper context)
479479
used by a subgraph.
@@ -490,21 +490,21 @@ def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
490490
hidden.add(i)
491491
for att in node.attribute:
492492
if att.type == AttributeProto.GRAPH and att.g:
493-
hid = self._get_hidden_inputs(att.g)
493+
hid = cls._get_hidden_inputs(att.g)
494494
less = set(h for h in hid if h not in memo)
495495
hidden |= less
496496
memo |= set(node.output)
497497
return hidden
498498

499499
@classmethod
500-
def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
500+
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
501501
"""Calls multiple _get_hidden_inputs on every attribute."""
502502
if node.op_type not in {"Loop", "Scan", "If"}:
503503
return set()
504504
hidden = set()
505505
for att in node.attribute:
506506
if att.type == AttributeProto.GRAPH:
507-
hidden |= self._get_hidden_inputs(att.g)
507+
hidden |= cls._get_hidden_inputs(att.g)
508508
return hidden - (hidden & set(node.input))
509509

510510
def _get_sess(

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..helpers.onnx_helper import pretty_onnx
1010
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
1111
from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
12-
from ..reference.ort_evaluator import OnnxList
12+
from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator
1313
from .sbs_dataclasses import (
1414
ReplayConfiguration,
1515
RunAlignedRecord,
@@ -176,10 +176,12 @@ def _loop_onnx_node(
176176

177177
ref = run_cls(node, **run_cls_kwargs)
178178
# We need to clone because the runtime maybe using dlpack to create OrtValue
179+
hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node)
180+
all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input
179181
feeds = (
180-
{k: onnx_results[k].clone() for k in node.input if k}
182+
{k: onnx_results[k].clone() for k in all_inputs if k}
181183
if use_tensor
182-
else {k: onnx_results[k].copy() for k in node.input if k}
184+
else {k: onnx_results[k].copy() for k in all_inputs if k}
183185
)
184186
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
185187
if verbose > 1:

0 commit comments

Comments
 (0)