@@ -93,6 +93,16 @@ const char *SAM::CreateSession(SEG::DL_INIT_PARAM &iParams) {
9393
9494 auto input_shape =
9595 _session->GetInputTypeInfo (0 ).GetTensorTypeAndShapeInfo ().GetShape ();
96+ // Optional shape check when model has fixed dims (not -1)
97+ if (input_shape.size () >= 4 && input_shape[2 ] > 0 && input_shape[3 ] > 0 ) {
98+ const int64_t expectH = _imgSize.at (1 );
99+ const int64_t expectW = _imgSize.at (0 );
100+ if (input_shape[2 ] != expectH || input_shape[3 ] != expectW) {
101+ std::cerr << " [SAM]: Model input (H,W)=(" << input_shape[2 ] << " ," << input_shape[3 ]
102+ << " ) mismatches configured imgSize (W,H)=(" << _imgSize[0 ] << " ," << _imgSize[1 ] << " )."
103+ << std::endl;
104+ }
105+ }
96106 auto output_shape =
97107 _session->GetOutputTypeInfo (0 ).GetTensorTypeAndShapeInfo ().GetShape ();
98108 auto output_type = _session->GetOutputTypeInfo (0 )
@@ -127,9 +137,9 @@ const char *SAM::RunSession(const cv::Mat &iImg,
127137 utilities.BlobFromImage (processedImg, blob);
128138 std::vector<int64_t > inputNodeDims;
129139 if (_modelType == SEG::SAM_SEGMENT_ENCODER) {
130- inputNodeDims = {1 , 3 , _imgSize.at (0 ), _imgSize.at (1 )};
140+ // NCHW: H = imgSize[1], W = imgSize[0]
141+ inputNodeDims = {1 , 3 , _imgSize.at (1 ), _imgSize.at (0 )};
131142 } else if (_modelType == SEG::SAM_SEGMENT_DECODER) {
132- // Input size or SAM decoder model is 256x64x64 for the decoder
133143 inputNodeDims = {1 , 256 , 64 , 64 };
134144 }
135145 TensorProcess (starttime_1, iImg, blob, inputNodeDims, _modelType, oResult,
@@ -329,8 +339,9 @@ char *SAM::WarmUpSession(SEG::MODEL_TYPE _modelType) {
329339
330340 float *blob = new float [iImg.total () * 3 ];
331341 utilities.BlobFromImage (processedImg, blob);
332- std::vector<int64_t > SAM_input_node_dims = {1 , 3 , _imgSize.at (0 ),
333- _imgSize.at (1 )};
342+
343+ // NCHW: H = imgSize[1], W = imgSize[0]
344+ std::vector<int64_t > SAM_input_node_dims = {1 , 3 , _imgSize.at (1 ), _imgSize.at (0 )};
334345 switch (_modelType) {
335346 case SEG::SAM_SEGMENT_ENCODER: {
336347 Ort::Value input_tensor = Ort::Value::CreateTensor<float >(
0 commit comments