@@ -62,6 +62,63 @@ def test_custom_doc_kernels_layer_normalization(self):
6262 self .maxDiff = None
6363 self .assertEqual (expected .strip ("\n " ), dot .strip ("\n " ))
6464
65+ def test_custom_doc_kernels_layer_normalization_constant (self ):
66+ TFLOAT16 = onnx .TensorProto .FLOAT16
67+ model = oh .make_model (
68+ oh .make_graph (
69+ [
70+ oh .make_node (
71+ "LayerNormalization" ,
72+ ["X" , "W" , "B" ],
73+ ["ln" ],
74+ axis = - 1 ,
75+ epsilon = 9.999999974752427e-7 ,
76+ ),
77+ oh .make_node ("Constant" , [], ["cst" ], value_float = [1 ]),
78+ oh .make_node ("Cast" , ["cst" ], ["cst16" ], to = onnx .TensorProto .FLOAT16 ),
79+ oh .make_node ("Add" , ["ln" , "cst16" ], ["Z" ], axis = - 1 ),
80+ ],
81+ "dummy" ,
82+ [
83+ oh .make_tensor_value_info ("X" , TFLOAT16 , ["b" , "c" , "d" ]),
84+ oh .make_tensor_value_info ("W" , TFLOAT16 , ["d" ]),
85+ oh .make_tensor_value_info ("B" , TFLOAT16 , ["d" ]),
86+ ],
87+ [oh .make_tensor_value_info ("Z" , TFLOAT16 , ["b" , "c" , "d" ])],
88+ ),
89+ ir_version = 9 ,
90+ opset_imports = [oh .make_opsetid ("" , 18 )],
91+ )
92+ dot = to_dot (model )
93+ expected = (
94+ textwrap .dedent (
95+ """
96+ digraph {
97+ graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
98+ node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
99+ edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
100+ I_0 [label="X\\ nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
101+ I_1 [label="W\\ nFLOAT16(d)", fillcolor="#aaeeaa"];
102+ I_2 [label="B\\ nFLOAT16(d)", fillcolor="#aaeeaa"];
103+ LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
104+ Cast_4 [label="Cast([1.0], to=FLOAT16)", fillcolor="#cccccc"];
105+ Add_5[label="Add(.,.,axis=-1)",fillcolor="#cccccc"];
106+ I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
107+ I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
108+ I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
109+ LayerNormalization_3 -> Add_5 [label="FLOAT16(b,c,d)"];
110+ Cast_4->Add_5[label="FLOAT16()"];
111+ O_6 [label="Z\\ nFLOAT16(b,c,d)", fillcolor="#aaaaee"];
112+ Add_5 -> O_6;
113+ }
114+ """
115+ )
116+ .strip ("\n " )
117+ .replace (" " , "" )
118+ )
119+ self .maxDiff = None
120+ self .assertEqual (expected , dot .strip ("\n " ).replace (" " , "" ))
121+
65122 @requires_transformers ("4.57" )
66123 def test_dot_plot_tiny (self ):
67124 data = get_untrained_model_with_inputs ("arnir0/Tiny-LLM" )
0 commit comments