Skip to content

Commit be32a4e

Browse files
authored
Fix CPU-only mode, add test for it (#64)
Signed-off-by: szalpal <[email protected]>
1 parent bb9204c commit be32a4e

File tree

10 files changed

+267
-13
lines changed

10 files changed

+267
-13
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python
2+
3+
# The MIT License (MIT)
4+
#
5+
# Copyright (c) 2020 NVIDIA CORPORATION
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
8+
# this software and associated documentation files (the "Software"), to deal in
9+
# the Software without restriction, including without limitation the rights to
10+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
11+
# the Software, and to permit persons to whom the Software is furnished to do so,
12+
# subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in all
15+
# copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
19+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
20+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
21+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
24+
import argparse, os, sys
25+
import numpy as np
26+
from numpy.random import randint
27+
import tritongrpcclient
28+
from PIL import Image
29+
import math
30+
31+
np.random.seed(100019)
32+
33+
def parse_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
36+
help='Enable verbose output')
37+
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',
38+
help='Inference server URL. Default is localhost:8001.')
39+
parser.add_argument('--batch_size', type=int, required=False, default=4,
40+
help='Batch size')
41+
parser.add_argument('--n_iter', type=int, required=False, default=-1,
42+
help='Number of iterations , with `batch_size` size')
43+
parser.add_argument('--model_name', type=str, required=False, default="dali_identity_cpu",
44+
help='Model name')
45+
return parser.parse_args()
46+
47+
48+
def array_from_list(arrays):
49+
"""
50+
Convert list of ndarrays to single ndarray with ndims+=1
51+
"""
52+
lengths = list(map(lambda x, arr=arrays: arr[x].shape[0], [x for x in range(len(arrays))]))
53+
max_len = max(lengths)
54+
arrays = list(map(lambda arr, ml=max_len: np.pad(arr, ((0, ml - arr.shape[0]))), arrays))
55+
for arr in arrays:
56+
assert arr.shape == arrays[0].shape, "Arrays must have the same shape"
57+
return np.stack(arrays)
58+
59+
60+
def batcher(dataset, max_batch_size, n_iterations=-1):
61+
"""
62+
Generator, that splits dataset into batches with given batch size
63+
"""
64+
iter_idx = 0
65+
data_idx = 0
66+
while data_idx < len(dataset):
67+
if 0 < n_iterations <= iter_idx:
68+
raise StopIteration
69+
batch_size = min(randint(1, max_batch_size), len(dataset) - data_idx)
70+
iter_idx += 1
71+
yield dataset[data_idx : data_idx + batch_size]
72+
data_idx += batch_size
73+
74+
75+
def main():
76+
FLAGS = parse_args()
77+
try:
78+
triton_client = tritongrpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose)
79+
except Exception as e:
80+
print("channel creation failed: " + str(e))
81+
sys.exit(1)
82+
83+
if not (triton_client.is_server_live() or
84+
triton_client.is_server_ready() or
85+
triton_client.is_model_ready(model_name=FLAGS.model_name)):
86+
print("Error connecting to server: Server live {}. Server ready {}. Model ready {}".format(
87+
triton_client.is_server_live, triton_client.is_server_ready,
88+
triton_client.is_model_ready(model_name=FLAGS.model_name)))
89+
sys.exit(1)
90+
91+
model_name = FLAGS.model_name
92+
model_version = -1
93+
94+
input_data = [randint(0, 255, size=randint(100), dtype='uint8') for _ in
95+
range(randint(100) * FLAGS.batch_size)]
96+
input_data = array_from_list(input_data)
97+
98+
# Infer
99+
outputs = []
100+
input_name = "DALI_INPUT_0"
101+
output_name = "DALI_OUTPUT_0"
102+
input_shape = list(input_data.shape)
103+
outputs.append(tritongrpcclient.InferRequestedOutput(output_name))
104+
105+
for batch in batcher(input_data, FLAGS.batch_size):
106+
print("Input mean before backend processing:", np.mean(batch))
107+
input_shape[0] = np.shape(batch)[0]
108+
print("Batch size: ", input_shape[0])
109+
inputs = [tritongrpcclient.InferInput(input_name, input_shape, "UINT8")]
110+
# Initialize the data
111+
inputs[0].set_data_from_numpy(batch)
112+
113+
# Test with outputs
114+
results = triton_client.infer(model_name=model_name,
115+
inputs=inputs,
116+
outputs=outputs)
117+
118+
# Get the output arrays from the results
119+
output0_data = results.as_numpy(output_name)
120+
print("Output mean after backend processing:", np.mean(output0_data))
121+
print("Output shape: ", np.shape(output0_data))
122+
if not math.isclose(np.mean(output0_data), np.mean(batch)):
123+
print("Pre/post average does not match")
124+
sys.exit(1)
125+
else:
126+
print("pass")
127+
128+
statistics = triton_client.get_inference_statistics(model_name=model_name)
129+
if len(statistics.model_stats) != 1:
130+
print("FAILED: Inference Statistics")
131+
sys.exit(1)
132+
133+
134+
if __name__ == '__main__':
135+
main()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# The MIT License (MIT)
2+
#
3+
# Copyright (c) 2020 NVIDIA CORPORATION
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
6+
# this software and associated documentation files (the "Software"), to deal in
7+
# the Software without restriction, including without limitation the rights to
8+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9+
# the Software, and to permit persons to whom the Software is furnished to do so,
10+
# subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21+
22+
name: "dali_identity_cpu"
23+
backend: "dali"
24+
max_batch_size: 256
25+
input [
26+
{
27+
name: "DALI_INPUT_0"
28+
data_type: TYPE_UINT8
29+
dims: [ -1 ]
30+
}
31+
]
32+
33+
output [
34+
{
35+
name: "DALI_OUTPUT_0"
36+
data_type: TYPE_UINT8
37+
dims: [ -1 ]
38+
}
39+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# The MIT License (MIT)
2+
#
3+
# Copyright (c) 2021 NVIDIA CORPORATION
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
6+
# this software and associated documentation files (the "Software"), to deal in
7+
# the Software without restriction, including without limitation the rights to
8+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9+
# the Software, and to permit persons to whom the Software is furnished to do so,
10+
# subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21+
22+
import nvidia.dali as dali
23+
24+
25+
def _parse_args():
26+
import argparse
27+
parser = argparse.ArgumentParser(description="Serialize the pipeline and save it to a file")
28+
parser.add_argument('file_path', type=str, help='The path where to save the serialized pipeline')
29+
return parser.parse_args()
30+
31+
32+
@dali.pipeline_def(batch_size=3, num_threads=1, device_id=None)
33+
def pipe():
34+
data = dali.fn.external_source(device="cpu", name="DALI_INPUT_0")
35+
return data
36+
37+
38+
def main(filename):
39+
pipe().serialize(filename=filename)
40+
41+
42+
if __name__ == '__main__':
43+
args = _parse_args()
44+
main(args.file_path)

qa/L0_identity_cpu/setup.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash -ex
2+
3+
pushd model_repository
4+
5+
mkdir -p dali_identity_cpu/1
6+
python identity_pipeline.py dali_identity_cpu/1/model.dali
7+
echo "Identity model ready."
8+
9+
popd

qa/L0_identity_cpu/test.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash -ex
2+
3+
: ${GRPC_ADDR:=${1:-"localhost:8001"}}
4+
5+
python identity_client.py -u "$GRPC_ADDR"

src/dali_backend.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
178178
}
179179

