Skip to content

Commit 126cce8

Browse files
authored
QuadricCustomOp: Handle multiple outputs when shape inferencing (#17)
1 parent d0776ea commit 126cce8

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2930,7 +2930,6 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
29302930
.Attr("element_wise", "True (1) if only element-wise ops, False (0) otherwise", AttributeProto::INT, true)
29312931
.TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(),
29322932
"Allow inputs and outputs to be any kind of tensor.");
2933-
// FIXME: Add a type/shape inference function
29342933

29352934
#ifdef ENABLE_TRAINING_OPS
29362935
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or

onnxruntime/python/tools/symbolic_shape_infer.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
226226
"QLinearAveragePool": self._infer_qlinear_unary_op,
227227
# Quadric custom operators
228228
"QuadricCustomOp": self._infer_custom_op,
229-
"QuadricCustomOpElementWise": self._infer_custom_op
230229
}
231230
self.aten_op_dispatcher_ = {
232231
"embedding": self._infer_Gather,
@@ -455,6 +454,7 @@ def _onnx_infer_single_node(self, node):
455454
"If",
456455
"Loop",
457456
"Scan",
457+
"QuadricCustomOp",
458458
"SplitToSequence",
459459
"ZipMap", # contrib ops
460460
"Attention",
@@ -974,17 +974,21 @@ def _infer_qgemm(self, node):
974974
def _infer_custom_op(self, node):
975975
# For the CCL custom operators the shape and dtype of the output are present in
976976
# the attributes and can be used to directly create the value info
977-
attr_map = {n.name:n for n in list(node.attribute)}
978-
assert "shape" in attr_map and "elem_type" in attr_map,\
979-
"Custom op output type not found"
980-
vi = self.known_vi_[node.output[0]]
981-
vi.CopyFrom(
982-
helper.make_tensor_value_info(
983-
node.output[0],
984-
attr_map["elem_type"].i,
985-
attr_map["shape"].ints,
977+
attr_map = {n.name: n for n in list(node.attribute)}
978+
assert "shape" in attr_map and "elem_type" in attr_map, "Custom op output type not found"
979+
if len(node.output) > 1:
980+
for i, out in enumerate(node.output):
981+
vi = self.known_vi_[out]
982+
vi.CopyFrom(
983+
helper.make_tensor_value_info(
984+
out,
985+
attr_map["elem_type"].ints[i],
986+
attr_map["shape"].tensors[i].int32_data,
987+
)
986988
)
987-
)
989+
else:
990+
vi = self.known_vi_[node.output[0]]
991+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], attr_map["elem_type"].i, attr_map["shape"].ints))
988992

989993
def _infer_ConcatFromSequence(self, node):
990994
seq_shape = self._get_shape(node, 0)
@@ -2582,6 +2586,10 @@ def get_prereq(node):
25822586
get_attribute(node, "then_branch"),
25832587
get_attribute(node, "else_branch"),
25842588
]
2589+
elif node.op_type == "QuadricCustomOp":
2590+
# Should have a subgraph, but allow for cases where it's not there
2591+
subgraph = get_attribute(node, "sub_graph")
2592+
subgraphs = [subgraph] if subgraph else []
25852593
elif node.op_type in ["Loop", "Scan"]:
25862594
subgraphs = [get_attribute(node, "body")]
25872595
for g in subgraphs:

0 commit comments

Comments
 (0)