-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathgooglenet.cpp
More file actions
424 lines (378 loc) · 16.3 KB
/
googlenet.cpp
File metadata and controls
424 lines (378 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
#include <NvInfer.h>
#include <cassert>
#include <chrono>
#include <cmath>
#include <opencv2/opencv.hpp>
#include <vector>
#include "logging.h"
#include "utils.h"
using WeightMap = std::map<std::string, Weights>;
using M = nvinfer1::MatrixOperation;
using E = nvinfer1::ElementWiseOperation;
using NDCF = nvinfer1::NetworkDefinitionCreationFlag;
static Logger gLogger;
// stuff we know about googlenet
static constexpr const std::size_t N = 1;
static constexpr const int32_t INPUT_H = 224;
static constexpr const int32_t INPUT_W = 224;
static constexpr const std::array<int32_t, 2> SIZES = {3 * INPUT_H * INPUT_W, 1000};
static constexpr const std::array<const char*, 2> NAMES = {"data", "prob"};
static constexpr const bool TRT_PREPROCESS = TRT_VERSION >= 8510 ? true : false;
static constexpr const char* WTS_PATH = "../models/googlenet.wts";
static constexpr const char* ENGINE_PATH = "../models/googlenet.engine";
static constexpr const char* LABELS_PATH = "../assets/imagenet1000_clsidx_to_labels.txt";
static constexpr const std::array<const float, 3> mean = {0.485f, 0.456f, 0.406f};
static constexpr const std::array<const float, 3> stdv = {0.229f, 0.224f, 0.225f};
auto addBatchNorm2d(INetworkDefinition* network, WeightMap& m, ITensor& input, const std::string& lname,
float eps = 1e-3) -> ILayer* {
static Weights none{DataType::kFLOAT, nullptr, 0ll};
float* gamma = (float*)m[lname + ".weight"].values;
float* beta = (float*)m[lname + ".bias"].values;
float* mean = (float*)m[lname + ".running_mean"].values;
float* var = (float*)m[lname + ".running_var"].values;
int64_t len = m[lname + ".running_var"].count;
auto* scval = static_cast<float*>(malloc(sizeof(float) * len));
for (int i = 0; i < len; i++) {
scval[i] = gamma[i] / sqrt(var[i] + eps);
}
Weights scale{DataType::kFLOAT, scval, len};
auto* shift_val = static_cast<float*>(malloc(sizeof(float) * len));
for (int i = 0; i < len; i++) {
shift_val[i] = beta[i] - (mean[i] * scval[i]);
}
Weights shift{DataType::kFLOAT, shift_val, len};
m[lname + ".scale"] = scale;
m[lname + ".shift"] = shift;
m[lname + ".power"] = none;
auto* bn = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, none);
assert(bn);
bn->setName(lname.c_str());
return bn;
}
/**
* @brief A basic conv2d+bn+relu layer from googlenet
*
* @param network network definition from TensorRT API
* @param weightMap weight map
* @param input input tensor
* @param outch output channels
* @param k kernel size for convolution
* @param s stride size for convolution
* @param p padding size for convolution
* @param lname layer name from weight map
* @return ILayer*
*/
ILayer* basicConv2d(INetworkDefinition* network, WeightMap& weightMap, ITensor& input, const std::string& lname,
int32_t outch, int k, int s = 1, int p = 0) {
static const Weights none{DataType::kFLOAT, nullptr, 0ll};
auto* conv = network->addConvolutionNd(input, outch, DimsHW{k, k}, weightMap[lname + ".conv.weight"], none);
auto* bn = addBatchNorm2d(network, weightMap, *conv->getOutput(0), lname + ".bn");
auto* relu = network->addActivation(*bn->getOutput(0), ActivationType::kRELU);
assert(conv && bn && relu);
conv->setName(lname.c_str());
bn->setName((lname + ".bn").c_str());
relu->setName((lname + ".relu").c_str());
conv->setStrideNd(DimsHW{s, s});
conv->setPaddingNd(DimsHW{p, p});
return relu;
}
/**
* @brief Inception module from googlenet implementation in torchvision, see:
* https://github.com/pytorch/vision/blob/v0.24.1/torchvision/models/googlenet.py#L184
*
* @param network network definition from TensorRT API
* @param weightMap weight map
* @param input input tensor
* @param lname layer name from weight map
* @param ch1x1
* @param ch3x3red
* @param ch3x3
* @param ch5x5red
* @param ch5x5
* @param pool_proj
* @return IConcatenationLayer*
*/
IConcatenationLayer* inception(INetworkDefinition* network, WeightMap& weightMap, ITensor& input,
const std::string& lname, int ch1x1, int ch3x3red, int ch3x3, int ch5x5red, int ch5x5,
int pool_proj) {
// "cbr" means "Conv-Batchnorm-Relu"
auto* cbr1 = basicConv2d(network, weightMap, input, lname + "branch1", ch1x1, 1);
auto* cbr2 = basicConv2d(network, weightMap, input, lname + "branch2.0", ch3x3red, 1);
auto* cbr3 = basicConv2d(network, weightMap, *cbr2->getOutput(0), lname + "branch2.1", ch3x3, 3, 1, 1);
auto* cbr4 = basicConv2d(network, weightMap, input, lname + "branch3.0", ch5x5red, 1);
auto* cbr5 = basicConv2d(network, weightMap, *cbr4->getOutput(0), lname + "branch3.1", ch5x5, 3, 1, 1);
auto* pool1 = network->addPoolingNd(input, PoolingType::kMAX, DimsHW{3, 3});
auto* cbr6 = basicConv2d(network, weightMap, *pool1->getOutput(0), lname + "branch4.1", pool_proj, 1);
assert(cbr1 && cbr2 && cbr3 && cbr4 && cbr5 && pool1 && cbr6);
pool1->setStrideNd(DimsHW{1, 1});
pool1->setPaddingNd(DimsHW{1, 1});
pool1->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
std::array<ITensor*, 4> inputTensors = {cbr1->getOutput(0), cbr3->getOutput(0), cbr5->getOutput(0),
cbr6->getOutput(0)};
IConcatenationLayer* cat1 = network->addConcatenation(inputTensors.data(), 4);
assert(cat1);
return cat1;
}
// Creat the engine using only the API and not any parser.
ICudaEngine* createEngine(int32_t N, IRuntime* runtime, IBuilder* builder, IBuilderConfig* config, DataType dt) {
WeightMap weightMap = loadWeights(WTS_PATH);
#if TRT_VERSION >= 11200
auto flag = 1U << static_cast<int>(NDCF::kSTRONGLY_TYPED);
#elif TRT_VERSION >= 10000
auto flag = 0U;
#else
auto flag = 1U << static_cast<int>(NDCF::kEXPLICIT_BATCH);
#endif
auto* network = builder->createNetworkV2(flag);
ITensor* input{nullptr};
if constexpr (TRT_PREPROCESS) {
dt = DataType::kUINT8;
input = network->addInput(NAMES[0], dt, Dims4{N, INPUT_H, INPUT_W, 3});
auto* trans = addTransformLayer(network, *input, true, mean, stdv);
input = trans->getOutput(0);
} else {
input = network->addInput(NAMES[0], dt, Dims4{N, 3, INPUT_H, INPUT_W});
}
assert(input);
auto* relu1 = basicConv2d(network, weightMap, *input, "conv1", 64, 7, 2, 3);
auto* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
assert(pool1);
pool1->setStrideNd(DimsHW{2, 2});
pool1->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
pool1->setName("pool1");
auto* relu2 = basicConv2d(network, weightMap, *pool1->getOutput(0), "conv2", 64, 1);
auto* relu3 = basicConv2d(network, weightMap, *relu2->getOutput(0), "conv3", 192, 3, 1, 1);
auto* pool2 = network->addPoolingNd(*relu3->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
assert(pool2);
pool2->setStrideNd(DimsHW{2, 2});
pool2->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
pool2->setName("pool2");
auto* cat1 = inception(network, weightMap, *pool2->getOutput(0), "inception3a.", 64, 96, 128, 16, 32, 32);
auto* cat2 = inception(network, weightMap, *cat1->getOutput(0), "inception3b.", 128, 128, 192, 32, 96, 64);
auto* pool3 = network->addPoolingNd(*cat2->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
assert(pool3);
pool3->setStrideNd(DimsHW{2, 2});
pool3->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
pool3->setName("pool3");
auto* cat3 = inception(network, weightMap, *pool3->getOutput(0), "inception4a.", 192, 96, 208, 16, 48, 64);
cat3 = inception(network, weightMap, *cat3->getOutput(0), "inception4b.", 160, 112, 224, 24, 64, 64);
cat3 = inception(network, weightMap, *cat3->getOutput(0), "inception4c.", 128, 128, 256, 24, 64, 64);
cat3 = inception(network, weightMap, *cat3->getOutput(0), "inception4d.", 112, 144, 288, 32, 64, 64);
cat3 = inception(network, weightMap, *cat3->getOutput(0), "inception4e.", 256, 160, 320, 32, 128, 128);
IPoolingLayer* pool4 = network->addPoolingNd(*cat3->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
assert(pool4);
pool4->setStrideNd(DimsHW{2, 2});
pool4->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
pool4->setName("pool4");
cat3 = inception(network, weightMap, *pool4->getOutput(0), "inception5a.", 256, 160, 320, 32, 128, 128);
cat3 = inception(network, weightMap, *cat3->getOutput(0), "inception5b.", 384, 192, 384, 48, 128, 128);
// this is a AdaptiveAvgPool2d in pytorch implementation
IPoolingLayer* pool5 = network->addPoolingNd(*cat3->getOutput(0), PoolingType::kAVERAGE, DimsHW{7, 7});
auto* shuffle = network->addShuffle(*pool5->getOutput(0));
assert(pool5 && shuffle);
shuffle->setName("shuffle");
shuffle->setReshapeDimensions(Dims2{1, -1}); // "-1" means "1024"
auto* fcw = network->addConstant(Dims2{1000, 1024}, weightMap["fc.weight"])->getOutput(0);
auto* fcb = network->addConstant(Dims2{1, 1000}, weightMap["fc.bias"])->getOutput(0);
auto* fc0 = network->addMatrixMultiply(*shuffle->getOutput(0), M::kNONE, *fcw, M::kTRANSPOSE);
auto* fc1 = network->addElementWise(*fc0->getOutput(0), *fcb, E::kSUM);
fc1->getOutput(0)->setName(NAMES[1]);
network->markOutput(*fc1->getOutput(0));
// Build engine
#if TRT_VERSION >= 8000
config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, WORKSPACE_SIZE);
IHostMemory* mem = builder->buildSerializedNetwork(*network, *config);
ICudaEngine* engine = runtime->deserializeCudaEngine(mem->data(), mem->size());
delete network;
#else
builder->setMaxBatchSize(N);
config->setMaxWorkspaceSize(WORKSPACE_SIZE);
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
network->destroy();
#endif
std::cout << "build finished\n";
// Release host memory
for (auto& mem : weightMap) {
free((void*)mem.second.values);
}
return engine;
}
void APIToModel(int32_t N, IRuntime* runtime, IHostMemory** modelStream) {
// Create builder
IBuilder* builder = createInferBuilder(gLogger);
IBuilderConfig* config = builder->createBuilderConfig();
// Create model to populate the network, then set the outputs and create an engine
ICudaEngine* engine = createEngine(N, runtime, builder, config, DataType::kFLOAT);
assert(engine != nullptr);
// Serialize the engine
(*modelStream) = engine->serialize();
#if TRT_VERSION >= 8000
delete engine;
delete config;
delete builder;
#else
engine->destroy();
config->destroy();
builder->destroy();
#endif
}
std::vector<std::vector<float>> doInference(IExecutionContext& context, void* input, int64_t batchSize) {
const auto& engine = context.getEngine();
cudaStream_t stream;
CHECK(cudaStreamCreate(&stream));
std::vector<void*> buffers;
#if TRT_VERSION >= 8000
const int32_t nIO = engine.getNbIOTensors();
#else
const int32_t nIO = engine.getNbBindings();
#endif
buffers.resize(nIO);
for (auto i = 0; i < nIO; ++i) {
std::size_t size = 0;
#if TRT_VERSION >= 8000
auto* tensor_name = engine.getIOTensorName(i);
auto s = getSize(engine.getTensorDataType(tensor_name));
size = s * batchSize * SIZES[i];
CHECK(cudaMalloc(&buffers[i], size));
if (i == 0) {
CHECK(cudaMemcpyAsync(buffers[i], input, size, cudaMemcpyHostToDevice, stream));
}
context.setTensorAddress(tensor_name, buffers[i]);
#else
const int32_t idx = engine.getBindingIndex(NAMES[i]);
auto s = getSize(engine.getBindingDataType(idx));
assert(idx == i);
size = s * batchSize * SIZES[i];
CHECK(cudaMalloc(&buffers[i], size));
if (i == 0) {
CHECK(cudaMemcpyAsync(buffers[i], input, size, cudaMemcpyHostToDevice, stream));
}
#endif
}
#if TRT_VERSION >= 8000
assert(context.enqueueV3(stream));
#else
assert(context.enqueueV2(buffers.data(), stream, nullptr));
#endif
std::vector<std::vector<float>> prob;
for (int i = 1; i < nIO; ++i) {
std::vector<float> tmp(batchSize * SIZES[i], std::nanf(""));
std::size_t size = batchSize * SIZES[i] * sizeof(float);
CHECK(cudaMemcpyAsync(tmp.data(), buffers[i], size, cudaMemcpyDeviceToHost, stream));
prob.emplace_back(tmp);
}
CHECK(cudaStreamSynchronize(stream));
for (auto& buffer : buffers) {
CHECK(cudaFree(buffer));
}
CHECK(cudaStreamDestroy(stream));
return prob;
}
int main(int argc, char** argv) {
checkTrtEnv();
if (argc != 2) {
std::cerr << "arguments not right!\n";
std::cerr << "./googlenet -s // serialize model to plan file\n";
std::cerr << "./googlenet -d // deserialize plan file and run inference\n";
return -1;
}
// create a model using the API directly and serialize it to a stream
IRuntime* runtime = createInferRuntime(gLogger);
assert(runtime != nullptr);
char* trtModelStream{nullptr};
std::streamsize size{0};
if (std::string(argv[1]) == "-s") {
IHostMemory* modelStream{nullptr};
APIToModel(1, runtime, &modelStream);
assert(modelStream != nullptr);
std::ofstream p(ENGINE_PATH, std::ios::binary | std::ios::trunc);
if (!p) {
std::cerr << "could not open plan output file\n";
return -1;
}
if (modelStream->size() > static_cast<std::size_t>(std::numeric_limits<std::streamsize>::max())) {
std::cerr << "this model is too large to serialize\n";
return -1;
}
const auto* data_ptr = reinterpret_cast<const char*>(modelStream->data());
auto data_size = static_cast<std::streamsize>(modelStream->size());
p.write(data_ptr, data_size);
#if TRT_VERSION >= 8000
delete modelStream;
#else
modelStream->destroy();
#endif
return 0;
} else if (std::string(argv[1]) == "-d") {
std::ifstream file(ENGINE_PATH, std::ios::binary);
if (file.good()) {
file.seekg(0, file.end);
size = file.tellg();
file.seekg(0, file.beg);
trtModelStream = new char[size];
assert(trtModelStream);
file.read(trtModelStream, size);
file.close();
}
} else {
return 1;
}
#if TRT_VERSION >= 8000
ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size);
#else
ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr);
#endif
assert(engine != nullptr);
IExecutionContext* context = engine->createExecutionContext();
assert(context != nullptr);
const std::string img_path = "../assets/cats.jpg";
void* input = nullptr;
std::vector<float> flat_img;
cv::Mat img = cv::imread(img_path, cv::IMREAD_COLOR);
if constexpr (TRT_PREPROCESS) {
// for simplicity, resize image on cpu side
cv::resize(img, img, cv::Size(INPUT_W, INPUT_H), 0, 0, cv::INTER_LINEAR);
input = static_cast<void*>(img.data);
} else {
flat_img = preprocess_img(img, true, mean, stdv, N, INPUT_H, INPUT_W);
input = flat_img.data();
}
assert(input);
for (int32_t i = 0; i < 100; ++i) {
auto _start = std::chrono::system_clock::now();
auto prob = doInference(*context, input, 1);
auto _end = std::chrono::system_clock::now();
auto _time = std::chrono::duration_cast<std::chrono::microseconds>(_end - _start).count();
std::cout << "Execution time: " << _time << "us\n";
for (const auto& vector : prob) {
int idx = 0;
for (auto v : vector) {
std::cout << std::setprecision(4) << v << ", " << std::flush;
if (++idx > 20) {
std::cout << "\n====\n";
break;
}
}
}
if (i == 99) {
std::cout << "prediction result:\n";
auto labels = loadImagenetLabelMap(LABELS_PATH);
int _top = 0;
for (auto& [idx, logits] : topk(prob[0], 3)) {
std::cout << "Top: " << _top++ << " idx: " << idx << ", logits: " << logits
<< ", label: " << labels[idx] << "\n";
}
}
}
delete[] trtModelStream;
#if TRT_VERSION >= 8000
delete context;
delete engine;
delete runtime;
#else
context->destroy();
engine->destroy();
runtime->destroy();
#endif
return 0;
}