2020// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121// SOFTWARE.
2222
23- #include " src/dali_backend.h"
24-
2523#include < memory>
2624
25+ #include " src/dali_executor/dali_executor.h"
2726#include " src/dali_executor/io_buffer.h"
2827#include " src/dali_executor/utils/dali.h"
2928#include " src/dali_executor/utils/utils.h"
29+ #include " src/model_provider/model_provider.h"
30+ #include " src/utils/timing.h"
3031#include " src/utils/triton.h"
32+ #include " triton/backend/backend_common.h"
3133#include " triton/backend/backend_model.h"
3234#include " triton/backend/backend_model_instance.h"
3335
@@ -157,7 +159,7 @@ TRITONSERVER_Error* DaliModel::Create(TRITONBACKEND_Model* triton_model, DaliMod
157159}
158160
159161struct RequestMeta {
160- uint64_t compute_start_ns, compute_end_ns;
162+ TimeInterval compute_interval; // nanoseconds
161163 int batch_size;
162164};
163165
@@ -175,20 +177,36 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
175177 return *dali_model_;
176178 }
177179
178- RequestMeta ProcessRequest (TRITONBACKEND_Response* response, TritonRequest& request ) {
180+ void Execute ( const std::vector< TritonRequest>& requests ) {
179181 DeviceGuard dg (device_id_);
180- RequestMeta ret;
181- auto & outputs_indices = dali_model_->GetOutputOrder ();
182+ int total_batch_size = 0 ;
183+ TimeInterval batch_compute_interval{};
184+ TimeInterval batch_exec_interval{};
185+ start_timer_ns (batch_exec_interval);
186+ for (size_t i = 0 ; i < requests.size (); i++) {
187+ TimeInterval req_exec_interval{};
188+ start_timer_ns (req_exec_interval);
189+ auto response = TritonResponse::New (requests[i]);
190+ RequestMeta request_meta;
191+ TritonError error{};
192+ try {
193+ request_meta = ProcessRequest (response, requests[i]);
194+ } catch (...) { error = ErrorHandler (); }
195+
196+ if (i == 0 ) {
197+ batch_compute_interval.start = request_meta.compute_interval .start ;
198+ } else if (i == requests.size () - 1 ) {
199+ batch_compute_interval.end = request_meta.compute_interval .end ;
200+ }
182201
183- auto dali_inputs = GenerateInputs (request);
184- ret.batch_size = dali_inputs[0 ].meta .shape .num_samples (); // Batch size is expected to be the
185- // same in every input
186- ret.compute_start_ns = detail::capture_time ();
187- auto outputs_info = dali_executor_->Run (dali_inputs);
188- ret.compute_end_ns = detail::capture_time ();
189- auto dali_outputs = detail::AllocateOutputs (request, response, outputs_info, outputs_indices);
190- dali_executor_->PutOutputs (dali_outputs);
191- return ret;
202+ end_timer_ns (req_exec_interval);
203+ ReportStats (requests[i], req_exec_interval, request_meta.compute_interval , !error);
204+ SendResponse (std::move (response), std::move (error));
205+
206+ total_batch_size += request_meta.batch_size ;
207+ }
208+ end_timer_ns (batch_exec_interval);
209+ ReportBatchStats (total_batch_size, batch_exec_interval, batch_compute_interval);
192210 }
193211
194212 private:
@@ -201,6 +219,37 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
201219 dali_executor_ = std::make_unique<DaliExecutor>(std::move (pipeline));
202220 }
203221
222+ void ReportStats (TritonRequestView request, TimeInterval exec, TimeInterval compute,
223+ bool success) {
224+ LOG_IF_ERROR (TRITONBACKEND_ModelInstanceReportStatistics (triton_model_instance_, request,
225+ success, exec.start , compute.start ,
226+ compute.end , exec.end ),
227+ " Failed reporting request statistics." );
228+ }
229+
230+ void ReportBatchStats (uint32_t total_batch_size, TimeInterval exec, TimeInterval compute) {
231+ LOG_IF_ERROR (TRITONBACKEND_ModelInstanceReportBatchStatistics (
232+ triton_model_instance_, total_batch_size, exec.start , compute.start ,
233+ compute.end , exec.end ),
234+ " Failed reporting batch statistics." );
235+ }
236+
237+ /* * Run inference for a given \p request and prepare a response. */
238+ RequestMeta ProcessRequest (TritonResponseView response, TritonRequestView request) {
239+ RequestMeta ret;
240+
241+ auto dali_inputs = GenerateInputs (request);
242+ ret.batch_size = dali_inputs[0 ].meta .shape .num_samples (); // Batch size is expected to be the
243+ // same in every input
244+ start_timer_ns (ret.compute_interval );
245+ auto outputs_info = dali_executor_->Run (dali_inputs);
246+ end_timer_ns (ret.compute_interval );
247+ auto dali_outputs = AllocateOutputs (request, response, outputs_info);
248+ dali_executor_->PutOutputs (dali_outputs);
249+ return ret;
250+ }
251+
252+ /* * @brief Generate descriptors of inputs provided by a given request. */
204253 std::vector<IDescr> GenerateInputs (TritonRequestView request) {
205254 uint32_t input_cnt = request.InputCount ();
206255 std::vector<IDescr> ret;
@@ -220,6 +269,61 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
220269 return ret;
221270 }
222271
272+ /* *
273+ * @brief Allocate outputs required by a given request.
274+ *
275+ * Lifetime of the created buffer is bound to the \p response
276+ */
277+ std::vector<ODescr> AllocateOutputs (TritonRequestView request, TritonResponseView response,
278+ const std::vector<OutputInfo>& outputs_info) {
279+ uint32_t output_cnt = request.OutputCount ();
280+ ENFORCE (outputs_info.size () == output_cnt,
281+ make_string (" Number of outputs in the model configuration (" , output_cnt,
282+ " ) does not match to the number of outputs from DALI pipeline (" ,
283+ outputs_info.size (), " )" ));
284+ const auto & output_indices = dali_model_->GetOutputOrder ();
285+ std::vector<ODescr> outputs (output_cnt);
286+ outputs.reserve (output_cnt);
287+ for (uint32_t i = 0 ; i < output_cnt; ++i) {
288+ auto name = request.OutputName (i);
289+ int output_idx = output_indices.at (name);
290+ IOMeta out_meta{};
291+ out_meta.name = name;
292+ out_meta.type = outputs_info[output_idx].type ;
293+ out_meta.shape = outputs_info[output_idx].shape ;
294+ auto output = response.GetOutput (out_meta);
295+ auto buffer = output.AllocateBuffer (outputs_info[output_idx].device , device_id_);
296+ outputs[output_idx] = {out_meta, {buffer}};
297+ }
298+ return outputs;
299+ }
300+
301+ TritonError ErrorHandler () {
302+ TritonError error{};
303+ try {
304+ throw ;
305+ } catch (TritonError& e) {
306+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
307+ error = std::move (e);
308+ } catch (DaliBackendException& e) {
309+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
310+ error = TritonError::Unknown (make_string (" DALI Backend error: " , e.what ()));
311+ } catch (DALIException& e) {
312+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
313+ error = TritonError::Unknown (make_string (" DALI error: " , e.what ()));
314+ } catch (std::runtime_error& e) {
315+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
316+ error = TritonError::Unknown (make_string (" Runtime error: " , e.what ()));
317+ } catch (std::exception& e) {
318+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
319+ error = TritonError::Unknown (make_string (" Exception: " , e.what ()));
320+ } catch (...) {
321+ LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (" Unknown error" ));
322+ error = TritonError::Unknown (" Unknown error" );
323+ }
324+ return error;
325+ }
326+
223327 std::unique_ptr<DaliExecutor> dali_executor_;
224328 DaliModel* dali_model_;
225329};
@@ -441,76 +545,18 @@ TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInsta
441545TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute (TRITONBACKEND_ModelInstance* instance,
442546 TRITONBACKEND_Request** reqs,
443547 const uint32_t request_count) {
444- DaliModelInstance* instance_state;
548+ std::vector<TritonRequest> requests;
549+ for (uint32_t idx = 0 ; idx < request_count; ++idx) {
550+ requests.emplace_back (reqs[idx]);
551+ }
552+ DaliModelInstance* dali_instance;
445553 RETURN_IF_ERROR (
446- TRITONBACKEND_ModelInstanceState (instance, reinterpret_cast <void **>(&instance_state )));
554+ TRITONBACKEND_ModelInstanceState (instance, reinterpret_cast <void **>(&dali_instance )));
447555 std::vector<TRITONBACKEND_Response*> responses (request_count);
448556
449- int total_batch_size = 0 ;
450- uint64_t exec_start_ns = 0 , exec_end_ns = 0 , batch_exec_start_ns = 0 , batch_exec_end_ns = 0 ,
451- batch_compute_start_ns = 0 , batch_compute_end_ns = 0 ;
452- batch_exec_start_ns = detail::capture_time ();
453- for (size_t i = 0 ; i < responses.size (); i++) {
454- TritonRequest request (reqs[i]);
455- TRITONSERVER_Error* error = nullptr ; // success
456- exec_start_ns = detail::capture_time ();
457- // TODO Do not process requests one by one, but gather all
458- // into one buffer and process it in DALI all together
459- LOG_IF_ERROR (TRITONBACKEND_ResponseNew (&responses[i], request),
460- make_string (" Failed creating a response, idx: " , i));
461- RequestMeta request_meta;
462-
463- try {
464- request_meta = instance_state->ProcessRequest (responses[i], request);
465- } catch (DaliBackendException& e) {
466- LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
467- error = TRITONSERVER_ErrorNew (TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
468- make_string (" DALI Backend error: " , e.what ()).c_str ());
469- } catch (DALIException& e) {
470- LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
471- error = TRITONSERVER_ErrorNew (TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
472- make_string (" DALI error: " , e.what ()).c_str ());
473- } catch (std::runtime_error& e) {
474- LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
475- error = TRITONSERVER_ErrorNew (TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
476- make_string (" runtime error: " , e.what ()).c_str ());
477- } catch (std::exception& e) {
478- LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (e.what ()));
479- error = TRITONSERVER_ErrorNew (TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
480- make_string (" exception: " , e.what ()).c_str ());
481- } catch (...) {
482- LOG_MESSAGE (TRITONSERVER_LOG_ERROR, (" Unknown error" ));
483- error = TRITONSERVER_ErrorNew (TRITONSERVER_Error_Code::TRITONSERVER_ERROR_UNKNOWN,
484- " Unknown error" );
485- }
486-
487- exec_end_ns = detail::capture_time ();
488- batch_compute_start_ns = batch_compute_start_ns == 0 ?
489- request_meta.compute_start_ns :
490- batch_compute_start_ns; // Ternary to please the compiler
491-
492- LOG_IF_ERROR (TRITONBACKEND_ModelInstanceReportStatistics (
493- instance, request, !error, exec_start_ns, request_meta.compute_start_ns ,
494- request_meta.compute_end_ns , exec_end_ns),
495- make_string (" Failed reporting statistics for response idx " , i));
496-
497- LOG_IF_ERROR (
498- TRITONBACKEND_ResponseSend (
499- responses[i], TRITONSERVER_ResponseCompleteFlag::TRITONSERVER_RESPONSE_COMPLETE_FINAL,
500- error),
501- make_string (" Failed sending response, idx " , i));
502-
503- total_batch_size += request_meta.batch_size ;
504- batch_exec_end_ns = exec_end_ns;
505- batch_compute_end_ns = request_meta.compute_end_ns ;
506- }
507- if (batch_exec_end_ns == 0 )
508- batch_exec_end_ns = detail::capture_time ();
509-
510- LOG_IF_ERROR (TRITONBACKEND_ModelInstanceReportBatchStatistics (
511- instance, total_batch_size, batch_exec_start_ns, batch_compute_start_ns,
512- batch_compute_end_ns, batch_exec_end_ns),
513- make_string (" Failed reporting batch statistics" ));
557+ try {
558+ dali_instance->Execute (requests);
559+ } catch (TritonError& err) { return err.release (); }
514560
515561 return nullptr ;
516562}
0 commit comments