11import copy
22import unittest
3+ import numpy as np
4+ import onnx
35import torch
46import onnxruntime
7+ from onnxruntime .capi import _pybind_state as ORTC
58from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , ignore_warnings
69from onnx_diagnostic .helpers import max_diff
710from onnx_diagnostic .helpers .ort_session import (
1215from onnx_diagnostic .torch_export_patches import bypass_export_some_errors
1316from onnx_diagnostic .torch_models .llms import get_tiny_llm
1417from onnx_diagnostic .reference import ExtendedReferenceEvaluator
18+ from onnx_diagnostic .helpers .onnx_helper import np_dtype_to_tensor_dtype
1519
1620
1721class TestOrtSessionTinyLLM (ExtTestCase ):
1822
23+ def test_ort_value (self ):
24+ val = np .array ([30 , 31 , 32 ], dtype = np .int64 )
25+ ort = ORTC .OrtValue .ortvalue_from_numpy_with_onnx_type (val , onnx .TensorProto .INT64 )
26+ self .assertEqual (np_dtype_to_tensor_dtype (val .dtype ), onnx .TensorProto .INT64 )
27+ val2 = ort .numpy ()
28+ self .assertEqualArray (val , val2 )
29+ ort = ORTC .OrtValue .ortvalue_from_numpy_with_onnx_type (
30+ val , np_dtype_to_tensor_dtype (val .dtype )
31+ )
32+ val2 = ort .numpy ()
33+ self .assertEqualArray (val , val2 )
34+
35+ def test_ort_value_py (self ):
36+ data = get_tiny_llm ()
37+ inputs = data ["inputs" ]
38+ feeds = make_feeds (
39+ ["input_ids" , "attention_mask" , "position_ids" , "key0" , "value0" ],
40+ inputs ,
41+ use_numpy = True ,
42+ copy = True ,
43+ )
44+ new_feeds = {}
45+ for k , v in feeds .items ():
46+ new_feeds [k ] = onnxruntime .OrtValue .ortvalue_from_numpy_with_onnx_type (
47+ v , np_dtype_to_tensor_dtype (v .dtype )
48+ )
49+ other_feeds = {k : v .numpy () for k , v in new_feeds .items ()}
50+ self .assertEqualAny (feeds , other_feeds )
51+
52+ def test_ort_value_more (self ):
53+ data = get_tiny_llm ()
54+ inputs = data ["inputs" ]
55+ feeds = make_feeds (
56+ ["input_ids" , "attention_mask" , "position_ids" , "key0" , "value0" ],
57+ inputs ,
58+ use_numpy = True ,
59+ copy = True ,
60+ )
61+ feeds = {
62+ k : feeds [k ].copy ()
63+ for k in ["input_ids" , "attention_mask" , "key0" , "value0" , "position_ids" ]
64+ }
65+ new_feeds = {}
66+ for k , v in feeds .items ():
67+ new_feeds [k ] = ORTC .OrtValue .ortvalue_from_numpy_with_onnx_type (
68+ v , np_dtype_to_tensor_dtype (v .dtype )
69+ )
70+ other_feeds = {k : v .numpy () for k , v in new_feeds .items ()}
71+ self .assertEqualAny (feeds , other_feeds )
72+
1973 @ignore_warnings ((UserWarning , DeprecationWarning , FutureWarning ))
2074 @hide_stdout ()
2175 def test_check_allruntimes_on_tiny_llm (self ):
@@ -30,7 +84,7 @@ def test_check_allruntimes_on_tiny_llm(self):
3084
3185 proto = ep .model_proto
3286 self .dump_onnx ("test_check_allruntimes_on_tiny_llm.onnx" , proto )
33- feeds = make_feeds (proto , inputs , use_numpy = True )
87+ feeds = make_feeds (proto , inputs , use_numpy = True , copy = True )
3488 sess = onnxruntime .InferenceSession (
3589 proto .SerializeToString (), providers = ["CPUExecutionProvider" ]
3690 )
@@ -45,10 +99,10 @@ def test_check_allruntimes_on_tiny_llm(self):
4599 self .assertEqualArray (got [0 ], all_outputs ["linear_7" ])
46100
47101 sess = InferenceSessionForNumpy (proto )
48- got = sess .run (None , feeds , expected = all_outputs )
102+ got = sess .run (None , feeds )
49103 self .assertLess (max_diff (expected , got , flatten = True )["abs" ], 1e-5 )
50104
51- feeds = make_feeds (proto , inputs )
105+ feeds = make_feeds (proto , inputs , copy = True )
52106 sess = InferenceSessionForTorch (proto )
53107 got = sess .run (None , feeds )
54108 self .assertLess (max_diff (expected , got , flatten = True )["abs" ], 1e-5 )
0 commit comments