66from torch ._C import _from_dlpack
77import onnxruntime
88from onnxruntime .capi import _pybind_state as ORTC
9+ from .helpers import (
10+ torch_dtype_to_onnx_dtype ,
11+ onnx_dtype_to_np_dtype ,
12+ np_dtype_to_tensor_dtype ,
13+ onnx_dtype_name ,
14+ size_type ,
15+ )
916
1017DEVICES = {- 1 : ORTC .OrtDevice (ORTC .OrtDevice .cpu (), ORTC .OrtDevice .default_memory (), 0 )}
1118
@@ -48,7 +55,14 @@ def __init__(
4855 ):
4956 # onnxruntime is importing when needed as it takes a
5057 # couple of seconds if it contains CUDA EP.
58+ can_use_training_api = True
5159 if isinstance (sess , (onnx .ModelProto , str )):
60+ if isinstance (sess , onnx .ModelProto ):
61+ for i in sess .graph .initializer :
62+ if i .data_type >= onnx .TensorProto .BFLOAT16 :
63+ # Cannot use training api as it relies too much on numpy.
64+ can_use_training_api = False
65+ break
5266 assert session_options is None or (
5367 providers is None
5468 and graph_optimization_level is None
@@ -113,7 +127,7 @@ def __init__(
113127 if log_verbosity_level is not None :
114128 self .run_options .log_verbosity_level = log_verbosity_level
115129
116- self .use_training_api = (
130+ self .use_training_api = can_use_training_api and (
117131 self .has_onnxruntime_training () if use_training_api is None else use_training_api
118132 )
119133
@@ -176,7 +190,76 @@ def run(
176190 self , output_names : Optional [List [str ]], feeds : Dict [str , npt .ArrayLike ]
177191 ) -> List [npt .ArrayLike ]:
178192 """Calls :meth:`onnxruntime.InferenceSession.run`."""
179- return self .sess .run (output_names , feeds )
193+ if any (
194+ (np_dtype_to_tensor_dtype (v .dtype ) >= onnx .TensorProto .BFLOAT16 )
195+ for v in feeds .values ()
196+ ):
197+ # bfloat16 not supported by onnxruntime
198+ return self .run_dlpack (output_names , feeds )
199+ if self .nvtx :
200+ self .torch .cuda .nvtx .range_push ("run" )
201+ res = self .sess .run (output_names , feeds )
202+ if self .nvtx :
203+ self .torch .cuda .nvtx .range_pop ()
204+ return res
205+
206+ def run_dlpack (
207+ self , output_names : Optional [List [str ]], feeds : Dict [str , np .ndarray ]
208+ ) -> Tuple [torch .Tensor , ...]:
209+ """
210+ Same as :meth:`onnxruntime.InferenceSession.run` except that
211+ feeds is a dictionary of :class:`np.ndarray`.
212+ The output device is CPU even if the outputs are on CUDA.
213+ """
214+ new_feeds = {}
215+ for k , v in feeds .items ():
216+ new_feeds [k ] = ORTC .OrtValue .ortvalue_from_numpy_with_onnx_type (
217+ v , np_dtype_to_tensor_dtype (v .dtype )
218+ )
219+ if self .nvtx :
220+ self .torch .cuda .nvtx .range_push ("run_with_ort_values" )
221+ ort_outputs = self .sess ._sess .run_with_ort_values (
222+ new_feeds , output_names or self .output_names , self .run_options
223+ )
224+ if self .nvtx :
225+ self .torch .cuda .nvtx .range_pop ()
226+ pth_outputs = self ._ortvalues_to_numpy_tensor (ort_outputs )
227+ return pth_outputs
228+
229+ def _ortvalues_to_numpy_tensor (
230+ self ,
231+ ortvalues : Union [List [ORTC .OrtValue ], ORTC .OrtValueVector ],
232+ ) -> Tuple [np .ndarray , ...]:
233+ if len (ortvalues ) == 0 :
234+ return tuple ()
235+
236+ if self .nvtx :
237+ self .torch .cuda .nvtx .range_push ("_ortvalues_to_numpy_tensor" )
238+ res = []
239+ for i in range (len (ortvalues )):
240+ if not ortvalues [i ].has_value ():
241+ res .append (None )
242+ continue
243+
244+ el_type = ortvalues [i ].element_type ()
245+ if el_type < onnx .TensorProto .BFLOAT16 :
246+ res .append (np .from_dlpack (ortvalues [i ]))
247+ continue
248+
249+ # no easy conversion, let's use torch
250+ tch = torch .from_dlpack (ortvalues [i ].to_dlpack ())
251+ size = size_type (el_type )
252+ assert size == 2 , f"Not implemented for type { onnx_dtype_name (el_type )} "
253+ it = torch .uint16
254+ itch = tch .view (it )
255+ npt = itch .numpy ()
256+
257+ dtype = onnx_dtype_to_np_dtype (el_type )
258+ res .append (npt .view (dtype ))
259+
260+ if self .nvtx :
261+ self .torch .cuda .nvtx .range_pop ()
262+ return tuple (res )
180263
181264
182265class InferenceSessionForTorch (_InferenceSession ):
@@ -225,33 +308,6 @@ def __init__(
225308 use_training_api = use_training_api ,
226309 )
227310
228- self .TORCH_DTYPE_TO_ONNX_DTYPE = {
229- torch .float16 : onnx .TensorProto .FLOAT16 ,
230- torch .bfloat16 : onnx .TensorProto .BFLOAT16 ,
231- torch .float32 : onnx .TensorProto .FLOAT ,
232- torch .float64 : onnx .TensorProto .DOUBLE ,
233- torch .uint32 : onnx .TensorProto .UINT32 ,
234- torch .uint16 : onnx .TensorProto .UINT16 ,
235- torch .uint8 : onnx .TensorProto .UINT8 ,
236- torch .int8 : onnx .TensorProto .INT8 ,
237- torch .int16 : onnx .TensorProto .INT16 ,
238- torch .int32 : onnx .TensorProto .INT32 ,
239- torch .int64 : onnx .TensorProto .INT64 ,
240- torch .bool : onnx .TensorProto .BOOL ,
241- }
242-
243- self .TORCH_DTYPE_TO_NUMPY_DTYPE = {
244- torch .float16 : np .float16 ,
245- torch .float32 : np .float32 ,
246- torch .float64 : np .float64 ,
247- torch .uint8 : np .uint8 ,
248- torch .int8 : np .int8 ,
249- torch .int16 : np .int16 ,
250- torch .int32 : np .int32 ,
251- torch .int64 : np .int64 ,
252- torch .bool : np .bool_ ,
253- }
254-
255311 def _get_ortvalues_from_torch_tensors (
256312 self , tensors : Tuple [torch .Tensor , ...], n_outputs : int
257313 ) -> Tuple [ORTC .OrtValueVector , List [onnxruntime .OrtDevice ]]:
@@ -269,7 +325,7 @@ def _get_ortvalues_from_torch_tensors(
269325 new_tensors = []
270326 for tensor in tensors :
271327 assert isinstance (tensor , self .torch .Tensor ), f"Unexpected type { type (tensor )} "
272- dtypes .append (self . TORCH_DTYPE_TO_NUMPY_DTYPE [ tensor .dtype ] )
328+ dtypes .append (onnx_dtype_to_np_dtype ( torch_dtype_to_onnx_dtype ( tensor .dtype )) )
273329 shapes .append (tensor .size ())
274330 data_ptrs .append (tensor .data_ptr ())
275331 d = tensor .get_device ()
0 commit comments