Skip to content

Commit 24898da

Browse files
authored
QuadricCustomOp handling (#12)
* Add QuadricCustomOp * Update README_EPU.md with correct instructions
1 parent 421a7ca commit 24898da

File tree

7 files changed

+353
-3
lines changed

7 files changed

+353
-3
lines changed

README_EPU.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@ This repository contains the a distribution of onnxruntime with additional opera
1111
```
1212
git clone --recursive https://github.com/quadric-io/onnxruntime onnxruntime
1313
cd onnxruntime
14-
# Install wheel
15-
pip install wheel
14+
python3.9 -m venv venv
15+
source venv/bin/activate
16+
# Install required packages. numpy version is restricted by TVM
17+
pip3 install wheel packaging numpy==1.24.4
1618
# Build the python package
1719
./build.sh --build_wheel --config Release --parallel
1820
```
1921

2022
## Install
2123
```
22-
pip install build/MacOS/Release/dist/onnxruntime-1.14.0-cp39-cp39-macosx_11_0_x86_64.whl
24+
# Find the wheel you just created
25+
$ find . -name '*.whl'
26+
./build/MacOS/Release/dist/onnxruntime-1.16.0-cp39-cp39-macosx_13_0_arm64.whl
27+
# Install it
28+
pip3 install ./build/MacOS/Release/dist/onnxruntime-1.16.0-cp39-cp39-macosx_13_0_arm64.whl
2329
```

include/onnxruntime/core/graph/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ constexpr const char* kMSDmlDomain = "com.microsoft.dml";
2424
constexpr const char* kNGraphDomain = "com.intel.ai";
2525
constexpr const char* kMIGraphXDomain = "";
2626
constexpr const char* kVitisAIDomain = "com.xilinx";
27+
constexpr const char* kQuadricDomain = "com.quadric";
2728

2829
// This is moved from the OrtApis::GetAvailableProviders implementation
2930
// where it is enforced
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright (c) Quadric, Inc. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "quadric_custom_op.h"
5+
#include "core/common/common.h"
6+
#include "core/framework/op_kernel_context_internal.h"
7+
#include "core/framework/ortdevice.h"
8+
#include "core/framework/ortmemoryinfo.h"
9+
#include "core/framework/session_options.h"
10+
#include "core/framework/session_state.h"
11+
#include "core/framework/tensorprotoutils.h"
12+
#include "core/framework/utils.h"
13+
14+
using namespace ONNX_NAMESPACE;
15+
using namespace onnxruntime::common;
16+
17+
namespace onnxruntime {
18+
19+
ONNX_OPERATOR_KERNEL_EX(QuadricCustomOp, kQuadricDomain, 1, kCpuExecutionProvider, KernelDefBuilder(), QuadricCustomOp);
20+
21+
QuadricCustomOp::Info::Info(const onnxruntime::Node& node, const GraphViewer& subgraph_in) : subgraph(subgraph_in), used_inputs(node.InputDefs().size(), false) {
22+
num_inputs = static_cast<int>(node.InputDefs().size());
23+
num_outputs = static_cast<int>(node.OutputDefs().size());
24+
25+
auto& subgraph_inputs = subgraph.GetInputs();
26+
auto num_subgraph_inputs = subgraph_inputs.size();
27+
28+
for (size_t i = 0; i < num_subgraph_inputs; ++i) {
29+
auto& input = subgraph_inputs[i];
30+
subgraph_input_names.insert(input->Name());
31+
}
32+
33+
// This is commented out because we include initializers as inputs to the custom op, but
34+
// *NOT* the sub-graph. As a result, the number of inputs differs. Unfortunately, ORT doesn't do
35+
// a great job of telling us whether something is truly an initializer or not, so we can't
36+
// effectively check whether an input is an initializer or not.
37+
/*ORT_ENFORCE(num_subgraph_inputs == static_cast<size_t>(num_inputs),
38+
"'QuadricCustomOp' node has ", num_inputs, " inputs which doesn't match the subgraph's ",
39+
num_subgraph_inputs, " inputs.");
40+
*/
41+
42+
auto& subgraph_outputs = subgraph.GetOutputs();
43+
auto num_subgraph_outputs = subgraph_outputs.size();
44+
45+
// outputs should always match up, so enforce that.
46+
ORT_ENFORCE(num_subgraph_outputs == static_cast<size_t>(num_outputs),
47+
"'QuadricCustomOp' node has ", num_outputs, " outputs which doesn't match the subgraph's ",
48+
num_subgraph_outputs, " outputs.");
49+
50+
subgraph_output_names.reserve(num_subgraph_outputs);
51+
for (size_t i = 0; i < num_subgraph_outputs; ++i) {
52+
auto& output = subgraph_outputs[i];
53+
subgraph_output_names.push_back(output->Name());
54+
}
55+
}
56+
57+
class QuadricCustomOpImpl {
58+
public:
59+
QuadricCustomOpImpl(OpKernelContextInternal& context,
60+
const SessionState& session_state,
61+
const QuadricCustomOp::Info& info);
62+
63+
Status Initialize();
64+
Status Execute(const FeedsFetchesManager& ffm);
65+
66+
private:
67+
OpKernelContextInternal& context_;
68+
const SessionState& session_state_;
69+
const QuadricCustomOp::Info& info_;
70+
71+
Status AllocateOutputTensors();
72+
73+
enum class AllocationType {
74+
Delayed, // allocation of If output will be done by subgraph execution
75+
SubgraphOutput
76+
};
77+
78+
// track where the fetches provided to subgraph execution were allocated.
79+
std::vector<std::pair<AllocationType, OrtValue>> outputs_;
80+
};
81+
82+
QuadricCustomOpImpl::QuadricCustomOpImpl(OpKernelContextInternal& context,
83+
const SessionState& session_state,
84+
const QuadricCustomOp::Info& info) : context_(context),
85+
session_state_(session_state),
86+
info_(info) {}
87+
88+
Status QuadricCustomOpImpl::Initialize() {
89+
auto status = AllocateOutputTensors();
90+
ORT_RETURN_IF_ERROR(status);
91+
92+
return Status::OK();
93+
}
94+
95+
Status QuadricCustomOpImpl::AllocateOutputTensors() {
96+
// This function mostly copied from if.cc
97+
Status status = Status::OK();
98+
int index = 0;
99+
100+
const GraphViewer& subgraph = session_state_.GetGraphViewer();
101+
102+
const auto& graph_outputs = subgraph.GetOutputs();
103+
104+
for (auto& graph_output : graph_outputs) {
105+
const auto* graph_output_type = graph_output->TypeAsProto();
106+
107+
ORT_ENFORCE(graph_output_type->has_tensor_type() || graph_output_type->has_sequence_type(), "Only tensors or tensor sequences are supported");
108+
if (graph_output_type->has_tensor_type()) {
109+
auto* graph_output_shape = graph_output->Shape();
110+
bool symbolic_dim_in_shape = false;
111+
112+
if (graph_output_shape) {
113+
TensorShape output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
114+
115+
// if size < 0 we have a symbolic dimension and need to use a temporary OrtValue in the subgraph execution
116+
if (output_shape.Size() < 0) {
117+
symbolic_dim_in_shape = true;
118+
} else {
119+
auto* tensor = context_.Output(index, output_shape);
120+
121+
if (!tensor)
122+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name());
123+
124+
outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)});
125+
}
126+
}
127+
128+
if (!graph_output_shape || symbolic_dim_in_shape) {
129+
// we still need a value to put in the feeds we give to the execution frame, so just use an empty MLValue
130+
outputs_.push_back({AllocationType::Delayed, {}});
131+
}
132+
} else if (graph_output_type->has_sequence_type()) {
133+
auto* seq_tensor = context_.Output<TensorSeq>(index);
134+
if (!seq_tensor)
135+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name());
136+
outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)});
137+
}
138+
++index;
139+
}
140+
141+
return Status::OK();
142+
}
143+
144+
Status QuadricCustomOpImpl::Execute(const FeedsFetchesManager& ffm) {
145+
Status status = Status::OK();
146+
147+
auto num_inputs = context_.InputCount();
148+
std::vector<OrtValue> feeds;
149+
feeds.reserve(num_inputs);
150+
151+
// This will contain used inputs, so some/all initializers may not be present
152+
for (int i = 0; i < num_inputs; ++i) {
153+
if(info_.used_inputs[i]) {
154+
feeds.push_back(*context_.GetInputMLValue(i));
155+
}
156+
}
157+
158+
std::vector<OrtValue> fetches;
159+
std::unordered_map<size_t, IExecutor::CustomAllocator> fetch_allocators;
160+
161+
fetches.reserve(info_.num_outputs);
162+
for (int i = 0; i < info_.num_outputs; ++i) {
163+
fetches.push_back(outputs_[i].second);
164+
165+
if (outputs_[i].first == AllocationType::Delayed) {
166+
// functor to forward the allocation request from the subgraph to the If node's context so that the
167+
// allocation plan for the If node's output is used.
168+
fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtDevice& location,
169+
OrtValue& ort_value, bool& allocated) {
170+
// if the device the QuadricCustomOp output is allocated on does not match the required device for the subgraph output
171+
// we don't update the provided OrtValue and return false for 'allocated'.
172+
// the execution frame will allocate a buffer on the required device, and the fetches copy
173+
// logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here.
174+
175+
auto* tensor = context_.Output(i, shape);
176+
if (!tensor)
177+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for QuadricCustomOp output ", i);
178+
179+
const OrtValue& value = *context_.GetOutputMLValue(i);
180+
181+
if (tensor->Location().device == location) {
182+
// return OrtValue for allocated tensor
183+
ort_value = value;
184+
allocated = true;
185+
} else {
186+
// put the allocated value into fetches so the copy logic in utils::ExecuteGraphImpl can use it
187+
fetches[i] = value;
188+
}
189+
190+
return Status::OK();
191+
};
192+
}
193+
}
194+
195+
status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, fetch_allocators,
196+
ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(),
197+
context_.Logger(), context_.GetComputeStream());
198+
199+
ORT_RETURN_IF_ERROR(status);
200+
201+
return status;
202+
}
203+
204+
QuadricCustomOp::QuadricCustomOp(const OpKernelInfo& info) : IControlFlowKernel(info) {
205+
ONNX_NAMESPACE::GraphProto proto;
206+
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("sub_graph", &proto).IsOK());
207+
ORT_IGNORE_RETURN_VALUE(proto);
208+
}
209+
210+
Status QuadricCustomOp::Compute(OpKernelContext* ctx) const {
211+
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
212+
auto* session_state = ctx_internal->SubgraphSessionState("sub_graph");
213+
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for sub_graph attribute.");
214+
215+
QuadricCustomOpImpl impl{*ctx_internal, *session_state, *info_};
216+
auto status = impl.Initialize();
217+
ORT_RETURN_IF_ERROR(status);
218+
219+
status = impl.Execute(*feeds_fetches_manager_);
220+
221+
return Status::OK();
222+
}
223+
224+
Status QuadricCustomOp::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) {
225+
const auto& node = Node();
226+
info_ = std::make_unique<QuadricCustomOp::Info>(node, subgraph_session_state.GetGraphViewer());
227+
228+
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
229+
230+
std::vector<std::string> feed_names;
231+
232+
const auto& input_defs = node.InputDefs();
233+
for (size_t i = 0, end = info_->num_inputs; i < end; ++i) {
234+
const auto* input = input_defs[i];
235+
// Not all subgraph inputs will have names that correspond to the node's inputs. The inputs
236+
// that diverge like this are limited *only* to initializers and we don't need to create
237+
// feeds for them. Furthermore, since they are not actually used by the custom op (and
238+
// not even by the sub-graph since the subgraph contains its own version of initializers)
239+
// they end up getting removed from the graph during an optimization step and so we can't
240+
// prove that it's an initializer using Graph::IsInitializedTensor
241+
242+
if (info_->subgraph_input_names.find(input->Name()) != info_->subgraph_input_names.end()) {
243+
feed_names.push_back(input->Name());
244+
info_->used_inputs[i] = true;
245+
}
246+
}
247+
248+
std::unique_ptr<FeedsFetchesManager> ffm;
249+
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, info_->subgraph_output_names, subgraph_map, ffm));
250+
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
251+
252+
// find the location all the feeds will be coming from
253+
std::vector<OrtDevice> feed_locations;
254+
feed_locations.resize(feed_names.size());
255+
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
256+
const auto& location = utils::FindDeviceForValue(session_state, feed_names[i]);
257+
feed_locations[i] = location;
258+
}
259+
260+
std::vector<const OrtDevice*> fetch_locations;
261+
fetch_locations.reserve(info_->num_outputs);
262+
263+
// we need the allocator info for each output from the QuadricCustomOp node
264+
// as the subgraph execution will write directly into those buffers
265+
const auto& outputs = node.OutputDefs();
266+
for (int i = 0, end = info_->num_outputs; i < end; ++i) {
267+
const auto& alloc_info = utils::FindDeviceForValue(session_state, outputs[i]->Name());
268+
fetch_locations.push_back(&alloc_info);
269+
}
270+
271+
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
272+
273+
feeds_fetches_manager_ = std::move(ffm);
274+
275+
return Status::OK();
276+
}
277+
278+
} // namespace onnxruntime
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cpu/controlflow/utils.h"
5+
6+
#include "core/framework/feeds_fetches_manager.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "core/framework/op_kernel_context_internal.h"
9+
#include "core/session/onnxruntime_cxx_api.h"
10+
#include <vector>
11+
#include <unordered_set>
12+
13+
namespace onnxruntime {
14+
15+
struct QuadricCustomOp : public controlflow::IControlFlowKernel {
16+
QuadricCustomOp(const OpKernelInfo& info);
17+
18+
Status Compute(OpKernelContext* ctx) const override;
19+
20+
virtual Status SetupSubgraphExecutionInfo(const SessionState& session_state,
21+
const std::string& attribute_name,
22+
const SessionState& subgraph_session_state) override;
23+
24+
struct Info {
25+
Info(const onnxruntime::Node& node, const GraphViewer& subgraph_in);
26+
const GraphViewer& subgraph;
27+
28+
int num_inputs;
29+
int num_outputs;
30+
31+
std::unordered_set<std::string> subgraph_input_names;
32+
std::vector<bool> used_inputs;
33+
std::vector<std::string> subgraph_output_names;
34+
};
35+
36+
private:
37+
std::unique_ptr<Info> info_;
38+
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
39+
};
40+
41+
} // namespace onnxruntime

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,6 +2913,25 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
29132913
"Allow inputs and outputs to be any kind of tensor.");
29142914
#endif
29152915

