|
| 1 | +// Copyright (c) Quadric, Inc. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "quadric_custom_op.h" |
| 5 | +#include "core/common/common.h" |
| 6 | +#include "core/framework/op_kernel_context_internal.h" |
| 7 | +#include "core/framework/ortdevice.h" |
| 8 | +#include "core/framework/ortmemoryinfo.h" |
| 9 | +#include "core/framework/session_options.h" |
| 10 | +#include "core/framework/session_state.h" |
| 11 | +#include "core/framework/tensorprotoutils.h" |
| 12 | +#include "core/framework/utils.h" |
| 13 | + |
| 14 | +using namespace ONNX_NAMESPACE; |
| 15 | +using namespace onnxruntime::common; |
| 16 | + |
| 17 | +namespace onnxruntime { |
| 18 | + |
| 19 | +ONNX_OPERATOR_KERNEL_EX(QuadricCustomOp, kQuadricDomain, 1, kCpuExecutionProvider, KernelDefBuilder(), QuadricCustomOp); |
| 20 | + |
| 21 | +QuadricCustomOp::Info::Info(const onnxruntime::Node& node, const GraphViewer& subgraph_in) : subgraph(subgraph_in), used_inputs(node.InputDefs().size(), false) { |
| 22 | + num_inputs = static_cast<int>(node.InputDefs().size()); |
| 23 | + num_outputs = static_cast<int>(node.OutputDefs().size()); |
| 24 | + |
| 25 | + auto& subgraph_inputs = subgraph.GetInputs(); |
| 26 | + auto num_subgraph_inputs = subgraph_inputs.size(); |
| 27 | + |
| 28 | + for (size_t i = 0; i < num_subgraph_inputs; ++i) { |
| 29 | + auto& input = subgraph_inputs[i]; |
| 30 | + subgraph_input_names.insert(input->Name()); |
| 31 | + } |
| 32 | + |
| 33 | + // This is commented out because we include initializers as inputs to the custom op, but |
| 34 | + // *NOT* the sub-graph. As a result, the number of inputs differs. Unfortunately, ORT doesn't do |
| 35 | + // a great job of telling us whether something is truly an initializer or not, so we can't |
| 36 | + // effectively check whether an input is an initializer or not. |
| 37 | + /*ORT_ENFORCE(num_subgraph_inputs == static_cast<size_t>(num_inputs), |
| 38 | + "'QuadricCustomOp' node has ", num_inputs, " inputs which doesn't match the subgraph's ", |
| 39 | + num_subgraph_inputs, " inputs."); |
| 40 | + */ |
| 41 | + |
| 42 | + auto& subgraph_outputs = subgraph.GetOutputs(); |
| 43 | + auto num_subgraph_outputs = subgraph_outputs.size(); |
| 44 | + |
| 45 | + // outputs should always match up, so enforce that. |
| 46 | + ORT_ENFORCE(num_subgraph_outputs == static_cast<size_t>(num_outputs), |
| 47 | + "'QuadricCustomOp' node has ", num_outputs, " outputs which doesn't match the subgraph's ", |
| 48 | + num_subgraph_outputs, " outputs."); |
| 49 | + |
| 50 | + subgraph_output_names.reserve(num_subgraph_outputs); |
| 51 | + for (size_t i = 0; i < num_subgraph_outputs; ++i) { |
| 52 | + auto& output = subgraph_outputs[i]; |
| 53 | + subgraph_output_names.push_back(output->Name()); |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +class QuadricCustomOpImpl { |
| 58 | + public: |
| 59 | + QuadricCustomOpImpl(OpKernelContextInternal& context, |
| 60 | + const SessionState& session_state, |
| 61 | + const QuadricCustomOp::Info& info); |
| 62 | + |
| 63 | + Status Initialize(); |
| 64 | + Status Execute(const FeedsFetchesManager& ffm); |
| 65 | + |
| 66 | + private: |
| 67 | + OpKernelContextInternal& context_; |
| 68 | + const SessionState& session_state_; |
| 69 | + const QuadricCustomOp::Info& info_; |
| 70 | + |
| 71 | + Status AllocateOutputTensors(); |
| 72 | + |
| 73 | + enum class AllocationType { |
| 74 | + Delayed, // allocation of If output will be done by subgraph execution |
| 75 | + SubgraphOutput |
| 76 | + }; |
| 77 | + |
| 78 | + // track where the fetches provided to subgraph execution were allocated. |
| 79 | + std::vector<std::pair<AllocationType, OrtValue>> outputs_; |
| 80 | +}; |
| 81 | + |
| 82 | +QuadricCustomOpImpl::QuadricCustomOpImpl(OpKernelContextInternal& context, |
| 83 | + const SessionState& session_state, |
| 84 | + const QuadricCustomOp::Info& info) : context_(context), |
| 85 | + session_state_(session_state), |
| 86 | + info_(info) {} |
| 87 | + |
| 88 | +Status QuadricCustomOpImpl::Initialize() { |
| 89 | + auto status = AllocateOutputTensors(); |
| 90 | + ORT_RETURN_IF_ERROR(status); |
| 91 | + |
| 92 | + return Status::OK(); |
| 93 | +} |
| 94 | + |
| 95 | +Status QuadricCustomOpImpl::AllocateOutputTensors() { |
| 96 | + // This function mostly copied from if.cc |
| 97 | + Status status = Status::OK(); |
| 98 | + int index = 0; |
| 99 | + |
| 100 | + const GraphViewer& subgraph = session_state_.GetGraphViewer(); |
| 101 | + |
| 102 | + const auto& graph_outputs = subgraph.GetOutputs(); |
| 103 | + |
| 104 | + for (auto& graph_output : graph_outputs) { |
| 105 | + const auto* graph_output_type = graph_output->TypeAsProto(); |
| 106 | + |
| 107 | + ORT_ENFORCE(graph_output_type->has_tensor_type() || graph_output_type->has_sequence_type(), "Only tensors or tensor sequences are supported"); |
| 108 | + if (graph_output_type->has_tensor_type()) { |
| 109 | + auto* graph_output_shape = graph_output->Shape(); |
| 110 | + bool symbolic_dim_in_shape = false; |
| 111 | + |
| 112 | + if (graph_output_shape) { |
| 113 | + TensorShape output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape); |
| 114 | + |
| 115 | + // if size < 0 we have a symbolic dimension and need to use a temporary OrtValue in the subgraph execution |
| 116 | + if (output_shape.Size() < 0) { |
| 117 | + symbolic_dim_in_shape = true; |
| 118 | + } else { |
| 119 | + auto* tensor = context_.Output(index, output_shape); |
| 120 | + |
| 121 | + if (!tensor) |
| 122 | + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name()); |
| 123 | + |
| 124 | + outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)}); |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + if (!graph_output_shape || symbolic_dim_in_shape) { |
| 129 | + // we still need a value to put in the feeds we give to the execution frame, so just use an empty MLValue |
| 130 | + outputs_.push_back({AllocationType::Delayed, {}}); |
| 131 | + } |
| 132 | + } else if (graph_output_type->has_sequence_type()) { |
| 133 | + auto* seq_tensor = context_.Output<TensorSeq>(index); |
| 134 | + if (!seq_tensor) |
| 135 | + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name()); |
| 136 | + outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)}); |
| 137 | + } |
| 138 | + ++index; |
| 139 | + } |
| 140 | + |
| 141 | + return Status::OK(); |
| 142 | +} |
| 143 | + |
| 144 | +Status QuadricCustomOpImpl::Execute(const FeedsFetchesManager& ffm) { |
| 145 | + Status status = Status::OK(); |
| 146 | + |
| 147 | + auto num_inputs = context_.InputCount(); |
| 148 | + std::vector<OrtValue> feeds; |
| 149 | + feeds.reserve(num_inputs); |
| 150 | + |
| 151 | + // This will contain used inputs, so some/all initializers may not be present |
| 152 | + for (int i = 0; i < num_inputs; ++i) { |
| 153 | + if(info_.used_inputs[i]) { |
| 154 | + feeds.push_back(*context_.GetInputMLValue(i)); |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + std::vector<OrtValue> fetches; |
| 159 | + std::unordered_map<size_t, IExecutor::CustomAllocator> fetch_allocators; |
| 160 | + |
| 161 | + fetches.reserve(info_.num_outputs); |
| 162 | + for (int i = 0; i < info_.num_outputs; ++i) { |
| 163 | + fetches.push_back(outputs_[i].second); |
| 164 | + |
| 165 | + if (outputs_[i].first == AllocationType::Delayed) { |
| 166 | + // functor to forward the allocation request from the subgraph to the If node's context so that the |
| 167 | + // allocation plan for the If node's output is used. |
| 168 | + fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtDevice& location, |
| 169 | + OrtValue& ort_value, bool& allocated) { |
| 170 | + // if the device the QuadricCustomOp output is allocated on does not match the required device for the subgraph output |
| 171 | + // we don't update the provided OrtValue and return false for 'allocated'. |
| 172 | + // the execution frame will allocate a buffer on the required device, and the fetches copy |
| 173 | + // logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here. |
| 174 | + |
| 175 | + auto* tensor = context_.Output(i, shape); |
| 176 | + if (!tensor) |
| 177 | + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for QuadricCustomOp output ", i); |
| 178 | + |
| 179 | + const OrtValue& value = *context_.GetOutputMLValue(i); |
| 180 | + |
| 181 | + if (tensor->Location().device == location) { |
| 182 | + // return OrtValue for allocated tensor |
| 183 | + ort_value = value; |
| 184 | + allocated = true; |
| 185 | + } else { |
| 186 | + // put the allocated value into fetches so the copy logic in utils::ExecuteGraphImpl can use it |
| 187 | + fetches[i] = value; |
| 188 | + } |
| 189 | + |
| 190 | + return Status::OK(); |
| 191 | + }; |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, fetch_allocators, |
| 196 | + ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), |
| 197 | + context_.Logger(), context_.GetComputeStream()); |
| 198 | + |
| 199 | + ORT_RETURN_IF_ERROR(status); |
| 200 | + |
| 201 | + return status; |
| 202 | +} |
| 203 | + |
| 204 | +QuadricCustomOp::QuadricCustomOp(const OpKernelInfo& info) : IControlFlowKernel(info) { |
| 205 | + ONNX_NAMESPACE::GraphProto proto; |
| 206 | + ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("sub_graph", &proto).IsOK()); |
| 207 | + ORT_IGNORE_RETURN_VALUE(proto); |
| 208 | +} |
| 209 | + |
| 210 | +Status QuadricCustomOp::Compute(OpKernelContext* ctx) const { |
| 211 | + auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx); |
| 212 | + auto* session_state = ctx_internal->SubgraphSessionState("sub_graph"); |
| 213 | + ORT_ENFORCE(session_state, "Subgraph SessionState was not found for sub_graph attribute."); |
| 214 | + |
| 215 | + QuadricCustomOpImpl impl{*ctx_internal, *session_state, *info_}; |
| 216 | + auto status = impl.Initialize(); |
| 217 | + ORT_RETURN_IF_ERROR(status); |
| 218 | + |
| 219 | + status = impl.Execute(*feeds_fetches_manager_); |
| 220 | + |
| 221 | + return Status::OK(); |
| 222 | +} |
| 223 | + |
| 224 | +Status QuadricCustomOp::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { |
| 225 | + const auto& node = Node(); |
| 226 | + info_ = std::make_unique<QuadricCustomOp::Info>(node, subgraph_session_state.GetGraphViewer()); |
| 227 | + |
| 228 | + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); |
| 229 | + |
| 230 | + std::vector<std::string> feed_names; |
| 231 | + |
| 232 | + const auto& input_defs = node.InputDefs(); |
| 233 | + for (size_t i = 0, end = info_->num_inputs; i < end; ++i) { |
| 234 | + const auto* input = input_defs[i]; |
| 235 | + // Not all subgraph inputs will have names that correspond to the node's inputs. The inputs |
| 236 | + // that diverge like this are limited *only* to initializers and we don't need to create |
| 237 | + // feeds for them. Furthermore, since they are not actually used by the custom op (and |
| 238 | + // not even by the sub-graph since the subgraph contains its own version of initializers) |
| 239 | + // they end up getting removed from the graph during an optimization step and so we can't |
| 240 | + // prove that it's an initializer using Graph::IsInitializedTensor |
| 241 | + |
| 242 | + if (info_->subgraph_input_names.find(input->Name()) != info_->subgraph_input_names.end()) { |
| 243 | + feed_names.push_back(input->Name()); |
| 244 | + info_->used_inputs[i] = true; |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + std::unique_ptr<FeedsFetchesManager> ffm; |
| 249 | + ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, info_->subgraph_output_names, subgraph_map, ffm)); |
| 250 | + ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); |
| 251 | + |
| 252 | + // find the location all the feeds will be coming from |
| 253 | + std::vector<OrtDevice> feed_locations; |
| 254 | + feed_locations.resize(feed_names.size()); |
| 255 | + for (size_t i = 0, end = feed_names.size(); i < end; ++i) { |
| 256 | + const auto& location = utils::FindDeviceForValue(session_state, feed_names[i]); |
| 257 | + feed_locations[i] = location; |
| 258 | + } |
| 259 | + |
| 260 | + std::vector<const OrtDevice*> fetch_locations; |
| 261 | + fetch_locations.reserve(info_->num_outputs); |
| 262 | + |
| 263 | + // we need the allocator info for each output from the QuadricCustomOp node |
| 264 | + // as the subgraph execution will write directly into those buffers |
| 265 | + const auto& outputs = node.OutputDefs(); |
| 266 | + for (int i = 0, end = info_->num_outputs; i < end; ++i) { |
| 267 | + const auto& alloc_info = utils::FindDeviceForValue(session_state, outputs[i]->Name()); |
| 268 | + fetch_locations.push_back(&alloc_info); |
| 269 | + } |
| 270 | + |
| 271 | + utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); |
| 272 | + |
| 273 | + feeds_fetches_manager_ = std::move(ffm); |
| 274 | + |
| 275 | + return Status::OK(); |
| 276 | +} |
| 277 | + |
| 278 | +} // namespace onnxruntime |
0 commit comments