Skip to content

Commit 5c99490

Browse files
[ORT] [TRT] Support SAM2.1 (#464)
* sam2 python version commit * clean code
1 parent 2298d63 commit 5c99490

File tree

5 files changed

+1634
-0
lines changed

5 files changed

+1634
-0
lines changed

docs/python/sam2/engine_build.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import tensorrt as trt
2+
import numpy as np
3+
4+
def build_flexible_sam2_engine():
5+
"""使用 Python API 构建灵活的 SAM2 引擎"""
6+
7+
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
8+
9+
with trt.Builder(TRT_LOGGER) as builder, \
10+
builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
11+
trt.OnnxParser(network, TRT_LOGGER) as parser:
12+
13+
# 解析 ONNX
14+
print("Loading ONNX model...")
15+
with open("sam2.1_large.onnx", 'rb') as model:
16+
if not parser.parse(model.read()):
17+
print("Failed to parse ONNX model")
18+
for error in range(parser.num_errors):
19+
print(f"Error {error}: {parser.get_error(error)}")
20+
return False
21+
22+
# 创建 builder config
23+
config = builder.create_builder_config()
24+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 8 << 30) # 8GB
25+
26+
# 打印网络信息
27+
print(f"Network has {network.num_inputs} inputs and {network.num_outputs} outputs")
28+
29+
for i in range(network.num_inputs):
30+
input_tensor = network.get_input(i)
31+
print(f"Input {i}: {input_tensor.name}, shape: {input_tensor.shape}")
32+
33+
for i in range(network.num_outputs):
34+
output_tensor = network.get_output(i)
35+
print(f"Output {i}: {output_tensor.name}, shape: {output_tensor.shape}")
36+
37+
# 创建优化配置文件
38+
profile = builder.create_optimization_profile()
39+
40+
# 设置动态输入的形状范围
41+
profile.set_shape("image_embeddings", (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64))
42+
profile.set_shape("high_res_features1", (1, 32, 256, 256), (1, 32, 256, 256), (1, 32, 256, 256))
43+
profile.set_shape("high_res_features2", (1, 64, 128, 128), (1, 64, 128, 128), (1, 64, 128, 128))
44+
profile.set_shape("point_coords", (1, 1, 2), (1, 5, 2), (1, 10, 2))
45+
profile.set_shape("point_labels", (1, 1), (1, 5), (1, 10))
46+
profile.set_shape("mask_input", (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256))
47+
profile.set_shape("has_mask_input", (1,), (1,), (1,))
48+
49+
# 关键修复:为形状张量 orig_im_size 设置具体的值
50+
# 使用你想要的目标尺寸 1024x1797
51+
profile.set_shape("orig_im_size", (2,), (2,), (2,))
52+
53+
# 为形状张量设置具体的值
54+
target_size = np.array([1920, 1080], dtype=np.int64)
55+
profile.set_shape_input("orig_im_size", target_size, target_size, target_size)
56+
57+
# 添加配置文件
58+
config.add_optimization_profile(profile)
59+
60+
print("Building TensorRT engine... This may take several minutes.")
61+
print(f"Target output size: {target_size}")
62+
63+
# 构建引擎
64+
serialized_engine = builder.build_serialized_network(network, config)
65+
66+
if serialized_engine is None:
67+
print("Failed to build engine")
68+
return False
69+
70+
# 保存引擎
71+
engine_path = "sam2.1_large_1024x1797_python.engine"
72+
with open(engine_path, 'wb') as f:
73+
f.write(serialized_engine)
74+
75+
print(f"✅ Engine successfully built and saved to: {engine_path}")
76+
77+
# 验证引擎
78+
runtime = trt.Runtime(TRT_LOGGER)
79+
engine = runtime.deserialize_cuda_engine(serialized_engine)
80+
81+
if engine:
82+
print("✅ Engine validation successful!")
83+
print(f"Number of IO tensors: {engine.num_io_tensors}")
84+
85+
# 打印张量信息
86+
for i in range(engine.num_io_tensors):
87+
tensor_name = engine.get_tensor_name(i)
88+
tensor_shape = engine.get_tensor_shape(tensor_name)
89+
tensor_mode = engine.get_tensor_mode(tensor_name)
90+
io_type = "INPUT" if tensor_mode == trt.TensorIOMode.INPUT else "OUTPUT"
91+
print(f" {io_type}: {tensor_name} -> {tensor_shape}")
92+
93+
return True
94+
else:
95+
print("❌ Engine validation failed!")
96+
return False
97+
98+
if __name__ == "__main__":
99+
success = build_flexible_sam2_engine()
100+
if not success:
101+
print("Engine build failed!")

docs/python/sam2/readme.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
### SAM2.1
2+
3+
1. [Download ONNX](https://github.com/ryouchinsa/sam-cpp-macos?tab=readme-ov-file)
4+
2. python engine_build.py
5+
3. python sam_trt.py
6+
7+
``` 运行engine_build.py会生成engine文件 但是目前由于ONNX的属性问题现在只能编译图片大小的engine文件```
8+
9+
```在engine_build.py的line 54中可以修改图片大小的参数```

0 commit comments

Comments
 (0)