77
88
99class LayerNormalizationOrt (OpRunKernel ):
10- "LayerNormalization"
10+ "LayerNormalization with onnxruntime "
1111
1212 @classmethod
1313 def device_dependent (cls ) -> bool :
@@ -25,7 +25,7 @@ def __init__(
2525 self .axis = self .get_attribute_int (node , "axis" , - 1 )
2626 self .epsilon = self .get_attribute_float (node , "epsilon" , 1e-5 )
2727 self .device = device
28- self .stash_type = onnx_dtype_to_torch_dtype (
28+ self .stash_type = onnx_dtype_to_torch_dtype ( # type: ignore[arg-type]
2929 self .get_attribute_int (node , "stash_type" , onnx .TensorProto .FLOAT )
3030 )
3131 self .compute_std = len (node .output ) > 1
@@ -36,7 +36,7 @@ def __init__(
3636 self ._cache : Dict [Tuple [int , int ], onnx .ModelProto ] = {}
3737 self .is_cpu = torch .device ("cpu" ) == self .device
3838
39- def _make_model (self , dtype : int , rank : int ) -> onnx .ModelProto :
39+ def _make_model (self , itype : int , rank : int ) -> onnx .ModelProto :
4040 shape = [* ["d{i}" for i in range (rank - 1 )], "last" ]
4141 layer_model = oh .make_model (
4242 oh .make_graph (
@@ -51,14 +51,14 @@ def _make_model(self, dtype: int, rank: int) -> onnx.ModelProto:
5151 ],
5252 "dummy" ,
5353 [
54- oh .make_tensor_value_info ("X" , onnx . TensorProto . FLOAT16 , shape ),
55- oh .make_tensor_value_info ("W" , onnx . TensorProto . FLOAT16 , ["last" ]),
56- oh .make_tensor_value_info ("B" , onnx . TensorProto . FLOAT16 , ["last" ]),
54+ oh .make_tensor_value_info ("X" , itype , shape ),
55+ oh .make_tensor_value_info ("W" , itype , ["last" ]),
56+ oh .make_tensor_value_info ("B" , itype , ["last" ]),
5757 ],
58- [oh .make_tensor_value_info ("Z" , onnx . TensorProto . FLOAT16 , shape )],
58+ [oh .make_tensor_value_info ("Z" , itype , shape )],
5959 ),
6060 ir_version = 9 ,
61- opset_imports = [oh .make_opsetid ("" , 17 )],
61+ opset_imports = [oh .make_opsetid ("" , 18 )],
6262 )
6363 import onnxruntime
6464
@@ -80,3 +80,58 @@ def run(self, x, scale, bias=None):
8080 feeds = {k : v .tensor .detach ().cpu ().numpy () for k , v in feeds .items ()}
8181 got = sess .run (None , feeds )[0 ]
8282 return OpRunTensor (torch .from_numpy (got ).to (x .dtype ).to (x .device ))
83+
84+
85+ class MatMulOrt (OpRunKernel ):
86+ "MatMul with onnxruntime"
87+
88+ @classmethod
89+ def device_dependent (cls ) -> bool :
90+ "Needs device."
91+ return False
92+
93+ def __init__ (
94+ self ,
95+ node : onnx .NodeProto ,
96+ version = None ,
97+ device : Optional [torch .device ] = None ,
98+ verbose = 0 ,
99+ ):
100+ super ().__init__ (node , version , verbose = verbose )
101+ self .device = device
102+ self ._cache : Dict [Tuple [int , int , int ], onnx .ModelProto ] = {}
103+ self .is_cpu = torch .device ("cpu" ) == self .device
104+
105+ def _make_model (self , itype : int , ranka : int , rankb : int ) -> onnx .ModelProto :
106+ shapea = ["a{i}" for i in range (ranka )]
107+ shapeb = ["b{i}" for i in range (rankb )]
108+ shapec = ["c{i}" for i in range (max (ranka , rankb ))]
109+ model = oh .make_model (
110+ oh .make_graph (
111+ [oh .make_node ("MatMul" , ["A" , "B" ], ["C" ])],
112+ "dummy" ,
113+ [
114+ oh .make_tensor_value_info ("A" , itype , shapea ),
115+ oh .make_tensor_value_info ("B" , itype , shapeb ),
116+ ],
117+ [oh .make_tensor_value_info ("C" , itype , shapec )],
118+ ),
119+ ir_version = 9 ,
120+ opset_imports = [oh .make_opsetid ("" , 17 )],
121+ )
122+ import onnxruntime
123+
124+ provider = "CPUExecutionProvider" if self .is_cpu else "CUDAExecutionProvider"
125+ return onnxruntime .InferenceSession (model .SerializeToString (), providers = [provider ])
126+
127+ def run (self , a , b ):
128+ itype = torch_dtype_to_onnx_dtype (a .dtype )
129+ ranka , rankb = len (a .shape ), len (b .shape )
130+ key = itype , ranka , rankb
131+ if key not in self ._cache :
132+ self ._cache [key ] = self ._make_model (itype , ranka , rankb )
133+ sess = self ._cache [key ]
134+ feeds = dict (A = a , B = b )
135+ feeds = {k : v .tensor .detach ().cpu ().numpy () for k , v in feeds .items ()}
136+ got = sess .run (None , feeds )[0 ]
137+ return OpRunTensor (torch .from_numpy (got ).to (a .dtype ).to (a .device ))
0 commit comments