Skip to content

Commit bb9204c

Browse files
authored
Triton API refactoring (#71)
Extend Triton API C++ wrappers. Remove dali_backend.h. Signed-off-by: Rafal <[email protected]>
1 parent 9955992 commit bb9204c

File tree

7 files changed

+345
-183
lines changed

7 files changed

+345
-183
lines changed

qa/L0_multi_input/model_repository/dali_multi_input/config.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ input [
3939
output [
4040
{
4141
name: "DALI_unchanged"
42-
data_type: TYPE_INT32
42+
data_type: TYPE_UINT8
4343
dims: [ -1 ]
4444
},
4545
{

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_subdirectory(dali_executor)
2626
add_library(
2727
triton-dali-backend SHARED
2828
dali_backend.cc
29+
utils/triton.cc
2930
)
3031

3132
add_library(

src/dali_backend.cc

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
// SOFTWARE.
2222

23-
#include "src/dali_backend.h"
24-
2523
#include <memory>
2624

25+
#include "src/dali_executor/dali_executor.h"
2726
#include "src/dali_executor/io_buffer.h"
2827
#include "src/dali_executor/utils/dali.h"
2928
#include "src/dali_executor/utils/utils.h"
29+
#include "src/model_provider/model_provider.h"
30+
#include "src/utils/timing.h"
3031
#include "src/utils/triton.h"
32+
#include "triton/backend/backend_common.h"
3133
#include "triton/backend/backend_model.h"
3234
#include "triton/backend/backend_model_instance.h"
3335

@@ -157,7 +159,7 @@ TRITONSERVER_Error* DaliModel::Create(TRITONBACKEND_Model* triton_model, DaliMod
157159
}
158160

159161
struct RequestMeta {
160-
uint64_t compute_start_ns, compute_end_ns;
162+
TimeInterval compute_interval; // nanoseconds
161163
int batch_size;
162164
};
163165

@@ -175,20 +177,36 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
175177
return *dali_model_;
176178
}
177179

178-
RequestMeta ProcessRequest(TRITONBACKEND_Response* response, TritonRequest& request) {
180+
void Execute(const std::vector<TritonRequest>& requests) {
179181
DeviceGuard dg(device_id_);
180-
RequestMeta ret;
181-
auto& outputs_indices = dali_model_->GetOutputOrder();
182+
int total_batch_size = 0;
183+
TimeInterval batch_compute_interval{};
184+
TimeInterval batch_exec_interval{};
185+
start_timer_ns(batch_exec_interval);
186+
for (size_t i = 0; i < requests.size(); i++) {
187+
TimeInterval req_exec_interval{};
188+
start_timer_ns(req_exec_interval);
189+
auto response = TritonResponse::New(requests[i]);
190+
RequestMeta request_meta;
191+
TritonError error{};
192+
try {
193+
request_meta = ProcessRequest(response, requests[i]);
194+
} catch (...) { error = ErrorHandler(); }
195+
196+
if (i == 0) {
197+
batch_compute_interval.start = request_meta.compute_interval.start;
198+
} else if (i == requests.size() - 1) {
199+
batch_compute_interval.end = request_meta.compute_interval.end;
200+
}
182201

183-
auto dali_inputs = GenerateInputs(request);
184-
ret.batch_size = dali_inputs[0].meta.shape.num_samples(); // Batch size is expected to be the
185-
// same in every input
186-
ret.compute_start_ns = detail::capture_time();
187-
auto outputs_info = dali_executor_->Run(dali_inputs);
188-
ret.compute_end_ns = detail::capture_time();
189-
auto dali_outputs = detail::AllocateOutputs(request, response, outputs_info, outputs_indices);
190-
dali_executor_->PutOutputs(dali_outputs);
191-
return ret;
202+
end_timer_ns(req_exec_interval);
203+
ReportStats(requests[i], req_exec_interval, request_meta.compute_interval, !error);
204+
SendResponse(std::move(response), std::move(error));
205+
206+
total_batch_size += request_meta.batch_size;
207+
}
208+
end_timer_ns(batch_exec_interval);
209+
ReportBatchStats(total_batch_size, batch_exec_interval, batch_compute_interval);
192210
}
193211

194212
private:
@@ -201,6 +219,37 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
201219
dali_executor_ = std::make_unique<DaliExecutor>(std::move(pipeline));
202220
}
203221

222+
void ReportStats(TritonRequestView request, TimeInterval exec, TimeInterval compute,
223+
bool success) {
224+
LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics(triton_model_instance_, request,
225+
success, exec.start, compute.start,
226+
compute.end, exec.end),
227+
"Failed reporting request statistics.");
228+
}
229+
230+
void ReportBatchStats(uint32_t total_batch_size, TimeInterval exec, TimeInterval compute) {
231+
LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(
232+
triton_model_instance_, total_batch_size, exec.start, compute.start,
233+
compute.end, exec.end),
234+
"Failed reporting batch statistics.");
235+
}
236+
237+
/** Run inference for a given \p request and prepare a response. */
238+
RequestMeta ProcessRequest(TritonResponseView response, TritonRequestView request) {
239+
RequestMeta ret;
240+
241+
auto dali_inputs = GenerateInputs(request);
242+
ret.batch_size = dali_inputs[0].meta.shape.num_samples(); // Batch size is expected to be the
243+
// same in every input
244+
start_timer_ns(ret.compute_interval);
245+
auto outputs_info = dali_executor_->Run(dali_inputs);
246+
end_timer_ns(ret.compute_interval);
247+
auto dali_outputs = AllocateOutputs(request, response, outputs_info);
248+
dali_executor_->PutOutputs(dali_outputs);
249+
return ret;
250+
}
251+
252+
/** @brief Generate descriptors of inputs provided by a given request. */
204253
std::vector<IDescr> GenerateInputs(TritonRequestView request) {
205254
uint32_t input_cnt = request.InputCount();
206255
std::vector<IDescr> ret;
@@ -220,6 +269,61 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
220269
return ret;
221270
}
222271

272+
/**
273+
* @brief Allocate outputs required by a given request.
274+
*
275+
* Lifetime of the created buffer is bound to the \p response
276+
*/
277+
std::vector<ODescr> AllocateOutputs(TritonRequestView request, TritonResponseView response,
278+
const std::vector<OutputInfo>& outputs_info) {
279+
uint32_t output_cnt = request.OutputCount();
280+
ENFORCE(outputs_info.size() == output_cnt,
281+
make_string("Number of outputs in the model configuration (", output_cnt,
282+
") does not match to the number of outputs from DALI pipeline (",
283+
outputs_info.size(), ")"));
284+
const auto& output_indices = dali_model_->GetOutputOrder();
285+
std::vector<ODescr> outputs(output_cnt);
286+
outputs.reserve(output_cnt);
287+
for (uint32_t i = 0; i < output_cnt; ++i) {
288+
auto name = request.OutputName(i);
289+
int output_idx = output_indices.at(name);
290+
IOMeta out_meta{};
291+
out_meta.name = name;
292+
out_meta.type = outputs_info[output_idx].type;
293+
out_meta.shape = outputs_info[output_idx].shape;
294+
auto output = response.GetOutput(out_meta);
295+
auto buffer = output.AllocateBuffer(outputs_info[output_idx].device, device_id_);
296+
outputs[output_idx] = {out_meta, {buffer}};
297+
}
298+
return outputs;
299+
}
300+
301+
TritonError ErrorHandler() {
302+
TritonError error{};
303+
try {
304+
throw;
305+
} catch (TritonError& e) {
306+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
307+
error = std::move(e);
308+
} catch (DaliBackendException& e) {
309+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
310+
error = TritonError::Unknown(make_string("DALI Backend error: ", e.what()));
311+
} catch (DALIException& e) {
312+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
313+
error = TritonError::Unknown(make_string("DALI error: ", e.what()));
314+
} catch (std::runtime_error& e) {
315+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
316+
error = TritonError::Unknown(make_string("Runtime error: ", e.what()));
317+
} catch (std::exception& e) {
318+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
319+
error = TritonError::Unknown(make_string("Exception: ", e.what()));
320+
} catch (...) {
321+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, ("Unknown error"));
322+
error = TritonError::Unknown("Unknown error");
323+
}
324+
return error;
325+
}
326+
223327
std::unique_ptr<DaliExecutor> dali_executor_;
224328
DaliModel* dali_model_;
225329
};
@@ -441,76 +545,18 @@ TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInsta
441545
TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute(TRITONBACKEND_ModelInstance* instance,
442546
TRITONBACKEND_Request** reqs,
443547
const uint32_t request_count) {
444-
DaliModelInstance* instance_state;
548+
std::vector<TritonRequest> requests;
549+
for (uint32_t idx = 0; idx < request_count; ++idx) {
550+
requests.emplace_back(reqs[idx]);
551+
}
552+
DaliModelInstance* dali_instance;
445553
RETURN_IF_ERROR(
446-
TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast<void**>(&instance_state)));
554+
TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast<void**>(&dali_instance)));
447555
std::vector<TRITONBACKEND_Response*> responses(request_count);
448556

