@@ -26,12 +26,13 @@ def run(self, inputs, **kwargs):
2626 if isinstance (inputs , numpy .ndarray ):
2727 inputs = [inputs ]
2828 if isinstance (inputs , list ):
29- if len (inputs ) == len (self ._session .input_names ):
30- feeds = dict (zip (self ._session .input_names , inputs ))
29+ if len (inputs ) == len (self ._session .get_inputs () ):
30+ feeds = dict (zip ([ i . name for i in self ._session .get_inputs ()] , inputs ))
3131 else :
32+ input_names = [i .name for i in self ._session .get_inputs ()]
3233 feeds = {}
3334 pos_inputs = 0
34- for inp , tshape in zip (self . _session . input_names , self ._session .input_types ):
35+ for inp , tshape in zip (input_names , self ._session .input_types ):
3536 shape = tuple (d .dim_value for d in tshape .tensor_type .shape .dim )
3637 if shape == inputs [pos_inputs ].shape :
3738 feeds [inp ] = inputs [pos_inputs ]
@@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool:
5455 @classmethod
5556 def supports_device (cls , device : str ) -> bool :
5657 d = Device (device )
57- if d == DeviceType .CPU :
58+ if d . type == DeviceType .CPU :
5859 return True
59- if d == DeviceType .CUDA :
60- import torch
61-
62- return torch .cuda .is_available ()
60+ # if d.type == DeviceType.CUDA:
61+ # import torch
62+ #
63+ # return torch.cuda.is_available()
6364 return False
6465
6566 @classmethod
6667 def create_inference_session (cls , model , device ):
6768 d = Device (device )
68- if d == DeviceType .CUDA :
69+ if d . type == DeviceType .CUDA :
6970 providers = ["CUDAExecutionProvider" ]
70- elif d == DeviceType .CPU :
71+ elif d . type == DeviceType .CPU :
7172 providers = ["CPUExecutionProvider" ]
7273 else :
7374 raise ValueError (f"Unrecognized device { device !r} or { d !r} " )
0 commit comments