@@ -232,7 +232,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
232232 "QLinearAveragePool" : self ._infer_qlinear_unary_op ,
233233 # Quadric custom operators
234234 "QuadricCustomOp" : self ._infer_custom_op ,
235- "QuadricCustomOpElementWise" : self ._infer_custom_op
236235 }
237236 self .aten_op_dispatcher_ = {
238237 "embedding" : self ._infer_Gather ,
@@ -460,6 +459,7 @@ def _onnx_infer_single_node(self, node):
460459 "If" ,
461460 "Loop" ,
462461 "Scan" ,
462+ "QuadricCustomOp" ,
463463 "SplitToSequence" ,
464464 "ZipMap" , # contrib ops
465465 "Attention" ,
@@ -981,17 +981,21 @@ def _infer_qgemm(self, node):
981981 def _infer_custom_op (self , node ):
982982 # For the CCL custom operators the shape and dtype of the output are present in
983983 # the attributes and can be used to directly create the value info
984- attr_map = {n .name :n for n in list (node .attribute )}
985- assert "shape" in attr_map and "elem_type" in attr_map ,\
986- "Custom op output type not found"
987- vi = self .known_vi_ [node .output [0 ]]
988- vi .CopyFrom (
989- helper .make_tensor_value_info (
990- node .output [0 ],
991- attr_map ["elem_type" ].i ,
992- attr_map ["shape" ].ints ,
984+ attr_map = {n .name : n for n in list (node .attribute )}
985+ assert "shape" in attr_map and "elem_type" in attr_map , "Custom op output type not found"
986+ if len (node .output ) > 1 :
987+ for i , out in enumerate (node .output ):
988+ vi = self .known_vi_ [out ]
989+ vi .CopyFrom (
990+ helper .make_tensor_value_info (
991+ out ,
992+ attr_map ["elem_type" ].ints [i ],
993+ attr_map ["shape" ].tensors [i ].int32_data ,
994+ )
993995 )
994- )
996+ else :
997+ vi = self .known_vi_ [node .output [0 ]]
998+ vi .CopyFrom (helper .make_tensor_value_info (node .output [0 ], attr_map ["elem_type" ].i , attr_map ["shape" ].ints ))
995999
9961000 def _infer_ConcatFromSequence (self , node ):
9971001 seq_shape = self ._get_shape (node , 0 )
@@ -2610,6 +2614,10 @@ def get_prereq(node):
26102614 get_attribute (node , "then_branch" ),
26112615 get_attribute (node , "else_branch" ),
26122616 ]
2617+ elif node .op_type == "QuadricCustomOp" :
2618+ # Should have a subgraph, but allow for cases where it's not there
2619+ subgraph = get_attribute (node , "sub_graph" )
2620+ subgraphs = [subgraph ] if subgraph else []
26132621 elif node .op_type in ["Loop" , "Scan" ]:
26142622 subgraphs = [get_attribute (node , "body" )]
26152623 for g in subgraphs :
0 commit comments