@@ -158,9 +158,14 @@ TRITONSERVER_Error* DaliModel::Create(TRITONBACKEND_Model* triton_model, DaliMod
158158 return error;
159159}
160160
161- struct RequestMeta {
162- TimeInterval compute_interval; // nanoseconds
163- int batch_size;
161+ struct ProcessingMeta {
162+ TimeInterval compute_interval{};
163+ int total_batch_size = 0 ;
164+ };
165+
166+ struct InputsInfo {
167+ std::vector<IDescr> inputs;
168+ std::vector<int > reqs_batch_sizes; // batch size of each request
164169};
165170
166171class DaliModelInstance : public ::triton::backend::BackendModelInstance {
@@ -179,34 +184,22 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
179184
180185 void Execute (const std::vector<TritonRequest>& requests) {
181186 DeviceGuard dg (GetDaliDeviceId ());
182- int total_batch_size = 0 ;
183- TimeInterval batch_compute_interval{};
184- TimeInterval batch_exec_interval{};
185- start_timer_ns (batch_exec_interval);
186- for (size_t i = 0 ; i < requests.size (); i++) {
187- TimeInterval req_exec_interval{};
188- start_timer_ns (req_exec_interval);
189- auto response = TritonResponse::New (requests[i]);
190- RequestMeta request_meta;
191- TritonError error{};
192- try {
193- request_meta = ProcessRequest (response, requests[i]);
194- } catch (...) { error = ErrorHandler (); }
195-
196- if (i == 0 ) {
197- batch_compute_interval.start = request_meta.compute_interval .start ;
198- } else if (i == requests.size () - 1 ) {
199- batch_compute_interval.end = request_meta.compute_interval .end ;
200- }
201-
202- end_timer_ns (req_exec_interval);
203- ReportStats (requests[i], req_exec_interval, request_meta.compute_interval , !error);
204- SendResponse (std::move (response), std::move (error));
205-
206- total_batch_size += request_meta.batch_size ;
187+ TimeInterval exec_interval{};
188+ start_timer_ns (exec_interval);
189+ auto responses = CreateResponses (requests);
190+ ProcessingMeta proc_meta{};
191+ TritonError error{};
192+ try {
193+ proc_meta = ProcessRequests (requests, responses);
194+ } catch (...) { error = ErrorHandler (); }
195+ for (auto & response : responses) {
196+ SendResponse (std::move (response), TritonError::Copy (error));
197+ }
198+ end_timer_ns (exec_interval);
199+ for (auto & request : requests) {
200+ ReportStats (request, exec_interval, proc_meta.compute_interval , !error);
207201 }
208- end_timer_ns (batch_exec_interval);
209- ReportBatchStats (total_batch_size, batch_exec_interval, batch_compute_interval);
202+ ReportBatchStats (proc_meta.total_batch_size , exec_interval, proc_meta.compute_interval );
210203 }
211204
212205 private:
@@ -234,70 +227,133 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
234227 " Failed reporting batch statistics." );
235228 }
236229
237- /* * Run inference for a given \p request and prepare a response. */
238- RequestMeta ProcessRequest (TritonResponseView response, TritonRequestView request) {
239- RequestMeta ret;
230+ /* *
231+ * @brief Create a response for each request.
232+ *
233+ * @return responses vector
234+ */
235+ std::vector<TritonResponse> CreateResponses (const std::vector<TritonRequest>& requests) {
236+ std::vector<TritonResponse> responses;
237+ responses.reserve (requests.size ());
238+ for (auto & request : requests) {
239+ responses.push_back (TritonResponse::New (request));
240+ }
241+ return responses;
242+ }
240243
241- auto dali_inputs = GenerateInputs (request);
242- ret.batch_size = dali_inputs[0 ].meta .shape .num_samples (); // Batch size is expected to be the
243- // same in every input
244+ /* *
245+ * @brief Run inference for a given \p request and prepare a response.
246+ * @return computation time interval and total batch size
247+ */
248+ ProcessingMeta ProcessRequests (const std::vector<TritonRequest>& requests,
249+ const std::vector<TritonResponse>& responses) {
250+ ProcessingMeta ret{};
251+ auto inputs_info = GenerateInputs (requests);
244252 start_timer_ns (ret.compute_interval );
245- auto outputs_info = dali_executor_->Run (dali_inputs );
253+ auto outputs_info = dali_executor_->Run (inputs_info. inputs );
246254 end_timer_ns (ret.compute_interval );
247- auto dali_outputs = AllocateOutputs (request, response, outputs_info);
255+ for (auto & bs : inputs_info.reqs_batch_sizes ) {
256+ ret.total_batch_size += bs;
257+ }
258+ auto dali_outputs =
259+ AllocateOutputs (requests, responses, inputs_info.reqs_batch_sizes , outputs_info);
248260 dali_executor_->PutOutputs (dali_outputs);
249261 return ret;
250262 }
251263
252- /* * @brief Generate descriptors of inputs provided by a given request. */
253- std::vector<IDescr> GenerateInputs (TritonRequestView request) {
254- uint32_t input_cnt = request.InputCount ();
255- std::vector<IDescr> ret;
256- ret.reserve (input_cnt);
257- for (uint32_t input_idx = 0 ; input_idx < input_cnt; ++input_idx) {
258- auto input = request.InputByIdx (input_idx);
259- auto input_byte_size = input.ByteSize ();
260- auto input_buffer_count = input.BufferCount ();
261- std::vector<IBufferDescr> buffers;
262- buffers.reserve (input_buffer_count);
263- for (uint32_t buffer_idx = 0 ; buffer_idx < input_buffer_count; ++buffer_idx) {
264- auto buffer = input.GetBuffer (buffer_idx, device_type_t ::CPU, GetDaliDeviceId ());
265- buffers.push_back (buffer);
264+ /* *
265+ * @brief Generate descriptors of inputs provided by given \p requests
266+ * @return input descriptors and batch size of each request
267+ */
268+ InputsInfo GenerateInputs (const std::vector<TritonRequest>& requests) {
269+ uint32_t input_cnt = requests[0 ].InputCount ();
270+ std::vector<IDescr> inputs;
271+ inputs.reserve (input_cnt);
272+ std::unordered_map<std::string, IDescr> input_map;
273+ std::vector<int > reqs_batch_sizes (requests.size ());
274+ for (size_t ri = 0 ; ri < requests.size (); ++ri) {
275+ auto & request = requests[ri];
276+ ENFORCE (request.InputCount () == input_cnt,
277+ " Each request must provide the same number of inputs." );
278+ for (uint32_t input_idx = 0 ; input_idx < input_cnt; ++input_idx) {
279+ auto input = request.InputByIdx (input_idx);
280+ auto input_byte_size = input.ByteSize ();
281+ auto input_buffer_count = input.BufferCount ();
282+ auto meta = input.Meta ();
283+ auto & idescr = input_map[meta.name ];
284+ for (uint32_t buffer_idx = 0 ; buffer_idx < input_buffer_count; ++buffer_idx) {
285+ auto buffer = input.GetBuffer (buffer_idx, device_type_t ::CPU, GetDaliDeviceId ());
286+ idescr.buffers .push_back (buffer);
287+ }
288+ if (idescr.meta .shape .num_samples () == 0 ) {
289+ idescr.meta = meta;
290+ } else {
291+ ENFORCE (idescr.meta .type == meta.type ,
292+ make_string (" Mismatched type for input " , idescr.meta .name ));
293+ idescr.meta .shape .append (meta.shape );
294+ }
295+ if (input_idx == 0 ) {
296+ reqs_batch_sizes[ri] = meta.shape .num_samples ();
297+ } else {
298+ ENFORCE (meta.shape .num_samples () == reqs_batch_sizes[ri],
299+ " Each input in a request must have the same batch size." );
300+ }
266301 }
267- ret.push_back ({input.Meta (), std::move (buffers)});
268302 }
269- return ret;
303+ for (const auto & descrs : input_map) {
304+ inputs.push_back (descrs.second );
305+ }
306+ return {inputs, reqs_batch_sizes};
270307 }
271308
272309 int32_t GetDaliDeviceId () {
273310 return !CudaStream () ? ::dali::CPU_ONLY_DEVICE_ID : device_id_;
274311 }
275312
276313 /* *
277- * @brief Allocate outputs required by a given request .
314+ * @brief Allocate outputs expected by given \p requests .
278315 *
279- * Lifetime of the created buffer is bound to the \p response
316+ * Lifetime of the created buffer is bound to each of the \p responses
317+ * @param batch_sizes batch size of each request
280318 */
281- std::vector<ODescr> AllocateOutputs (TritonRequestView request, TritonResponseView response,
319+ std::vector<ODescr> AllocateOutputs (const std::vector<TritonRequest>& requests,
320+ const std::vector<TritonResponse>& responses,
321+ const std::vector<int >& batch_sizes,
282322 const std::vector<OutputInfo>& outputs_info) {
283- uint32_t output_cnt = request.OutputCount ();
323+ assert (requests.size () > 0 );
324+ assert (requests.size () == responses.size ());
325+ assert (requests.size () == batch_sizes.size ());
326+ uint32_t output_cnt = requests[0 ].OutputCount ();
327+ for (auto & req : requests) {
328+ ENFORCE (output_cnt == req.OutputCount (),
329+ " All of the requests must expect the same number of outputs." );
330+ }
284331 ENFORCE (outputs_info.size () == output_cnt,
285- make_string (" Number of outputs in the model configuration (" , output_cnt,
286- " ) does not match to the number of outputs from DALI pipeline (" ,
287- outputs_info.size (), " )" ));
332+ make_string (" Number of outputs expected by the requests (" , output_cnt,
333+ " ) does not match the number of outputs from DALI pipeline (" ,
334+ outputs_info.size (), " ). " ));
288335 const auto & output_indices = dali_model_->GetOutputOrder ();
336+ ENFORCE (output_cnt == output_indices.size (),
337+ make_string (" Number of outputs exptected by the requests (" , output_cnt,
338+ " ) does not match the number of outputs in the config (" ,
339+ output_indices.size (), " )." ));
289340 std::vector<ODescr> outputs (output_cnt);
290341 outputs.reserve (output_cnt);
291- for (uint32_t i = 0 ; i < output_cnt; ++i) {
292- auto name = request.OutputName (i);
293- int output_idx = output_indices.at (name);
342+ for (const auto & out_index : output_indices) {
343+ auto name = out_index.first ;
344+ int output_idx = out_index.second ;
345+ auto shapes = split_list_shape (outputs_info[output_idx].shape , batch_sizes);
346+ std::vector<OBufferDescr> buffers (requests.size ());
294347 IOMeta out_meta{};
295348 out_meta.name = name;
296349 out_meta.type = outputs_info[output_idx].type ;
350+ for (size_t ri = 0 ; ri < requests.size (); ++ri) {
351+ out_meta.shape = shapes[ri];
352+ auto output = responses[ri].GetOutput (out_meta);
353+ buffers[ri] = output.AllocateBuffer (outputs_info[output_idx].device , GetDaliDeviceId ());
354+ }
297355 out_meta.shape = outputs_info[output_idx].shape ;
298- auto output = response.GetOutput (out_meta);
299- auto buffer = output.AllocateBuffer (outputs_info[output_idx].device , GetDaliDeviceId ());
300- outputs[output_idx] = {out_meta, {buffer}};
356+ outputs[output_idx] = {out_meta, buffers};
301357 }
302358 return outputs;
303359 }
0 commit comments