449-
int total_batch_size = 0;
450-
uint64_t exec_start_ns = 0, exec_end_ns = 0, batch_exec_start_ns = 0, batch_exec_end_ns = 0,
451-
batch_compute_start_ns = 0, batch_compute_end_ns = 0;
452-
batch_exec_start_ns = detail::capture_time();
453-
for (size_t i = 0; i < responses.size(); i++) {
454-
TritonRequest request(reqs[i]);
455-
TRITONSERVER_Error* error = nullptr; // success
456-
exec_start_ns = detail::capture_time();
457-
// TODO Do not process requests one by one, but gather all
458-
// into one buffer and process it in DALI all together
459-
LOG_IF_ERROR(TRITONBACKEND_ResponseNew(&responses[i], request),
460-
make_string("Failed creating a response, idx: ", i));
461-
RequestMeta request_meta;
462-
463-
try {
464-
request_meta = instance_state->ProcessRequest(responses[i], request);
465-
} catch (DaliBackendException& e) {
466-
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
467-
error = TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
468-
make_string("DALI Backend error: ", e.what()).c_str());
469-
} catch (DALIException& e) {
470-
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
471-
error = TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
472-
make_string("DALI error: ", e.what()).c_str());
473-
} catch (std::runtime_error& e) {
474-
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
475-
error = TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
476-
make_string("runtime error: ", e.what()).c_str());
477-
} catch (std::exception& e) {
478-
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, (e.what()));
479-
error = TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
480-
make_string("exception: ", e.what()).c_str());
481-
} catch (...) {
482-
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, ("Unknown error"));
483-
error = TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
484-
"Unknown error");
485-
}
486-
487-
exec_end_ns = detail::capture_time();
488-
batch_compute_start_ns = batch_compute_start_ns == 0 ?
489-
request_meta.compute_start_ns :
490-
batch_compute_start_ns; // Ternary to please the compiler
491-
492-
LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics(
493-
instance, request, !error, exec_start_ns, request_meta.compute_start_ns,
494-
request_meta.compute_end_ns, exec_end_ns),
495-
make_string("Failed reporting statistics for response idx ", i));
496-
497-
LOG_IF_ERROR(
498-
TRITONBACKEND_ResponseSend(
499-
responses[i], TRITONSERVER_ResponseCompleteFlag::TRITONSERVER_RESPONSE_COMPLETE_FINAL,
500-
error),
501-
make_string("Failed sending response, idx ", i));
502-
503-
total_batch_size += request_meta.batch_size;
504-
batch_exec_end_ns = exec_end_ns;
505-
batch_compute_end_ns = request_meta.compute_end_ns;
506-
}
507-
if (batch_exec_end_ns == 0)
508-
batch_exec_end_ns = detail::capture_time();
509-
510-
LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(
511-
instance, total_batch_size, batch_exec_start_ns, batch_compute_start_ns,
512-
batch_compute_end_ns, batch_exec_end_ns),
513-
make_string("Failed reporting batch statistics"));
557+
try {
558+
dali_instance->Execute(requests);
559+
} catch (TritonError& err) { return err.release(); }
514560

515561
return nullptr;
516562
}

src/dali_backend.h

Lines changed: 0 additions & 94 deletions
This file was deleted.

0 commit comments

Comments
 (0)