Skip to content

Commit 77c6035

Browse files
committed
QuadricCustomOp: Handle multiple outputs when shape inferencing (#17)
1 parent 4f4d738 commit 77c6035

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3043,7 +3043,6 @@ Input absmax is stored in same type as original type of B(float32, float16) with
30433043
.Attr("element_wise", "True (1) if only element-wise ops, False (0) otherwise", AttributeProto::INT, true)
30443044
.TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(),
30453045
"Allow inputs and outputs to be any kind of tensor.");
3046-
// FIXME: Add a type/shape inference function
30473046

30483047
#ifdef ENABLE_TRAINING_OPS
30493048
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or

onnxruntime/python/tools/symbolic_shape_infer.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
232232
"QLinearAveragePool": self._infer_qlinear_unary_op,
233233
# Quadric custom operators
234234
"QuadricCustomOp": self._infer_custom_op,
235-
"QuadricCustomOpElementWise": self._infer_custom_op
236235
}
237236
self.aten_op_dispatcher_ = {
238237
"embedding": self._infer_Gather,
@@ -460,6 +459,7 @@ def _onnx_infer_single_node(self, node):
460459
"If",
461460
"Loop",
462461
"Scan",
462+
"QuadricCustomOp",
463463
"SplitToSequence",
464464
"ZipMap", # contrib ops
465465
"Attention",
@@ -981,17 +981,21 @@ def _infer_qgemm(self, node):
981981
def _infer_custom_op(self, node):
982982
# For the CCL custom operators the shape and dtype of the output are present in
983983
# the attributes and can be used to directly create the value info
984-
attr_map = {n.name:n for n in list(node.attribute)}
985-
assert "shape" in attr_map and "elem_type" in attr_map,\
986-
"Custom op output type not found"
987-
vi = self.known_vi_[node.output[0]]
988-
vi.CopyFrom(
989-
helper.make_tensor_value_info(
990-
node.output[0],
991-
attr_map["elem_type"].i,
992-
attr_map["shape"].ints,
984+
attr_map = {n.name: n for n in list(node.attribute)}
985+
assert "shape" in attr_map and "elem_type" in attr_map, "Custom op output type not found"
986+
if len(node.output) > 1:
987+
for i, out in enumerate(node.output):
988+
vi = self.known_vi_[out]
989+
vi.CopyFrom(
990+
helper.make_tensor_value_info(
991+
out,
992+
attr_map["elem_type"].ints[i],
993+
attr_map["shape"].tensors[i].int32_data,
994+
)
993995
)
994-
)
996+
else:
997+
vi = self.known_vi_[node.output[0]]
998+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], attr_map["elem_type"].i, attr_map["shape"].ints))
995999

9961000
def _infer_ConcatFromSequence(self, node):
9971001
seq_shape = self._get_shape(node, 0)
@@ -2610,6 +2614,10 @@ def get_prereq(node):
26102614
get_attribute(node, "then_branch"),
26112615
get_attribute(node, "else_branch"),
26122616
]
2617+
elif node.op_type == "QuadricCustomOp":
2618+
# Should have a subgraph, but allow for cases where it's not there
2619+
subgraph = get_attribute(node, "sub_graph")
2620+
subgraphs = [subgraph] if subgraph else []
26132621
elif node.op_type in ["Loop", "Scan"]:
26142622
subgraphs = [get_attribute(node, "body")]
26152623
for g in subgraphs:

0 commit comments

Comments
 (0)