2916+
ONNX_CONTRIB_OPERATOR_SCHEMA(QuadricCustomOp)
2917+
.SetDomain(kQuadricDomain)
2918+
.SinceVersion(1)
2919+
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
2920+
.SetDoc("QuadricCustomOp")
2921+
.Input(0, "inputs", "QuadricCustomOp inputs.", "T", OpSchema::Variadic,
2922+
/*is_homogeneous*/ false,
2923+
/*min_arity*/ 1)
2924+
.Output(0, "outputs", "QuadricCustomOp outputs.", "T", OpSchema::Variadic,
2925+
/*is_homogeneous*/ false,
2926+
/*min_arity*/ 1)
2927+
.AllowUncheckedAttributes()
2928+
.Attr("ccl_func_name", "Name of CCL function.", AttributeProto::STRING)
2929+
.Attr("sub_graph", "Replaced sub-graph.", AttributeProto::GRAPH)
2930+
.Attr("element_wise", "True (1) if only element-wise ops, False (0) otherwise", AttributeProto::INT, true)
2931+
.TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(),
2932+
"Allow inputs and outputs to be any kind of tensor.");
2933+
// FIXME: Add a type/shape inference function
2934+
29162935
#ifdef ENABLE_TRAINING_OPS
29172936
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
29182937
// 2). this is needed by inference for other purpose.

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
958958
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan);
959959
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape);
960960

961+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuadricCustomOp);
962+
961963
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
962964

963965
/* *** IMPORTANT! ***
@@ -2383,6 +2385,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
23832385
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize)>,
23842386
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan)>,
23852387
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape)>,
2388+
2389+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuadricCustomOp)>,
23862390
};
23872391

23882392
for (auto& function_table_entry : function_table) {

0 commit comments

Comments
 (0)