|
| 1 | +import textwrap |
| 2 | +import unittest |
| 3 | +import onnx |
| 4 | +import onnx.helper as oh |
| 5 | +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers |
| 6 | +from onnx_diagnostic.helpers.dot_helper import to_dot |
| 7 | +from onnx_diagnostic.export.api import to_onnx |
| 8 | +from onnx_diagnostic.torch_export_patches import torch_export_patches |
| 9 | +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs |
| 10 | + |
| 11 | + |
| 12 | +class TestDotHelper(ExtTestCase): |
| 13 | + def test_custom_doc_kernels_layer_normalization(self): |
| 14 | + TFLOAT16 = onnx.TensorProto.FLOAT16 |
| 15 | + model = oh.make_model( |
| 16 | + oh.make_graph( |
| 17 | + [ |
| 18 | + oh.make_node( |
| 19 | + "LayerNormalization", |
| 20 | + ["X", "W", "B"], |
| 21 | + ["ln"], |
| 22 | + axis=-1, |
| 23 | + epsilon=9.999999974752427e-7, |
| 24 | + ), |
| 25 | + oh.make_node( |
| 26 | + "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7 |
| 27 | + ), |
| 28 | + ], |
| 29 | + "dummy", |
| 30 | + [ |
| 31 | + oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]), |
| 32 | + oh.make_tensor_value_info("W", TFLOAT16, ["d"]), |
| 33 | + oh.make_tensor_value_info("B", TFLOAT16, ["d"]), |
| 34 | + ], |
| 35 | + [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])], |
| 36 | + ), |
| 37 | + ir_version=9, |
| 38 | + opset_imports=[oh.make_opsetid("", 18)], |
| 39 | + ) |
| 40 | + dot = to_dot(model) |
| 41 | + expected = textwrap.dedent( |
| 42 | + """ |
| 43 | + digraph { |
| 44 | + graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; |
| 45 | + node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; |
| 46 | + edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0]; |
| 47 | + I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"]; |
| 48 | + I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"]; |
| 49 | + I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"]; |
| 50 | + LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"]; |
| 51 | + Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"]; |
| 52 | + I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"]; |
| 53 | + I_1 -> LayerNormalization_3 [label="FLOAT16(d)"]; |
| 54 | + I_2 -> LayerNormalization_3 [label="FLOAT16(d)"]; |
| 55 | + LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"]; |
| 56 | + I_1 -> Add_4 [label="FLOAT16(d)"]; |
| 57 | + O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"]; |
| 58 | + Add_4 -> O_5; |
| 59 | + } |
| 60 | + """ |
| 61 | + ) |
| 62 | + self.maxDiff = None |
| 63 | + self.assertEqual(expected.strip("\n "), dot.strip("\n ")) |
| 64 | + |
| 65 | + @requires_transformers("4.57") |
| 66 | + def test_dot_plot_tiny(self): |
| 67 | + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") |
| 68 | + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] |
| 69 | + with torch_export_patches(patch_transformers=True): |
| 70 | + em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom") |
| 71 | + dot = to_dot(em.model_proto) |
| 72 | + name = self.get_dump_file("test_dot_plot_tiny.dot") |
| 73 | + with open(name, "w") as f: |
| 74 | + f.write(dot) |
| 75 | + # dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png |
| 76 | + self.assertIn("-> Add", dot) |
| 77 | + |
| 78 | + |
| 79 | +if __name__ == "__main__": |
| 80 | + unittest.main(verbosity=2) |
0 commit comments