@@ -956,6 +956,119 @@ def _mkv_(name):
956956 torch .tensor ([1 ], dtype = torch .float32 ),
957957 )
958958
959+ def test_loop (self ):
960+ x = np .array ([1 , 2 , 3 , 4 , 5 ]).astype (np .float32 )
961+
962+ model = oh .make_model (
963+ graph = oh .make_graph (
964+ name = "loop_test" ,
965+ inputs = [
966+ oh .make_tensor_value_info ("trip_count" , TINT64 , ["a" ]),
967+ oh .make_tensor_value_info ("cond" , onnx .TensorProto .BOOL , [1 ]),
968+ ],
969+ outputs = [oh .make_tensor_value_info ("res" , TFLOAT , [])],
970+ nodes = [
971+ oh .make_node ("SequenceEmpty" , [], ["seq_empty" ], dtype = TFLOAT ),
972+ oh .make_node (
973+ "Loop" ,
974+ inputs = ["trip_count" , "cond" , "seq_empty" ],
975+ outputs = ["seq_res" ],
976+ body = oh .make_graph (
977+ [
978+ oh .make_node (
979+ "Identity" , inputs = ["cond_in" ], outputs = ["cond_out" ]
980+ ),
981+ oh .make_node (
982+ "Constant" ,
983+ inputs = [],
984+ outputs = ["x" ],
985+ value = oh .make_tensor (
986+ name = "const_tensor_x" ,
987+ data_type = TFLOAT ,
988+ dims = x .shape ,
989+ vals = x .flatten ().astype (float ),
990+ ),
991+ ),
992+ oh .make_node (
993+ "Constant" ,
994+ inputs = [],
995+ outputs = ["one" ],
996+ value = oh .make_tensor (
997+ name = "const_tensor_one" ,
998+ data_type = TINT64 ,
999+ dims = (),
1000+ vals = [1 ],
1001+ ),
1002+ ),
1003+ oh .make_node (
1004+ "Constant" ,
1005+ inputs = [],
1006+ outputs = ["slice_start" ],
1007+ value = oh .make_tensor (
1008+ name = "const_tensor_zero" ,
1009+ data_type = TINT64 ,
1010+ dims = (1 ,),
1011+ vals = [0 ],
1012+ ),
1013+ ),
1014+ oh .make_node (
1015+ "Add" , inputs = ["iter_count" , "one" ], outputs = ["end" ]
1016+ ),
1017+ oh .make_node (
1018+ "Constant" ,
1019+ inputs = [],
1020+ outputs = ["axes" ],
1021+ value = oh .make_tensor (
1022+ name = "const_tensor_axes" ,
1023+ data_type = TINT64 ,
1024+ dims = (1 ,),
1025+ vals = [0 ],
1026+ ),
1027+ ),
1028+ oh .make_node (
1029+ "Unsqueeze" , inputs = ["end" , "axes" ], outputs = ["slice_end" ]
1030+ ),
1031+ oh .make_node (
1032+ "Slice" ,
1033+ inputs = ["x" , "slice_start" , "slice_end" ],
1034+ outputs = ["slice_out" ],
1035+ ),
1036+ oh .make_node (
1037+ "SequenceInsert" ,
1038+ inputs = ["seq_in" , "slice_out" ],
1039+ outputs = ["seq_out" ],
1040+ ),
1041+ ],
1042+ "loop_body" ,
1043+ [
1044+ oh .make_tensor_value_info ("iter_count" , TINT64 , []),
1045+ oh .make_tensor_value_info (
1046+ "cond_in" , onnx .TensorProto .BOOL , []
1047+ ),
1048+ oh .make_tensor_sequence_value_info ("seq_in" , TFLOAT , None ),
1049+ ],
1050+ [
1051+ oh .make_tensor_value_info (
1052+ "cond_out" , onnx .TensorProto .BOOL , []
1053+ ),
1054+ oh .make_tensor_sequence_value_info ("seq_out" , TFLOAT , None ),
1055+ ],
1056+ ),
1057+ ),
1058+ oh .make_node (
1059+ "ConcatFromSequence" ,
1060+ inputs = ["seq_res" ],
1061+ outputs = ["res" ],
1062+ axis = 0 ,
1063+ new_axis = 0 ,
1064+ ),
1065+ ],
1066+ )
1067+ )
1068+ self ._finalize_test (
1069+ model , torch .tensor (5 , dtype = torch .int64 ), torch .tensor (1 , dtype = torch .bool )
1070+ )
1071+
9591072
9601073if __name__ == "__main__" :
9611074 unittest .main (verbosity = 2 )
0 commit comments