Skip to content

Commit 191ae68

Browse files
authored
Enable dynamic batching. (#76)
Enable dynamic batching. Signed-off-by: Rafal <[email protected]>
1 parent fa927e6 commit 191ae68

File tree

8 files changed

+222
-97
lines changed

8 files changed

+222
-97
lines changed

qa/L0_DALI_GPU_ensemble/client.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from itertools import cycle, islice
2626
from numpy.random import randint
2727
import argparse
28+
import asyncio
29+
from concurrent.futures import ThreadPoolExecutor, as_completed
2830

2931
# TODO: Extend and move to a separate file
3032
def type_to_string(dtype):
@@ -35,6 +37,14 @@ def type_to_string(dtype):
3537
elif dtype == np.double:
3638
return "FP64"
3739

40+
def grouper(n, iterable):
41+
it = iter(iterable)
42+
while True:
43+
chunk = tuple(islice(it, n))
44+
if not chunk:
45+
return
46+
yield chunk
47+
3848
# TODO: Extend and move to a separate file
3949
class TestClient:
4050
def __init__(self, model_name: str, input_names: Sequence[str], output_names: Sequence[str],
@@ -51,57 +61,67 @@ def _get_input(batch, name):
5161
inp.set_data_from_numpy(batch)
5262
return inp
5363

54-
def run_inference(self, batches):
55-
assert(len(batches) == len(self.input_names))
56-
if (len(batches) > 1):
57-
for b in batches:
58-
assert b.shape[0] == batches[0].shape[0]
59-
inputs = [self._get_input(batch, name) for batch, name in zip(batches, self.input_names)]
64+
def test_infer(self, data, it):
65+
assert(len(data) == len(self.input_names))
66+
if (len(data) > 1):
67+
for b in data:
68+
assert b.shape[0] == data[0].shape[0]
69+
inputs = [self._get_input(batch, name) for batch, name in zip(data, self.input_names)]
6070
outputs = [t_client.InferRequestedOutput(name) for name in self.output_names]
61-
results = self.client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs)
62-
return [results.as_numpy(name) for name in self.output_names]
71+
res = self.client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs)
72+
res_data = [res.as_numpy(name) for name in self.output_names]
73+
return it, data, res_data
6374

6475
def run_tests(self, data, compare_to, n_infers=-1, eps=1e-7):
6576
generator = data if n_infers < 1 else islice(cycle(data), n_infers)
66-
for it, batches in enumerate(generator):
67-
results = self.run_inference(batches)
68-
ref = compare_to(*batches)
69-
assert(len(results) == len(ref))
70-
for out_i, (out, ref_out) in enumerate(zip(results, ref)):
71-
assert out.shape == ref_out.shape
72-
if not np.allclose(out, ref_out, atol=eps):
73-
print("Test failure in iteration", it)
74-
print("Output", out_i)
75-
print("Expected:\n", ref_out)
76-
print("Actual:\n", out)
77-
assert False
78-
print('PASS iteration:', it)
77+
for pack in grouper(self.concurrency, enumerate(generator)):
78+
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
79+
results_f = [executor.submit(self.test_infer, data, it) for it, data in pack]
80+
for future in as_completed(results_f):
81+
it, data, results = future.result()
82+
ref = compare_to(*data)
83+
assert(len(results) == len(ref))
84+
for out_i, (out, ref_out) in enumerate(zip(results, ref)):
85+
assert out.shape == ref_out.shape
86+
if not np.allclose(out, ref_out, atol=eps):
87+
print("Test failure in iteration", it)
88+
print("Output", out_i)
89+
print("Expected:\n", ref_out)
90+
print("Actual:\n", out)
91+
assert False
92+
print('PASS iteration:', it)
7993

8094

8195
# TODO: Use actual DALI pipelines to calculate ground truth
8296
def ref_func(inp1, inp2):
8397
return inp1 * 2 / 3, (inp2 * 3).astype(np.half).astype(np.single) / 2
8498

85-
def random_gen(max_batch_size):
99+
100+
def random_gen(max_batch_size, uniform_groups=1):
86101
while True:
87-
bs = randint(1, max_batch_size + 1)
88102
size1 = randint(100, 300)
89103
size2 = randint(100, 300)
90-
yield np.random.random((bs, size1)).astype(np.single), \
104+
for i in range(uniform_groups):
105+
bs = randint(1, max_batch_size + 1)
106+
yield np.random.random((bs, size1)).astype(np.single), \
91107
np.random.random((bs, size2)).astype(np.single)
92108

