Skip to content

Commit 84eda70

Browse files
committed
ut
1 parent 7a2f386 commit 84eda70

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

_scripts/test_backend_onnxruntime.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ def run(self, inputs, **kwargs):
2626
if isinstance(inputs, numpy.ndarray):
2727
inputs = [inputs]
2828
if isinstance(inputs, list):
29-
if len(inputs) == len(self._session.input_names):
30-
feeds = dict(zip(self._session.input_names, inputs))
29+
if len(inputs) == len(self._session.get_inputs()):
30+
feeds = dict(zip([i.name for i in self._session.get_inputs()], inputs))
3131
else:
32+
input_names = [i.name for i in self._session.get_inputs()]
3233
feeds = {}
3334
pos_inputs = 0
34-
for inp, tshape in zip(self._session.input_names, self._session.input_types):
35+
for inp, tshape in zip(input_names, self._session.input_types):
3536
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
3637
if shape == inputs[pos_inputs].shape:
3738
feeds[inp] = inputs[pos_inputs]
@@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool:
5455
@classmethod
5556
def supports_device(cls, device: str) -> bool:
5657
d = Device(device)
57-
if d == DeviceType.CPU:
58+
if d.type == DeviceType.CPU:
5859
return True
59-
if d == DeviceType.CUDA:
60-
import torch
61-
62-
return torch.cuda.is_available()
60+
# if d.type == DeviceType.CUDA:
61+
# import torch
62+
#
63+
# return torch.cuda.is_available()
6364
return False
6465

6566
@classmethod
6667
def create_inference_session(cls, model, device):
6768
d = Device(device)
68-
if d == DeviceType.CUDA:
69+
if d.type == DeviceType.CUDA:
6970
providers = ["CUDAExecutionProvider"]
70-
elif d == DeviceType.CPU:
71+
elif d.type == DeviceType.CPU:
7172
providers = ["CPUExecutionProvider"]
7273
else:
7374
raise ValueError(f"Unrecognized device {device!r} or {d!r}")

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def is_compatible(cls, model) -> bool:
5050
@classmethod
5151
def supports_device(cls, device: str) -> bool:
5252
d = Device(device)
53-
if d == DeviceType.CPU:
53+
if d.type == DeviceType.CPU:
5454
return True
55-
if d == DeviceType.CUDA:
55+
if d.type == DeviceType.CUDA:
5656
import torch
5757

5858
return torch.cuda.is_available()

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
154154
self.assertEqualArray(causal_mask, ep.moule(*inputs))
155155

156156
@requires_torch("2.8")
157-
@requires_transformers("4.52")
157+
@requires_transformers("4.53")
158158
def test_vmap_transformers_scenario_novmap(self):
159159
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
160160
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,

0 commit comments

Comments
 (0)