本次教程以 HRNet 分类器(HRNet-W18-C-Small-v2)为例子
code:https://github.com/HRNet/HRNet-Image-Classification
paper:https://arxiv.org/abs/1908.07919
无论是仅仅使用网络还是要对网络改进,首先都要对网络有一定了解。对于这种比较火的网络,网上大批详解博客,可以多去阅读,加上论文,来对网络理解。
HRNet 分类器网络看起来很简单,如下图
从网络中可看到基本组件很简单:卷积和 upsmple。【这里就表明网络 TRT 加速时不会有 plugin 的需求。】
参考博客:
- https://www.cnblogs.com/darkknightzh/p/12150637.html
- https://zhuanlan.zhihu.com/p/143385915
- https://blog.csdn.net/weixin_37993251/article/details/88043650
- https://blog.csdn.net/weixin_38715903/article/details/101629781?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param&depth_1-utm_source=dis
跑通 demo 是很重要的一步。跑通后就可以一步一步跟进,看到底走了哪些层,这样心里就会有一个基本框架;然后可以生成 wts 文件;同时也可以生成 onnx 文件。
上述的参考博客 4中对代码有详细介绍,可以详细分析下。
建议:对于运行环境,建议使用 anaconda 的 conda create 创建虚拟环境,这样没有一系列环境问题。
conda create -n xx python=3.7 # 创建环境
activate xx # 激活
pip install xxxx # 安装包
deactivate xx # 推出环境在生成 wts 文件时,没有必须每次都是去配置gen_wts.py,主要是读取模型,保存模型参数。只要 demo 文件跑通就可以随时保存为 wts。
这一步骤单独拉出来是因为在 debug 的过程中,要关注经过哪些层,预处理有哪些,后处理有哪些。另外在后面搭建 TRT 网络时,还要根据 debug 过程在中的一些信息来调试 trt 网络。
将 pytorch 模型保存为 onnx,可有可无。但是建议如果可以保存,就使用 onnx 来可视化网络。这样对网络架构一级每层的输入输出就会非常明了。
如果无法保存 onnx,搭建网络时,要根据 wts 来分析,比较麻烦。
另外强烈建议:无论是否保存了 onnx,都要手动在纸上将网络在画一遍,,并且将每层的输出维度标注下来,这样搭建层比较多的网络时,不会晕,并且在 debugTRT 网络时可以有效定位错误。
在手动画网络图时,可以给每个节点“标号”,利用该“标号”在搭建 TRT 网络时,可以很清楚知道 “哪个节点输入,经过某种操作,输出哪个节点。”
在 onnx 图中看到几个层一定要心里有数:
比如下面红线框出的一大块实际上就是 upsample 层
下面的为 FC 层:
Conv+BN+Relu 层
ResBlock 层
单击节点。会有详细信息,这些信息使搭建网络变得方便。
如果无法导出 onnx:
搭建网络时需要从 wts 中查看层名,各个卷积层信息需要从代码中分析。
搭建网络时就按照 onnx 图一层一层搭建。
几点建议:
1 要不断去查 API 的使用 https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/index.html
2 利用已有的模块,不要重复造轮子
3 各个层名使用 onnx 的 id,这样在搭建网络时不会晕。,根据 onnx 的结点信息,各层之间的连接也不会出错。
搭建网络过程肯定会出错,debug 是必要的手段:
1 打印每层的维度
Dims dim = id_1083->getOutput(0)->getDimensions();
std::cout << dim[0] << " " << dim[1] << " " << dim[2] << " " << dim[3] << std::endl;一般如果出现生成 engine 就失败的情况,就从 createEngine 的第一句开始调试,并且随时关注窗口输出,如果在某一层出现大量提示信息,那么该层就会有问题,就将该层的输入 tensor 维度和输出 tensor 维度信息都打印出来,看输出的维度是否正常。
2 打印输出
TRT 是先构建网络,然后再 enqueue 时才能得到各层的输出信息,因此若想对比每一层的输出,需要将该层设置为 output 层
out->getOutput(0)->setName(OUTPUT_BLOB_NAME); // out可替换为任意一层
network->markOutput(*out->getOutput(0));3 关注输入层 data
数据层的 debug 无需第 2 步的做法,直接可以查看预处理后的结果。在 debug
这里就是将 TRT 搭建的网络,能封装函数,就封装为函数模块,增加代码可读性。







