Skip to content

Commit 83c57b8

Browse files
KangJialiangshyama7004
authored andcommitted
Fix yoloPostProcessing` to handle variable number of classes (nc)
Previously, the yoloPostProcessing function assumed that the number of classes (nc) was fixed at 80. This caused incorrect behavior when a different number of classes was specified, leading to mismatched output shapes. This update modifies the code to use the provided `nc` value dynamically, ensuring that the output shapes are correctly calculated based on the specified number of classes. This prevents issues when `nc` is not equal to 80 and allows for greater flexibility in model configurations.
1 parent 563fc59 commit 83c57b8

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,7 +2691,7 @@ void yoloPostProcessing(
26912691
}
26922692

26932693
if (model_name == "yolonas"){
2694-
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
2694+
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
26952695
Mat concat_out;
26962696
// squeeze the first dimension
26972697
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@@ -2701,12 +2701,12 @@ void yoloPostProcessing(
27012701
// remove the second element
27022702
outs.pop_back();
27032703
// unsqueeze the first dimension
2704-
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, 84});
2704+
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
27052705
}
27062706

2707-
// assert if last dim is 85 or 84
2708-
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
2709-
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
2707+
// assert if last dim is nc+5 or nc+4
2708+
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
2709+
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
27102710

27112711
for (auto preds : outs){
27122712

samples/dnn/yolo_detector.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void yoloPostProcessing(
125125

126126
if (model_name == "yolonas")
127127
{
128-
// outs contains 2 elemets of shape [1, 8400, 80] and [1, 8400, 4]. Concat them to get [1, 8400, 84]
128+
// outs contains 2 elemets of shape [1, 8400, nc] and [1, 8400, 4]. Concat them to get [1, 8400, nc+4]
129129
Mat concat_out;
130130
// squeeze the first dimension
131131
outs[0] = outs[0].reshape(1, outs[0].size[1]);
@@ -135,12 +135,12 @@ void yoloPostProcessing(
135135
// remove the second element
136136
outs.pop_back();
137137
// unsqueeze the first dimension
138-
outs[0] = outs[0].reshape(0, std::vector<int>{1, 8400, nc + 4});
138+
outs[0] = outs[0].reshape(0, std::vector<int>{1, outs[0].size[0], outs[0].size[1]});
139139
}
140140

141-
// assert if last dim is 85 or 84
142-
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, 85 or 84]");
143-
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == 80 + 4), true, "Invalid output shape: ");
141+
// assert if last dim is nc+5 or nc+4
142+
CV_CheckEQ(outs[0].dims, 3, "Invalid output shape. The shape should be [1, #anchors, nc+5 or nc+4]");
143+
CV_CheckEQ((outs[0].size[2] == nc + 5 || outs[0].size[2] == nc + 4), true, "Invalid output shape: ");
144144

145145
for (auto preds : outs)
146146
{

0 commit comments

Comments
 (0)