180180
void Execute(const std::vector<TritonRequest>& requests) {
181-
DeviceGuard dg(device_id_);
181+
DeviceGuard dg(DetermineDeviceId());
182182
int total_batch_size = 0;
183183
TimeInterval batch_compute_interval{};
184184
TimeInterval batch_exec_interval{};
@@ -215,7 +215,7 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
215215
auto serialized_pipeline = dali_model_->GetModelProvider().GetModel();
216216
auto max_batch_size = dali_model_->MaxBatchSize();
217217
auto num_threads = dali_model_->GetModelParamters().GetNumThreads();
218-
DaliPipeline pipeline(serialized_pipeline, max_batch_size, num_threads, device_id_);
218+
DaliPipeline pipeline(serialized_pipeline, max_batch_size, num_threads, DetermineDeviceId());
219219
dali_executor_ = std::make_unique<DaliExecutor>(std::move(pipeline));
220220
}
221221

@@ -261,14 +261,18 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
261261
std::vector<IBufferDescr> buffers;
262262
buffers.reserve(input_buffer_count);
263263
for (uint32_t buffer_idx = 0; buffer_idx < input_buffer_count; ++buffer_idx) {
264-
auto buffer = input.GetBuffer(buffer_idx, device_type_t::CPU, device_id_);
264+
auto buffer = input.GetBuffer(buffer_idx, device_type_t::CPU, DetermineDeviceId());
265265
buffers.push_back(buffer);
266266
}
267267
ret.push_back({input.Meta(), std::move(buffers)});
268268
}
269269
return ret;
270270
}
271271

