@@ -226,7 +226,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
226226 "QLinearAveragePool" : self ._infer_qlinear_unary_op ,
227227 # Quadric custom operators
228228 "QuadricCustomOp" : self ._infer_custom_op ,
229- "QuadricCustomOpElementWise" : self ._infer_custom_op
230229 }
231230 self .aten_op_dispatcher_ = {
232231 "embedding" : self ._infer_Gather ,
@@ -455,6 +454,7 @@ def _onnx_infer_single_node(self, node):
455454 "If" ,
456455 "Loop" ,
457456 "Scan" ,
457+ "QuadricCustomOp" ,
458458 "SplitToSequence" ,
459459 "ZipMap" , # contrib ops
460460 "Attention" ,
@@ -974,17 +974,21 @@ def _infer_qgemm(self, node):
974974 def _infer_custom_op (self , node ):
975975 # For the CCL custom operators the shape and dtype of the output are present in
976976 # the attributes and can be used to directly create the value info
977- attr_map = {n .name :n for n in list (node .attribute )}
978- assert "shape" in attr_map and "elem_type" in attr_map ,\
979- "Custom op output type not found"
980- vi = self .known_vi_ [node .output [0 ]]
981- vi .CopyFrom (
982- helper .make_tensor_value_info (
983- node .output [0 ],
984- attr_map ["elem_type" ].i ,
985- attr_map ["shape" ].ints ,
977+ attr_map = {n .name : n for n in list (node .attribute )}
978+ assert "shape" in attr_map and "elem_type" in attr_map , "Custom op output type not found"
979+ if len (node .output ) > 1 :
980+ for i , out in enumerate (node .output ):
981+ vi = self .known_vi_ [out ]
982+ vi .CopyFrom (
983+ helper .make_tensor_value_info (
984+ out ,
985+ attr_map ["elem_type" ].ints [i ],
986+ attr_map ["shape" ].tensors [i ].int32_data ,
987+ )
986988 )
987- )
989+ else :
990+ vi = self .known_vi_ [node .output [0 ]]
991+ vi .CopyFrom (helper .make_tensor_value_info (node .output [0 ], attr_map ["elem_type" ].i , attr_map ["shape" ].ints ))
988992
989993 def _infer_ConcatFromSequence (self , node ):
990994 seq_shape = self ._get_shape (node , 0 )
@@ -2582,6 +2586,10 @@ def get_prereq(node):
25822586 get_attribute (node , "then_branch" ),
25832587 get_attribute (node , "else_branch" ),
25842588 ]
2589+ elif node .op_type == "QuadricCustomOp" :
2590+ # Should have a subgraph, but allow for cases where it's not there
2591+ subgraph = get_attribute (node , "sub_graph" )
2592+ subgraphs = [subgraph ] if subgraph else []
25852593 elif node .op_type in ["Loop" , "Scan" ]:
25862594 subgraphs = [get_attribute (node , "body" )]
25872595 for g in subgraphs :
0 commit comments