|
1 | 1 | import unittest |
2 | | -from typing import Optional |
| 2 | +from typing import Any, Dict, Optional, Tuple |
3 | 3 | import numpy as np |
| 4 | +import ml_dtypes |
4 | 5 | from onnx import ModelProto, TensorProto |
5 | 6 | from onnx.checker import check_model |
6 | 7 | import onnx.helper as oh |
|
12 | 13 | ignore_warnings, |
13 | 14 | requires_cuda, |
14 | 15 | ) |
| 16 | +from onnx_diagnostic.helpers import ( |
| 17 | + from_array_extended, |
| 18 | + onnx_dtype_to_torch_dtype, |
| 19 | + onnx_dtype_to_np_dtype, |
| 20 | +) |
15 | 21 | from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator |
16 | 22 |
|
17 | 23 | TFLOAT = TensorProto.FLOAT |
@@ -163,6 +169,85 @@ def test_local_function(self): |
163 | 169 | got = ort_eval.run(None, feeds) |
164 | 170 | self.assertEqualArray(expected[0], got[0]) |
165 | 171 |
|
| 172 | + @classmethod |
| 173 | + def _trange(cls, *shape, bias: Optional[float] = None): |
| 174 | + n = np.prod(shape) |
| 175 | + x = np.arange(n).astype(np.float32) / n |
| 176 | + if bias: |
| 177 | + x = x + bias |
| 178 | + return torch.from_numpy(x.reshape(tuple(shape)).astype(np.float32)) |
| 179 | + |
| 180 | + @classmethod |
| 181 | + def _get_model_init(cls, itype) -> Tuple[ModelProto, Dict[str, Any], Tuple[Any, ...]]: |
| 182 | + dtype = onnx_dtype_to_np_dtype(itype) |
| 183 | + ttype = onnx_dtype_to_torch_dtype(itype) |
| 184 | + cst = np.arange(6).astype(dtype) |
| 185 | + model = oh.make_model( |
| 186 | + oh.make_graph( |
| 187 | + [ |
| 188 | + oh.make_node("IsNaN", ["x"], ["xi"]), |
| 189 | + oh.make_node("IsNaN", ["y"], ["yi"]), |
| 190 | + oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64), |
| 191 | + oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64), |
| 192 | + oh.make_node("Add", ["xii", "yii"], ["gggg"]), |
| 193 | + oh.make_node("Cast", ["gggg"], ["final"], to=itype), |
| 194 | + ], |
| 195 | + "dummy", |
| 196 | + [oh.make_tensor_value_info("x", itype, [None, None])], |
| 197 | + [oh.make_tensor_value_info("final", itype, [None, None])], |
| 198 | + [from_array_extended(cst, name="y")], |
| 199 | + ), |
| 200 | + opset_imports=[oh.make_opsetid("", 20)], |
| 201 | + ir_version=10, |
| 202 | + ) |
| 203 | + feeds = {"x": cls._trange(5, 6).to(ttype)} |
| 204 | + expected = torch.isnan(feeds["x"]).to(int) + torch.isnan( |
| 205 | + torch.from_numpy(cst.astype(float)) |
| 206 | + ).to(int) |
| 207 | + return (model, feeds, (expected.to(ttype),)) |
| 208 | + |
| 209 | + @hide_stdout() |
| 210 | + def test_init_numpy_afloat32(self): |
| 211 | + model, feeds, expected = self._get_model_init(TensorProto.FLOAT) |
| 212 | + wrap = OnnxruntimeEvaluator( |
| 213 | + model, providers="cpu", graph_optimization_level=False, verbose=10 |
| 214 | + ) |
| 215 | + got = wrap.run(None, {k: v.numpy() for k, v in feeds.items()}) |
| 216 | + self.assertIsInstance(got[0], np.ndarray) |
| 217 | + self.assertEqualArray(expected[0], got[0]) |
| 218 | + |
| 219 | + @hide_stdout() |
| 220 | + def test_init_numpy_bfloat16(self): |
| 221 | + model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16) |
| 222 | + wrap = OnnxruntimeEvaluator( |
| 223 | + model, providers="cpu", graph_optimization_level=False, verbose=10 |
| 224 | + ) |
| 225 | + got = wrap.run( |
| 226 | + None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()} |
| 227 | + ) |
| 228 | + self.assertIsInstance(got[0], np.ndarray) |
| 229 | + self.assertEqualArray(expected[0], got[0]) |
| 230 | + |
| 231 | + @hide_stdout() |
| 232 | + def test_init_torch_afloat32(self): |
| 233 | + model, feeds, expected = self._get_model_init(TensorProto.FLOAT) |
| 234 | + wrap = OnnxruntimeEvaluator( |
| 235 | + model, providers="cpu", graph_optimization_level=False, verbose=10 |
| 236 | + ) |
| 237 | + got = wrap.run(None, feeds) |
| 238 | + self.assertIsInstance(got[0], torch.Tensor) |
| 239 | + self.assertEqualArray(expected[0], got[0]) |
| 240 | + |
| 241 | + @hide_stdout() |
| 242 | + def test_init_torch_bfloat16(self): |
| 243 | + model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16) |
| 244 | + wrap = OnnxruntimeEvaluator( |
| 245 | + model, providers="cpu", graph_optimization_level=False, verbose=10 |
| 246 | + ) |
| 247 | + got = wrap.run(None, feeds) |
| 248 | + self.assertIsInstance(got[0], torch.Tensor) |
| 249 | + self.assertEqualArray(expected[0], got[0]) |
| 250 | + |
166 | 251 |
|
167 | 252 | if __name__ == "__main__": |
168 | 253 | unittest.main(verbosity=2) |
0 commit comments