Skip to content

Commit cf89dfe

Browse files
bingnan.nisophgo-yezx
authored andcommitted
[fix](isp_model):fix bugs for models used in isp
[description ]:NA [root cause]:NA [JIRA ID]:NA [chip project]:NA [side effects]:NA [Affected *.ko/*.so/*.a] :NA Change-Id: I91feb17c5ab3b06def625f4e914227c2f52a0a9d Reviewed-on: https://gerrit-ai.sophgo.vip:8443/145199 Reviewed-by: 振兴 叶 <zhenxing.ye@sophgo.com>
1 parent a137e91 commit cf89dfe

File tree

3 files changed

+49
-53
lines changed

3 files changed

+49
-53
lines changed

sample/c/utils/meta_visualize.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ static float GetYuvColor(int chanel, color_rgb *color) {
192192
// TODO: Need refactor
193193
int _WriteText(VIDEO_FRAME_INFO_S *frame, int x, int y, const char *name,
194194
color_rgb color, int thickness) {
195-
if (frame->stVFrame.enPixelFormat != PIXEL_FORMAT_NV21 &&
195+
if (frame->stVFrame.enPixelFormat != PIXEL_FORMAT_NV12 &&
196+
frame->stVFrame.enPixelFormat != PIXEL_FORMAT_NV21 &&
196197
frame->stVFrame.enPixelFormat != PIXEL_FORMAT_YUV_PLANAR_420) {
197198
LOGE(
198199
"Only PIXEL_FORMAT_NV21 and PIXEL_FORMAT_YUV_PLANAR_420 are supported "
@@ -566,6 +567,7 @@ int WriteText(char *name, int x, int y, VIDEO_FRAME_INFO_S *drawFrame, float r,
566567
int DrawMeta(const TDLFace *meta, VIDEO_FRAME_INFO_S *drawFrame,
567568
const bool drawText, const std::vector<TDLBrush> &brushes) {
568569
if (drawFrame->stVFrame.enPixelFormat != PIXEL_FORMAT_NV21 &&
570+
drawFrame->stVFrame.enPixelFormat != PIXEL_FORMAT_NV12 &&
569571
drawFrame->stVFrame.enPixelFormat != PIXEL_FORMAT_YUV_PLANAR_420) {
570572
LOGE(
571573
"Only PIXEL_FORMAT_NV21 and PIXEL_FORMAT_YUV_PLANAR_420 are supported "
@@ -606,15 +608,15 @@ int DrawMeta(const TDLFace *meta, VIDEO_FRAME_INFO_S *drawFrame,
606608

607609
TDLBox bbox = meta->info[i].box;
608610

609-
if (drawFrame->stVFrame.enPixelFormat == PIXEL_FORMAT_NV21) {
611+
if (drawFrame->stVFrame.enPixelFormat == PIXEL_FORMAT_NV21 ||
612+
drawFrame->stVFrame.enPixelFormat == PIXEL_FORMAT_NV12) {
610613
DrawRect<FORMAT_NV21>(drawFrame, bbox.x1, bbox.x2, bbox.y1, bbox.y2,
611614
meta->info[i].name, rgb_color, thickness, drawText);
612615
} else {
613616
DrawRect<FORMAT_YUV_420P>(drawFrame, bbox.x1, bbox.x2, bbox.y1, bbox.y2,
614617
meta->info[i].name, rgb_color, thickness,
615618
drawText);
616619
}
617-
return 0;
618620
}
619621

620622
CVI_SYS_IonFlushCache(drawFrame->stVFrame.u64PhyAddr[0],

src/components/nn/image_classification/isp_image_classification.cpp

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,19 @@ void parse_output(T *ptr_out, const int num_cls, float qscale,
5454
}
5555

5656
IspImageClassification::IspImageClassification() : BaseModel() {
57-
net_param_.model_config.mean = {0, 0, 0};
58-
net_param_.model_config.std = {255, 255, 255};
59-
net_param_.model_config.rgb_order = "rgb";
57+
net_param_.model_config.mean = {123.675, 116.28, 103.52};
58+
net_param_.model_config.std = {58.395, 57.12, 57.375};
59+
net_param_.model_config.rgb_order = "gray";
6060
keep_aspect_ratio_ = true;
6161
}
6262

6363
IspImageClassification::~IspImageClassification() {}
6464

6565
int IspImageClassification::onModelOpened() {
66-
if (net_->getOutputNames().size() != 1) {
67-
LOGE("ImageClassification only expected 1 output branch!\n");
68-
return -1;
69-
}
66+
// if (net_->getOutputNames().size() != 1) {
67+
// LOGE("ImageClassification only expected 1 output branch!\n");
68+
// return -1;
69+
//}
7070

7171
return 0;
7272
}
@@ -75,6 +75,11 @@ int32_t IspImageClassification::inference(
7575
const std::vector<std::shared_ptr<BaseImage>> &images,
7676
std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas,
7777
const std::map<std::string, float> &parameters) {
78+
if (images.empty()) {
79+
LOGE("Input images is empty");
80+
return -1;
81+
}
82+
7883
float awb[3]; // rgain, ggain, bgain
7984
float ccm[9]; // rgb[3][3]
8085
float blc[1];
@@ -116,64 +121,50 @@ int32_t IspImageClassification::inference(
116121
memcpy(input_ptr, blc, sizeof(float));
117122
}
118123

119-
// for (auto &image : images) {
120-
// // int32_t *temp_buffer = (int32_t *)image->getVirtualAddress()[0];
121-
// int32_t *temp_buffer =
122-
// reinterpret_cast<int32_t *>(image->getVirtualAddress()[0]);
123-
// std::string input_image = net_->getInputNames()[0];
124+
model_timer_.TicToc("runstart");
124125

125-
// const TensorInfo &tinfo = net_->getTensorInfo(input_image);
126-
// input_ptr = (int32_t *)tinfo.sys_mem;
127-
// memcpy(input_ptr, temp_buffer, tinfo.tensor_size);
126+
for (auto &image : images) {
127+
std::string input_layer_name = net_->getInputNames()[0];
128128

129-
// net_->updateInputTensors();
130-
// net_->forward();
131-
// net_->updateOutputTensors();
132-
// std::vector<std::shared_ptr<ModelOutputInfo>> batch_results;
129+
net_->getInputTensor(input_layer_name)->copyFromImage(image, 0);
130+
model_timer_.TicToc("preprocess");
133131

134-
// std::vector<std::shared_ptr<BaseImage>> batch_images = {image};
135-
// outputParse(batch_images, batch_results);
132+
net_->updateInputTensors();
133+
net_->forward();
134+
model_timer_.TicToc("tpu");
135+
net_->updateOutputTensors();
136+
std::shared_ptr<ModelClassificationInfo> result =
137+
std::make_shared<ModelClassificationInfo>();
136138

137-
// out_datas.insert(out_datas.end(), batch_results.begin(),
138-
// batch_results.end());
139-
//}
139+
outputParse(result);
140+
model_timer_.TicToc("post");
140141

141-
std::vector<std::shared_ptr<ModelOutputInfo>> batch_out_datas;
142-
int ret = BaseModel::inference(images, batch_out_datas);
143-
if (ret != 0) {
144-
LOGE("inference failed");
145-
return ret;
142+
out_datas.push_back(result);
146143
}
147-
out_datas.push_back(batch_out_datas[0]);
148144

149145
return 0;
150146
}
151147

152148
int32_t IspImageClassification::outputParse(
153-
const std::vector<std::shared_ptr<BaseImage>> &images,
154-
std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas) {
155-
std::string output_name = net_->getOutputNames()[0];
149+
std::shared_ptr<ModelClassificationInfo> &out_data) {
150+
std::string output_name = net_->getOutputNames()[1];
156151
TensorInfo oinfo = net_->getTensorInfo(output_name);
157152

158153
std::shared_ptr<BaseTensor> output_tensor =
159154
net_->getOutputTensor(output_name);
160155

161-
for (size_t b = 0; b < images.size(); b++) {
162-
std::shared_ptr<ModelClassificationInfo> cls_meta =
163-
std::make_shared<ModelClassificationInfo>();
164-
if (oinfo.data_type == TDLDataType::INT8) {
165-
parse_output<int8_t>(output_tensor->getBatchPtr<int8_t>(b),
166-
oinfo.tensor_elem, oinfo.qscale, cls_meta);
167-
} else if (oinfo.data_type == TDLDataType::UINT8) {
168-
parse_output<uint8_t>(output_tensor->getBatchPtr<uint8_t>(b),
169-
oinfo.tensor_elem, oinfo.qscale, cls_meta);
170-
} else if (oinfo.data_type == TDLDataType::FP32) {
171-
parse_output<float>(output_tensor->getBatchPtr<float>(b),
172-
oinfo.tensor_elem, oinfo.qscale, cls_meta);
173-
} else {
174-
LOGE("unsupported data type: %d", oinfo.data_type);
175-
}
176-
out_datas.push_back(cls_meta);
156+
if (oinfo.data_type == TDLDataType::INT8) {
157+
parse_output<int8_t>(output_tensor->getBatchPtr<int8_t>(0),
158+
oinfo.tensor_elem, oinfo.qscale, out_data);
159+
} else if (oinfo.data_type == TDLDataType::UINT8) {
160+
parse_output<uint8_t>(output_tensor->getBatchPtr<uint8_t>(0),
161+
oinfo.tensor_elem, oinfo.qscale, out_data);
162+
} else if (oinfo.data_type == TDLDataType::FP32) {
163+
parse_output<float>(output_tensor->getBatchPtr<float>(0), oinfo.tensor_elem,
164+
oinfo.qscale, out_data);
165+
} else {
166+
LOGE("unsupported data type: %d", oinfo.data_type);
177167
}
168+
178169
return 0;
179170
}

src/components/nn/image_classification/isp_image_classification.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ class IspImageClassification final : public BaseModel {
1313
const std::map<std::string, float> &parameters = {}) override;
1414
virtual int32_t outputParse(
1515
const std::vector<std::shared_ptr<BaseImage>> &images,
16-
std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas) override;
16+
std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas) override {
17+
return 0;
18+
}
19+
int32_t outputParse(std::shared_ptr<ModelClassificationInfo> &out_data);
1720
virtual int32_t onModelOpened() override;
1821
};
1922

0 commit comments

Comments
 (0)