Skip to content

Commit e29066e

Browse files
Yolo26-Cls Added (#1704)
1 parent 664f222 commit e29066e

File tree

7 files changed

+402
-8
lines changed

7 files changed

+402
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The basic workflow of TensorRTx is:
1717

1818
## News
1919

20-
- `2 Feb 2026`. [fazligorkembal](https://github.com/fazligorkembal) Yolo26-Det, Yolo26-Obb
20+
- `2 Feb 2026`. [fazligorkembal](https://github.com/fazligorkembal) Yolo26-Det, Yolo26-Obb, Yolo26-Cls
2121
- `15 Jan 2026`. [zgjja](https://github.com/zgjja) Refactor multiple old CV models to support TensorRT SDK through 7~10.
2222
- `8 Jan 2026`. [ydk61](https://github.com/ydk61): YOLOv13
2323
- `10 May 2025`. [pranavm-nvidia](https://github.com/pranavm-nvidia): [YOLO11](./yolo11_tripy) writen in [Tripy](https://github.com/NVIDIA/TensorRT-Incubator/tree/main/tripy).

yolo26/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,10 @@ add_executable(yolo26_obb ${PROJECT_SOURCE_DIR}/yolo26_obb.cpp ${SRCS})
4848
target_link_libraries(yolo26_obb nvinfer)
4949
target_link_libraries(yolo26_obb cudart)
5050
target_link_libraries(yolo26_obb yololayerplugins)
51-
target_link_libraries(yolo26_obb ${OpenCV_LIBS})
51+
target_link_libraries(yolo26_obb ${OpenCV_LIBS})
52+
53+
add_executable(yolo26_cls ${PROJECT_SOURCE_DIR}/yolo26_cls.cpp ${SRCS})
54+
target_link_libraries(yolo26_cls nvinfer)
55+
target_link_libraries(yolo26_cls cudart)
56+
target_link_libraries(yolo26_cls yololayerplugins)
57+
target_link_libraries(yolo26_cls ${OpenCV_LIBS})

yolo26/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Training code [link](https://github.com/ultralytics/ultralytics/archive/refs/tag
1616

1717
* [] Yolo26n-det, Yolo26s-det, Yolo26m-det, Yolo26l-det, Yolo26sx-det, support FP32/FP16 and C++ API
1818
* [] Yolo26n-obb, Yolo26s-obb, Yolo26m-obb, Yolo26l-obb, Yolo26sx-obb, support FP32/FP16 and C++ API
19+
* [] Yolo26n-cls, Yolo26s-cls, Yolo26m-cls, Yolo26l-cls, Yolo26sx-cls, support FP32/FP16 and C++ API
1920

2021
## COMING FEATURES
2122
* [] Windows OS Support
@@ -54,6 +55,13 @@ cp [PATH-TO-MAIN-FOLDER]/gen_wts.py .
5455
python gen_wts.py -w yolo26n-obb.pt -o yolo26n-obb.wts -t obb
5556
# A file 'yolo26n-obb.wts' will be generated.
5657

58+
# Download models for Cls
59+
wget https://github.com/ultralytics/assets/releases/download/v8.4.0/yolo26n-cls.pt -O yolo26n-cls.pt # to download other models, replace 'yolo26n-cls.pt' with 'yolo26s-cls.pt', 'yolo26m-cls.pt', 'yolo26l-cls.pt' or 'yolo26x-cls.pt'
60+
# Generate .wts
61+
cp [PATH-TO-MAIN-FOLDER]/gen_wts.py .
62+
python gen_wts.py -w yolo26n-cls.pt -o yolo26n-cls.wts -t cls
63+
# A file 'yolo26n-cls.wts' will be generated.
64+
5765
```
5866

5967
2. build and run
@@ -81,7 +89,20 @@ cp [PATH-TO-ultralytics]/yolo26n-obb.wts .
8189
# Build and serialize TensorRT engine
8290
./yolo26_obb -s yolo26n-obb.wts yolo26n-obb.engine [n/s/m/l/x]
8391
# Run inference
84-
./yolo26_obb -d yolo26n.engine ../images
92+
./yolo26_obb -d yolo26n-obb.engine ../images
93+
# results saved in build directory
94+
```
95+
96+
### Cls
97+
```shell
98+
Generate classification text file in build folder or download it
99+
# wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/imagenet_classes.txt
100+
101+
cp [PATH-TO-ultralytics]/yolo26n-cls.wts .
102+
# Build and serialize TensorRT engine
103+
./yolo26_cls -s yolo26n-cls.wts yolo26n-cls.engine [n/s/m/l/x]
104+
# Run inference
105+
./yolo26_cls -d yolo26n-cls.engine ../images
85106
# results saved in build directory
86107
```
87108

yolo26/include/model.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@ nvinfer1::IHostMemory* buildEngineYolo26Det(nvinfer1::IBuilder* builder, nvinfer
1010

1111
nvinfer1::IHostMemory* buildEngineYolo26Obb(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
1212
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
13-
int& max_channels, std::string& type);
13+
int& max_channels, std::string& type);
14+
15+
nvinfer1::IHostMemory* buildEngineYolo26Cls(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
16+
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
17+
int& max_channels, std::string& type);

yolo26/include/utils.h

100644100755
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ static inline int read_files_in_dir(const char* p_dir_name, std::vector<std::str
4747
return 0;
4848
}
4949

50+
inline std::vector<std::string> read_classes(std::string file_name) {
51+
std::vector<std::string> classes;
52+
std::ifstream ifs(file_name, std::ios::in);
53+
if (!ifs.is_open()) {
54+
std::cerr << file_name << " is not found, pls refer to README and download it." << std::endl;
55+
assert(0);
56+
}
57+
std::string s;
58+
while (std::getline(ifs, s)) {
59+
// std::cout << "Read class: " << s << std::endl;
60+
classes.push_back(s);
61+
}
62+
ifs.close();
63+
return classes;
64+
}
65+
5066
// Function to trim leading and trailing whitespace from a string
5167
static inline std::string trim_leading_whitespace(const std::string& str) {
5268
size_t first = str.find_first_not_of(' ');

yolo26/src/model.cpp

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ nvinfer1::IHostMemory* buildEngineYolo26Det(nvinfer1::IBuilder* builder, nvinfer
8787
C3K2(network, weightMap, *block7->getOutput(0), get_width(1024, gw, max_channels),
8888
get_width(1024, gw, max_channels), get_depth(2, gd), true, true, false, 0.5, "model.8");
8989

90-
nvinfer1::IElementWiseLayer* block9 = SPPF(network, weightMap, *block8->getOutput(0),
91-
get_width(1024, gw, max_channels), get_width(1024, gw, max_channels), 5,
92-
true, "model.9"); // TODO: VERIFY THIS BLOCK FOR OTHER YOLO26 MODELS
90+
nvinfer1::IElementWiseLayer* block9 =
91+
SPPF(network, weightMap, *block8->getOutput(0), get_width(1024, gw, max_channels),
92+
get_width(1024, gw, max_channels), 5, true, "model.9");
9393

9494
nvinfer1::IElementWiseLayer* block10 =
9595
C2PSA(network, weightMap, *block9->getOutput(0), get_width(1024, gw, max_channels),
@@ -869,4 +869,112 @@ nvinfer1::IHostMemory* buildEngineYolo26Obb(nvinfer1::IBuilder* builder, nvinfer
869869
free((void*)(mem.second.values));
870870
}
871871
return serialized_model;
872-
}
872+
}
873+
874+
nvinfer1::IHostMemory* buildEngineYolo26Cls(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
875+
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
876+
int& max_channels, std::string& type) {
877+
std::map<std::string, nvinfer1::Weights> weightMap = loadWeights(wts_path);
878+
879+
nvinfer1::INetworkDefinition* network = builder->createNetworkV2(
880+
1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
881+
882+
/*******************************************************************************************************
883+
****************************************** YOLO26 INPUT **********************************************
884+
*******************************************************************************************************/
885+
886+
nvinfer1::ITensor* data =
887+
network->addInput(kInputTensorName, dt, nvinfer1::Dims4{kBatchSize, 3, kClsInputH, kClsInputW});
888+
assert(data);
889+
890+
/*******************************************************************************************************
891+
***************************************** YOLO26 BACKBONE ********************************************
892+
*******************************************************************************************************/
893+
894+
nvinfer1::IElementWiseLayer* block0 =
895+
convBnSiLU(network, weightMap, *data, get_width(64, gw, max_channels), {3, 3}, 2, "model.0");
896+
897+
nvinfer1::IElementWiseLayer* block1 = convBnSiLU(network, weightMap, *block0->getOutput(0),
898+
get_width(128, gw, max_channels), {3, 3}, 2, "model.1");
899+
900+
bool c3k = false;
901+
if (type == "m" || type == "l" || type == "x") {
902+
c3k = true;
903+
}
904+
905+
nvinfer1::IElementWiseLayer* conv2 =
906+
C3K2(network, weightMap, *block1->getOutput(0), get_width(128, gw, max_channels),
907+
get_width(256, gw, max_channels), get_depth(2, gd), c3k, true, false, 0.25, "model.2");
908+
909+
nvinfer1::IElementWiseLayer* block3 = convBnSiLU(network, weightMap, *conv2->getOutput(0),
910+
get_width(256, gw, max_channels), {3, 3}, 2, "model.3");
911+
912+
nvinfer1::IElementWiseLayer* block4 =
913+
C3K2(network, weightMap, *block3->getOutput(0), get_width(256, gw, max_channels),
914+
get_width(512, gw, max_channels), get_depth(2, gd), c3k, true, false, 0.25, "model.4");
915+
916+
nvinfer1::IElementWiseLayer* block5 = convBnSiLU(network, weightMap, *block4->getOutput(0),
917+
get_width(512, gw, max_channels), {3, 3}, 2, "model.5");
918+
919+
nvinfer1::IElementWiseLayer* block6 =
920+
C3K2(network, weightMap, *block5->getOutput(0), get_width(512, gw, max_channels),
921+
get_width(512, gw, max_channels), get_depth(2, gd), true, true, false, 0.5, "model.6");
922+
923+
nvinfer1::IElementWiseLayer* block7 = convBnSiLU(network, weightMap, *block6->getOutput(0),
924+
get_width(1024, gw, max_channels), {3, 3}, 2, "model.7");
925+
926+
nvinfer1::IElementWiseLayer* block8 =
927+
C3K2(network, weightMap, *block7->getOutput(0), get_width(1024, gw, max_channels),
928+
get_width(1024, gw, max_channels), get_depth(2, gd), true, true, false, 0.5, "model.8");
929+
930+
nvinfer1::IElementWiseLayer* block9 =
931+
C2PSA(network, weightMap, *block8->getOutput(0), get_width(1024, gw, max_channels),
932+
get_width(1024, gw, max_channels), get_depth(2, gd), 0.5, "model.9");
933+
934+
/////////////////////////////////////////////////////
935+
936+
nvinfer1::IElementWiseLayer* block10_convbn =
937+
convBnSiLU(network, weightMap, *block9->getOutput(0), 1280, {1, 1}, 1, "model.10.conv");
938+
nvinfer1::Dims dims =
939+
block10_convbn->getOutput(0)->getDimensions(); // Obtain the dimensions of the output of conv_class
940+
assert(dims.nbDims == 4);
941+
nvinfer1::IPoolingLayer* block10_pool = network->addPoolingNd(
942+
*block10_convbn->getOutput(0), nvinfer1::PoolingType::kAVERAGE, nvinfer1::DimsHW{dims.d[2], dims.d[3]});
943+
nvinfer1::IShuffleLayer* block10_reshape = network->addShuffle(*block10_pool->getOutput(0));
944+
block10_reshape->setReshapeDimensions(nvinfer1::Dims2{kBatchSize, 1280});
945+
nvinfer1::IConstantLayer* block10_linear_weight =
946+
network->addConstant(nvinfer1::Dims2{kClsNumClass, 1280}, weightMap["model.10.linear.weight"]);
947+
nvinfer1::IConstantLayer* block10_linear_bias =
948+
network->addConstant(nvinfer1::Dims2{kClsNumClass, 1}, weightMap["model.10.linear.bias"]);
949+
nvinfer1::IMatrixMultiplyLayer* block10_linear_matrix_multiply =
950+
network->addMatrixMultiply(*block10_reshape->getOutput(0), nvinfer1::MatrixOperation::kNONE,
951+
*block10_linear_weight->getOutput(0), nvinfer1::MatrixOperation::kTRANSPOSE);
952+
nvinfer1::IElementWiseLayer* block10_linear_add =
953+
network->addElementWise(*block10_linear_matrix_multiply->getOutput(0), *block10_linear_bias->getOutput(0),
954+
nvinfer1::ElementWiseOperation::kSUM);
955+
nvinfer1::IActivationLayer* output =
956+
network->addActivation(*block10_linear_add->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
957+
assert(output);
958+
959+
output->getOutput(0)->setName(kOutputTensorName);
960+
network->markOutput(*output->getOutput(0));
961+
// Use setMemoryPoolLimit instead of deprecated setMaxWorkspaceSize
962+
config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 * (1 << 20));
963+
964+
#if defined(USE_FP16)
965+
config->setFlag(nvinfer1::BuilderFlag::kFP16);
966+
#elif defined(USE_INT8)
967+
std::cerr << "INT8 not supported for YOLO26 model yet." << std::endl;
968+
#endif
969+
970+
std::cout << "Building engine, please wait for a while..." << std::endl;
971+
nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config);
972+
std::cout << "Build engine successfully!" << std::endl;
973+
974+
delete network;
975+
976+
for (auto& mem : weightMap) {
977+
free((void*)(mem.second.values));
978+
}
979+
return serialized_model;
980+
}

0 commit comments

Comments
 (0)