@@ -54,19 +54,19 @@ void parse_output(T *ptr_out, const int num_cls, float qscale,
5454}
5555
5656IspImageClassification::IspImageClassification () : BaseModel() {
57- net_param_.model_config .mean = {0 , 0 , 0 };
58- net_param_.model_config .std = {255 , 255 , 255 };
59- net_param_.model_config .rgb_order = " rgb " ;
57+ net_param_.model_config .mean = {123.675 , 116.28 , 103.52 };
58+ net_param_.model_config .std = {58.395 , 57.12 , 57.375 };
59+ net_param_.model_config .rgb_order = " gray " ;
6060 keep_aspect_ratio_ = true ;
6161}
6262
6363IspImageClassification::~IspImageClassification () {}
6464
6565int IspImageClassification::onModelOpened () {
66- if (net_->getOutputNames ().size () != 1 ) {
67- LOGE (" ImageClassification only expected 1 output branch!\n " );
68- return -1 ;
69- }
66+ // if (net_->getOutputNames().size() != 1) {
67+ // LOGE("ImageClassification only expected 1 output branch!\n");
68+ // return -1;
69+ // }
7070
7171 return 0 ;
7272}
@@ -75,6 +75,11 @@ int32_t IspImageClassification::inference(
7575 const std::vector<std::shared_ptr<BaseImage>> &images,
7676 std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas,
7777 const std::map<std::string, float > ¶meters) {
78+ if (images.empty ()) {
79+ LOGE (" Input images is empty" );
80+ return -1 ;
81+ }
82+
7883 float awb[3 ]; // rgain, ggain, bgain
7984 float ccm[9 ]; // rgb[3][3]
8085 float blc[1 ];
@@ -116,64 +121,50 @@ int32_t IspImageClassification::inference(
116121 memcpy (input_ptr, blc, sizeof (float ));
117122 }
118123
119- // for (auto &image : images) {
120- // // int32_t *temp_buffer = (int32_t *)image->getVirtualAddress()[0];
121- // int32_t *temp_buffer =
122- // reinterpret_cast<int32_t *>(image->getVirtualAddress()[0]);
123- // std::string input_image = net_->getInputNames()[0];
124+ model_timer_.TicToc (" runstart" );
124125
125- // const TensorInfo &tinfo = net_->getTensorInfo(input_image);
126- // input_ptr = (int32_t *)tinfo.sys_mem;
127- // memcpy(input_ptr, temp_buffer, tinfo.tensor_size);
126+ for (auto &image : images) {
127+ std::string input_layer_name = net_->getInputNames ()[0 ];
128128
129- // net_->updateInputTensors();
130- // net_->forward();
131- // net_->updateOutputTensors();
132- // std::vector<std::shared_ptr<ModelOutputInfo>> batch_results;
129+ net_->getInputTensor (input_layer_name)->copyFromImage (image, 0 );
130+ model_timer_.TicToc (" preprocess" );
133131
134- // std::vector<std::shared_ptr<BaseImage>> batch_images = {image};
135- // outputParse(batch_images, batch_results);
132+ net_->updateInputTensors ();
133+ net_->forward ();
134+ model_timer_.TicToc (" tpu" );
135+ net_->updateOutputTensors ();
136+ std::shared_ptr<ModelClassificationInfo> result =
137+ std::make_shared<ModelClassificationInfo>();
136138
137- // out_datas.insert(out_datas.end(), batch_results.begin(),
138- // batch_results.end());
139- // }
139+ outputParse (result);
140+ model_timer_.TicToc (" post" );
140141
141- std::vector<std::shared_ptr<ModelOutputInfo>> batch_out_datas;
142- int ret = BaseModel::inference (images, batch_out_datas);
143- if (ret != 0 ) {
144- LOGE (" inference failed" );
145- return ret;
142+ out_datas.push_back (result);
146143 }
147- out_datas.push_back (batch_out_datas[0 ]);
148144
149145 return 0 ;
150146}
151147
152148int32_t IspImageClassification::outputParse (
153- const std::vector<std::shared_ptr<BaseImage>> &images,
154- std::vector<std::shared_ptr<ModelOutputInfo>> &out_datas) {
155- std::string output_name = net_->getOutputNames ()[0 ];
149+ std::shared_ptr<ModelClassificationInfo> &out_data) {
150+ std::string output_name = net_->getOutputNames ()[1 ];
156151 TensorInfo oinfo = net_->getTensorInfo (output_name);
157152
158153 std::shared_ptr<BaseTensor> output_tensor =
159154 net_->getOutputTensor (output_name);
160155
161- for (size_t b = 0 ; b < images.size (); b++) {
162- std::shared_ptr<ModelClassificationInfo> cls_meta =
163- std::make_shared<ModelClassificationInfo>();
164- if (oinfo.data_type == TDLDataType::INT8) {
165- parse_output<int8_t >(output_tensor->getBatchPtr <int8_t >(b),
166- oinfo.tensor_elem , oinfo.qscale , cls_meta);
167- } else if (oinfo.data_type == TDLDataType::UINT8) {
168- parse_output<uint8_t >(output_tensor->getBatchPtr <uint8_t >(b),
169- oinfo.tensor_elem , oinfo.qscale , cls_meta);
170- } else if (oinfo.data_type == TDLDataType::FP32) {
171- parse_output<float >(output_tensor->getBatchPtr <float >(b),
172- oinfo.tensor_elem , oinfo.qscale , cls_meta);
173- } else {
174- LOGE (" unsupported data type: %d" , oinfo.data_type );
175- }
176- out_datas.push_back (cls_meta);
156+ if (oinfo.data_type == TDLDataType::INT8) {
157+ parse_output<int8_t >(output_tensor->getBatchPtr <int8_t >(0 ),
158+ oinfo.tensor_elem , oinfo.qscale , out_data);
159+ } else if (oinfo.data_type == TDLDataType::UINT8) {
160+ parse_output<uint8_t >(output_tensor->getBatchPtr <uint8_t >(0 ),
161+ oinfo.tensor_elem , oinfo.qscale , out_data);
162+ } else if (oinfo.data_type == TDLDataType::FP32) {
163+ parse_output<float >(output_tensor->getBatchPtr <float >(0 ), oinfo.tensor_elem ,
164+ oinfo.qscale , out_data);
165+ } else {
166+ LOGE (" unsupported data type: %d" , oinfo.data_type );
177167 }
168+
178169 return 0 ;
179170}
0 commit comments