93109
def parse_args():
94110
parser = argparse.ArgumentParser()
95111
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',
96112
help='Inference server GRPC URL. Default is localhost:8001.')
97-
parser.add_argument('--n_iters', type=int, required=False, default=1, help='Number of iterations')
113+
parser.add_argument('-n', '--n_iters', type=int, required=False, default=1, help='Number of iterations')
114+
parser.add_argument('-c', '--concurrency', type=int, required=False, default=1,
115+
help='Request concurrency level')
98116
parser.add_argument('-b', '--max_batch_size', type=int, required=False, default=256)
99117
return parser.parse_args()
100118

101119
def main():
102120
args = parse_args()
103-
client = TestClient('dali_ensemble', ['INPUT_0', 'INPUT_1'], ['OUTPUT_0', 'OUTPUT_1'], args.url)
104-
client.run_tests(random_gen(args.max_batch_size), ref_func, n_infers=args.n_iters, eps=1e-4)
121+
client = TestClient('dali_ensemble', ['INPUT_0', 'INPUT_1'], ['OUTPUT_0', 'OUTPUT_1'], args.url,
122+
concurrency=args.concurrency)
123+
client.run_tests(random_gen(args.max_batch_size, args.concurrency), ref_func,
124+
n_infers=args.n_iters, eps=1e-4)
105125

106126
if __name__ == '__main__':
107127
main()

qa/L0_DALI_GPU_ensemble/model_repository/dali_1/config.pbtxt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ output [
5353
dims: [ -1 ]
5454
}
5555
]
56+
57+
dynamic_batching {
58+
preferred_batch_size: [ 256 ]
59+
max_queue_delay_microseconds: 500
60+
}

qa/L0_DALI_GPU_ensemble/model_repository/dali_2/config.pbtxt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ output [
5353
dims: [ -1 ]
5454
}
5555
]
56+
57+
dynamic_batching {
58+
preferred_batch_size: [ 256 ]
59+
max_queue_delay_microseconds: 500
60+
}

qa/L0_DALI_GPU_ensemble/test.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,8 @@
2323

2424
: ${GRPC_ADDR:=${1:-"localhost:8001"}}
2525

26-
echo "RUN CLIENT"
27-
python client.py -u $GRPC_ADDR --n_iters 200
26+
echo "RUN SEQUENTIAL CLIENT"
27+
python client.py -b 256 -u $GRPC_ADDR -n 256
28+
29+
echo "RUN CONCURRENT CLIENT"
30+
python client.py -b 16 -c 16 -u $GRPC_ADDR -n 256

src/dali_backend.cc

Lines changed: 123 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,14 @@ TRITONSERVER_Error* DaliModel::Create(TRITONBACKEND_Model* triton_model, DaliMod
158158
return error;
159159
}
160160