272+
int32_t DetermineDeviceId() {
273+
return !CudaStream() ? ::dali::CPU_ONLY_DEVICE_ID : device_id_;
274+
}
275+
272276
/**
273277
* @brief Allocate outputs required by a given request.
274278
*
@@ -292,7 +296,7 @@ class DaliModelInstance : public ::triton::backend::BackendModelInstance {
292296
out_meta.type = outputs_info[output_idx].type;
293297
out_meta.shape = outputs_info[output_idx].shape;
294298
auto output = response.GetOutput(out_meta);
295-
auto buffer = output.AllocateBuffer(outputs_info[output_idx].device, device_id_);
299+
auto buffer = output.AllocateBuffer(outputs_info[output_idx].device, DetermineDeviceId());
296300
outputs[output_idx] = {out_meta, {buffer}};
297301
}
298302
return outputs;

src/dali_executor/dali_executor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ void DaliExecutor::SetupInputs(const std::vector<IDescr>& inputs) {
5050
}
5151
}
5252

53+
5354
IDescr DaliExecutor::ScheduleInputCopy(const IDescr& input) {
5455
assert(input.buffers.size() > 0);
5556
IOBufferI* buffer;
@@ -79,6 +80,7 @@ void DaliExecutor::RunInputCopy() {
7980
thread_pool_.RunAll();
8081
}
8182

83+
8284
bool DaliExecutor::IsNoCopy(const IDescr& input) {
8385
return input.buffers.size() == 1 && (input.buffers[0].device == device_type_t::CPU ||
8486
input.buffers[0].device_id == pipeline_.DeviceId());

src/dali_executor/dali_executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ class DaliExecutor {
6161
void SetupInputs(const std::vector<IDescr>& inputs);
6262

6363
/**
64-
* @brief Schedule copy to a continous buffer and return IDecr to the new buffer.
64+
* @brief Schedule a copy off all buffers within input IDescr to a continuous buffer.
65+
* The copy will be performed after calling RunInputCopy().
66+
* @return IDecr to the new, continuous, buffer.
6567
*/
6668
IDescr ScheduleInputCopy(const IDescr& buffers);
6769

src/dali_executor/dali_pipeline.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ void DaliPipeline::SetInput(const IDescr& io_descr) {
7878
}
7979

8080
void DaliPipeline::SyncOutputStream() {
81+
if (NoGpu())
82+
return;
8183
DeviceGuard dg(device_id_);
82-
CUDA_CALL(cudaStreamSynchronize(output_stream_));
84+
CUDA_CALL_GUARD(cudaStreamSynchronize(output_stream_));
8385
}
8486

8587
void DaliPipeline::PutOutput(void* destination, int output_idx, device_type_t destination_device) {

src/dali_executor/dali_pipeline.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
#include "src/dali_executor/utils/utils.h"
3333
#include "src/error_handling.h"
3434

35-
using std::cout;
36-
using std::endl;
3735

3836
namespace triton { namespace backend { namespace dali {
3937

@@ -138,15 +136,25 @@ class DaliPipeline {
138136
CreatePipeline();
139137
}
140138

141-
int DeviceId() {
139+
140+
int DeviceId() const {
142141
return device_id_;
143142
}
144143

145-
int NumThreadsArg() {
144+
145+
int NumThreadsArg() const {
146146
return num_threads_;
147147
}
148148

149+
149150
private:
151+
/**
152+
* @return True, if this DALI Pipeline does not have GPU available
153+
*/
154+
bool NoGpu() const noexcept {
155+
return device_id_ < 0;
156+
}
157+
150158
void CreatePipeline() {
151159
daliCreatePipeline(&handle_, serialized_pipeline_.c_str(), serialized_pipeline_.length(),
152160
max_batch_size_, num_threads_, device_id_, 0, 1, 0, 0, 0);
@@ -160,9 +168,11 @@ class DaliPipeline {
160168
}
161169

162170
void ReleaseStream() {
171+
if (NoGpu())
172+
return;
163173
if (output_stream_) {
164-
CUDA_CALL(cudaStreamSynchronize(output_stream_));
165-
CUDA_CALL(cudaStreamDestroy(output_stream_));
174+
CUDA_CALL_GUARD(cudaStreamSynchronize(output_stream_));
175+
CUDA_CALL_GUARD(cudaStreamDestroy(output_stream_));
166176
output_stream_ = nullptr;
167177
}
168178
}
@@ -175,7 +185,9 @@ class DaliPipeline {
175185
}
176186

177187
void InitStream() {
178-
CUDA_CALL(cudaStreamCreate(&output_stream_));
188+
if (NoGpu())
189+
return;
190+
CUDA_CALL_GUARD(cudaStreamCreate(&output_stream_));
179191
}
180192

181193
std::string serialized_pipeline_{};

0 commit comments

Comments
 (0)