1+ import tensorrt as trt
2+ import numpy as np
3+
4+ def build_flexible_sam2_engine ():
5+ """使用 Python API 构建灵活的 SAM2 引擎"""
6+
7+ TRT_LOGGER = trt .Logger (trt .Logger .INFO )
8+
9+ with trt .Builder (TRT_LOGGER ) as builder , \
10+ builder .create_network (1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )) as network , \
11+ trt .OnnxParser (network , TRT_LOGGER ) as parser :
12+
13+ # 解析 ONNX
14+ print ("Loading ONNX model..." )
15+ with open ("sam2.1_large.onnx" , 'rb' ) as model :
16+ if not parser .parse (model .read ()):
17+ print ("Failed to parse ONNX model" )
18+ for error in range (parser .num_errors ):
19+ print (f"Error { error } : { parser .get_error (error )} " )
20+ return False
21+
22+ # 创建 builder config
23+ config = builder .create_builder_config ()
24+ config .set_memory_pool_limit (trt .MemoryPoolType .WORKSPACE , 8 << 30 ) # 8GB
25+
26+ # 打印网络信息
27+ print (f"Network has { network .num_inputs } inputs and { network .num_outputs } outputs" )
28+
29+ for i in range (network .num_inputs ):
30+ input_tensor = network .get_input (i )
31+ print (f"Input { i } : { input_tensor .name } , shape: { input_tensor .shape } " )
32+
33+ for i in range (network .num_outputs ):
34+ output_tensor = network .get_output (i )
35+ print (f"Output { i } : { output_tensor .name } , shape: { output_tensor .shape } " )
36+
37+ # 创建优化配置文件
38+ profile = builder .create_optimization_profile ()
39+
40+ # 设置动态输入的形状范围
41+ profile .set_shape ("image_embeddings" , (1 , 256 , 64 , 64 ), (1 , 256 , 64 , 64 ), (1 , 256 , 64 , 64 ))
42+ profile .set_shape ("high_res_features1" , (1 , 32 , 256 , 256 ), (1 , 32 , 256 , 256 ), (1 , 32 , 256 , 256 ))
43+ profile .set_shape ("high_res_features2" , (1 , 64 , 128 , 128 ), (1 , 64 , 128 , 128 ), (1 , 64 , 128 , 128 ))
44+ profile .set_shape ("point_coords" , (1 , 1 , 2 ), (1 , 5 , 2 ), (1 , 10 , 2 ))
45+ profile .set_shape ("point_labels" , (1 , 1 ), (1 , 5 ), (1 , 10 ))
46+ profile .set_shape ("mask_input" , (1 , 1 , 256 , 256 ), (1 , 1 , 256 , 256 ), (1 , 1 , 256 , 256 ))
47+ profile .set_shape ("has_mask_input" , (1 ,), (1 ,), (1 ,))
48+
49+ # 关键修复:为形状张量 orig_im_size 设置具体的值
50+ # 使用你想要的目标尺寸 1024x1797
51+ profile .set_shape ("orig_im_size" , (2 ,), (2 ,), (2 ,))
52+
53+ # 为形状张量设置具体的值
54+ target_size = np .array ([1920 , 1080 ], dtype = np .int64 )
55+ profile .set_shape_input ("orig_im_size" , target_size , target_size , target_size )
56+
57+ # 添加配置文件
58+ config .add_optimization_profile (profile )
59+
60+ print ("Building TensorRT engine... This may take several minutes." )
61+ print (f"Target output size: { target_size } " )
62+
63+ # 构建引擎
64+ serialized_engine = builder .build_serialized_network (network , config )
65+
66+ if serialized_engine is None :
67+ print ("Failed to build engine" )
68+ return False
69+
70+ # 保存引擎
71+ engine_path = "sam2.1_large_1024x1797_python.engine"
72+ with open (engine_path , 'wb' ) as f :
73+ f .write (serialized_engine )
74+
75+ print (f"✅ Engine successfully built and saved to: { engine_path } " )
76+
77+ # 验证引擎
78+ runtime = trt .Runtime (TRT_LOGGER )
79+ engine = runtime .deserialize_cuda_engine (serialized_engine )
80+
81+ if engine :
82+ print ("✅ Engine validation successful!" )
83+ print (f"Number of IO tensors: { engine .num_io_tensors } " )
84+
85+ # 打印张量信息
86+ for i in range (engine .num_io_tensors ):
87+ tensor_name = engine .get_tensor_name (i )
88+ tensor_shape = engine .get_tensor_shape (tensor_name )
89+ tensor_mode = engine .get_tensor_mode (tensor_name )
90+ io_type = "INPUT" if tensor_mode == trt .TensorIOMode .INPUT else "OUTPUT"
91+ print (f" { io_type } : { tensor_name } -> { tensor_shape } " )
92+
93+ return True
94+ else :
95+ print ("❌ Engine validation failed!" )
96+ return False
97+
98+ if __name__ == "__main__" :
99+ success = build_flexible_sam2_engine ()
100+ if not success :
101+ print ("Engine build failed!" )
0 commit comments