Skip to content

Commit f56929f

Browse files
committed
union
1 parent 6d19c02 commit f56929f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnx_diagnostic/export/validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import itertools
33
import time
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55
import torch
66
from ..helpers import string_type, max_diff, string_diff
77
from ..helpers.torch_test_helper import torch_deepcopy
@@ -103,7 +103,7 @@ def _get(a):
103103

104104

105105
def validate_ep(
106-
ep: torch.export.ExportedProgram,
106+
ep: Union[torch.nn.Module, torch.export.ExportedProgram],
107107
mod: Optional[torch.nn.Module] = None,
108108
args: Optional[Tuple[Any, ...]] = None,
109109
kwargs: Optional[Dict[str, Any]] = None,
@@ -131,7 +131,7 @@ def validate_ep(
131131
:param rtol: relative tolerance
132132
:return: dictionary with inputs, outputs and tolerance
133133
"""
134-
modep = ep.module()
134+
modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
135135

136136
results = [
137137
compare_modules(

0 commit comments

Comments
 (0)