Skip to content

Commit a37a6d5

Browse files
authored
device (#124)
1 parent d5712f5 commit a37a6d5

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
7272
self.type = type
7373
self.shape = shape
7474

75+
@classmethod
76+
def _on_cuda(cls, providers) -> int:
77+
if not providers:
78+
return -1
79+
for p in providers:
80+
if p == "CUDAExecutionProvider":
81+
return 0
82+
if isinstance(p, tuple) and p[0] == "CUDAExecutionProvider":
83+
return p[1]["device_id"]
84+
return -1
85+
7586
def __init__(
7687
self,
7788
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
@@ -86,12 +97,13 @@ def __init__(
8697
self.functions = local_functions.copy() if local_functions else {}
8798
self.CPU = torch.tensor([0]).to("cpu").device
8899
self.verbose = verbose
89-
if "CUDAExecutionProvider" in providers:
90-
self.CUDA = torch.tensor([0]).to("cuda").device
91-
self.default_device = self.CUDA
92-
else:
100+
dev = self._on_cuda(providers)
101+
if dev < 0:
93102
self.default_device = self.CPU
94103
self.CUDA = None
104+
else:
105+
self.CUDA = torch.tensor([0]).to(f"cuda:{dev}").device
106+
self.default_device = self.CUDA
95107

96108
if isinstance(proto, str):
97109
proto = onnx.load(proto)

0 commit comments

Comments
 (0)