161-
struct RequestMeta {
162-
TimeInterval compute_interval; // nanoseconds
163-
int batch_size;
161+
struct ProcessingMeta {
162+
TimeInterval compute_interval{};
163+
int total_batch_size = 0;
164+
};
165+
166+
struct InputsInfo {
167+
std::vector<IDescr> inputs;
168+
std::vector<int> reqs_batch_sizes; // batch size of each request
164169
};
165170

166171
class DaliModelInstance : public ::triton::backend::BackendModelInstance {
@@ -179,34 +184,22 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
179184

180185
void Execute(const std::vector<TritonRequest>& requests) {
181186
DeviceGuard dg(GetDaliDeviceId());
182-
int total_batch_size = 0;
183-
TimeInterval batch_compute_interval{};
184-
TimeInterval batch_exec_interval{};
185-
start_timer_ns(batch_exec_interval);
186-
for (size_t i = 0; i < requests.size(); i++) {
187-
TimeInterval req_exec_interval{};
188-
start_timer_ns(req_exec_interval);
189-
auto response = TritonResponse::New(requests[i]);
190-
RequestMeta request_meta;
191-
TritonError error{};
192-
try {
193-
request_meta = ProcessRequest(response, requests[i]);
194-
} catch (...) { error = ErrorHandler(); }
195-
196-
if (i == 0) {
197-
batch_compute_interval.start = request_meta.compute_interval.start;
198-
} else if (i == requests.size() - 1) {
199-
batch_compute_interval.end = request_meta.compute_interval.end;
200-
}
201-
202-
end_timer_ns(req_exec_interval);
203-
ReportStats(requests[i], req_exec_interval, request_meta.compute_interval, !error);
204-
SendResponse(std::move(response), std::move(error));
205-
206-
total_batch_size += request_meta.batch_size;
187+
TimeInterval exec_interval{};
188+
start_timer_ns(exec_interval);
189+
auto responses = CreateResponses(requests);
190+
ProcessingMeta proc_meta{};
191+
TritonError error{};
192+
try {
193+
proc_meta = ProcessRequests(requests, responses);
194+
} catch (...) { error = ErrorHandler(); }
195+
for (auto& response : responses) {
196+
SendResponse(std::move(response), TritonError::Copy(error));
197+
}
198+
end_timer_ns(exec_interval);
199+
for (auto& request : requests) {
200+
ReportStats(request, exec_interval, proc_meta.compute_interval, !error);
207201
}
208-
end_timer_ns(batch_exec_interval);
209-
ReportBatchStats(total_batch_size, batch_exec_interval, batch_compute_interval);
202+
ReportBatchStats(proc_meta.total_batch_size, exec_interval, proc_meta.compute_interval);
210203
}
211204

212205
private:
@@ -234,70 +227,133 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
234227
"Failed reporting batch statistics.");
235228
}
236229

237-
/** Run inference for a given \p request and prepare a response. */
238-
RequestMeta ProcessRequest(TritonResponseView response, TritonRequestView request) {
239-
RequestMeta ret;
230+
/**
231+
* @brief Create a response for each request.
232+
*
233+
* @return responses vector
234+
*/
235+
std::vector<TritonResponse> CreateResponses(const std::vector<TritonRequest>& requests) {
236+
std::vector<TritonResponse> responses;
237+
responses.reserve(requests.size());
238+
for (auto& request : requests) {
239+
responses.push_back(TritonResponse::New(request));
240+
}
241+
return responses;
242+
}
240243

241-
auto dali_inputs = GenerateInputs(request);
242-
ret.batch_size = dali_inputs[0].meta.shape.num_samples(); // Batch size is expected to be the
243-
// same in every input
244+
/**
245+
* @brief Run inference for a given \p request and prepare a response.
246+
* @return computation time interval and total batch size
247+
*/
248+
ProcessingMeta ProcessRequests(const std::vector<TritonRequest>& requests,
249+
const std::vector<TritonResponse>& responses) {
250+
ProcessingMeta ret{};
251+
auto inputs_info = GenerateInputs(requests);
244252
start_timer_ns(ret.compute_interval);
245-
auto outputs_info = dali_executor_->Run(dali_inputs);
253+
auto outputs_info = dali_executor_->Run(inputs_info.inputs);
246254
end_timer_ns(ret.compute_interval);
247-
auto dali_outputs = AllocateOutputs(request, response, outputs_info);
255+
for (auto& bs : inputs_info.reqs_batch_sizes) {
256+
ret.total_batch_size += bs;
257+
}
258+
auto dali_outputs =
259+
AllocateOutputs(requests, responses, inputs_info.reqs_batch_sizes, outputs_info);
248260
dali_executor_->PutOutputs(dali_outputs);
249261
return ret;
250262
}
251263

252-
/** @brief Generate descriptors of inputs provided by a given request. */
253-
std::vector<IDescr> GenerateInputs(TritonRequestView request) {
254-
uint32_t input_cnt = request.InputCount();
255-
std::vector<IDescr> ret;
256-
ret.reserve(input_cnt);
257-
for (uint32_t input_idx = 0; input_idx < input_cnt; ++input_idx) {
258-
auto input = request.InputByIdx(input_idx);
259-
auto input_byte_size = input.ByteSize();
260-
auto input_buffer_count = input.BufferCount();
261-
std::vector<IBufferDescr> buffers;
262-
buffers.reserve(input_buffer_count);
263-
for (uint32_t buffer_idx = 0; buffer_idx < input_buffer_count; ++buffer_idx) {
264-
auto buffer = input.GetBuffer(buffer_idx, device_type_t::CPU, GetDaliDeviceId());
265-
buffers.push_back(buffer);
264+
/**
265+
* @brief Generate descriptors of inputs provided by given \p requests
266+
* @return input descriptors and batch size of each request
267+
*/
268+
InputsInfo GenerateInputs(const std::vector<TritonRequest>& requests) {
269+
uint32_t input_cnt = requests[0].InputCount();
270+
std::vector<IDescr> inputs;
271+
inputs.reserve(input_cnt);
272+
std::unordered_map<std::string, IDescr> input_map;
273+
std::vector<int> reqs_batch_sizes(requests.size());
274+
for (size_t ri = 0; ri < requests.size(); ++ri) {
275+
auto& request = requests[ri];
276+
ENFORCE(request.InputCount() == input_cnt,
277+
"Each request must provide the same number of inputs.");
278+
for (uint32_t input_idx = 0; input_idx < input_cnt; ++input_idx) {
279+
auto input = request.InputByIdx(input_idx);
280+
auto input_byte_size = input.ByteSize();
281+
auto input_buffer_count = input.BufferCount();
282+
auto meta = input.Meta();
283+
auto& idescr = input_map[meta.name];
284+
for (uint32_t buffer_idx = 0; buffer_idx < input_buffer_count; ++buffer_idx) {
285+
auto buffer = input.GetBuffer(buffer_idx, device_type_t::CPU, GetDaliDeviceId());
286+
idescr.buffers.push_back(buffer);
287+
}
288+
if (idescr.meta.shape.num_samples() == 0) {
289+
idescr.meta = meta;
290+
} else {
291+
ENFORCE(idescr.meta.type == meta.type,
292+
make_string("Mismatched type for input ", idescr.meta.name));
293+
idescr.meta.shape.append(meta.shape);
294+
}
295+
if (input_idx == 0) {
296+
reqs_batch_sizes[ri] = meta.shape.num_samples();
297+
} else {
298+
ENFORCE(meta.shape.num_samples() == reqs_batch_sizes[ri],
299+
"Each input in a request must have the same batch size.");
300+
}
266301
}
267-
ret.push_back({input.Meta(), std::move(buffers)});
268302
}
269-
return ret;
303+
for (const auto& descrs : input_map) {
304+
inputs.push_back(descrs.second);
305+
}
306+
return {inputs, reqs_batch_sizes};
270307
}
271308

272309
int32_t GetDaliDeviceId() {
273310
return !CudaStream() ? ::dali::CPU_ONLY_DEVICE_ID : device_id_;
274311
}
275312

276313
/**
277-
* @brief Allocate outputs required by a given request.
314+
* @brief Allocate outputs expected by given \p requests.
278315
*
279-
* Lifetime of the created buffer is bound to the \p response
316+
* Lifetime of the created buffer is bound to each of the \p responses
317+
* @param batch_sizes batch size of each request
280318
*/
281-
std::vector<ODescr> AllocateOutputs(TritonRequestView request, TritonResponseView response,
319+
std::vector<ODescr> AllocateOutputs(const std::vector<TritonRequest>& requests,
320+
const std::vector<TritonResponse>& responses,
321+
const std::vector<int>& batch_sizes,
282322
const std::vector<OutputInfo>& outputs_info) {
283-
uint32_t output_cnt = request.OutputCount();
323+
assert(requests.size() > 0);
324+
assert(requests.size() == responses.size());
325+
assert(requests.size() == batch_sizes.size());
326+
uint32_t output_cnt = requests[0].OutputCount();
327+
for (auto& req : requests) {
328+
ENFORCE(output_cnt == req.OutputCount(),
329+
"All of the requests must expect the same number of outputs.");
330+
}
284331
ENFORCE(outputs_info.size() == output_cnt,
285-
make_string("Number of outputs in the model configuration (", output_cnt,
286-
") does not match to the number of outputs from DALI pipeline (",
287-
outputs_info.size(), ")"));
332+
make_string("Number of outputs expected by the requests (", output_cnt,
333+
") does not match the number of outputs from DALI pipeline (",
334+
outputs_info.size(), ")."));
288335
const auto& output_indices = dali_model_->GetOutputOrder();
336+
ENFORCE(output_cnt == output_indices.size(),
337+
make_string("Number of outputs exptected by the requests (", output_cnt,
338+
") does not match the number of outputs in the config (",
339+
output_indices.size(), ")."));
289340
std::vector<ODescr> outputs(output_cnt);
290341
outputs.reserve(output_cnt);
291-
for (uint32_t i = 0; i < output_cnt; ++i) {
292-
auto name = request.OutputName(i);
293-
int output_idx = output_indices.at(name);
342+
for (const auto& out_index : output_indices) {
343+
auto name = out_index.first;
344+
int output_idx = out_index.second;
345+
auto shapes = split_list_shape(outputs_info[output_idx].shape, batch_sizes);
346+
std::vector<OBufferDescr> buffers(requests.size());
294347
IOMeta out_meta{};
295348
out_meta.name = name;
296349
out_meta.type = outputs_info[output_idx].type;
350+
for (size_t ri = 0; ri < requests.size(); ++ri) {
351+
out_meta.shape = shapes[ri];
352+
auto output = responses[ri].GetOutput(out_meta);
353+
buffers[ri] = output.AllocateBuffer(outputs_info[output_idx].device, GetDaliDeviceId());
354+
}
297355
out_meta.shape = outputs_info[output_idx].shape;
298-
auto output = response.GetOutput(out_meta);
299-
auto buffer = output.AllocateBuffer(outputs_info[output_idx].device, GetDaliDeviceId());
300-
outputs[output_idx] = {out_meta, {buffer}};
356+
outputs[output_idx] = {out_meta, buffers};
301357
}
302358
return outputs;
303359
}

0 commit comments

Comments
 (0)