Skip to content
This repository was archived by the owner on Mar 14, 2025. It is now read-only.

Commit 0ece6bd

Browse files
committed
remove opt profile logic if implicit batch network
1 parent 42208c0 commit 0ece6bd

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

classification/imagenet/onnx_to_tensorrt.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,31 +81,28 @@ def get_batch_sizes(max_batch_size):
8181

8282

8383
# TODO: This only covers dynamic shape for batch size, not dynamic shape for other dimensions
84-
def create_optimization_profiles(builder, inputs, explicit_batch, batch_sizes=[1,8,16,32,64]):
84+
def create_optimization_profiles(builder, inputs, batch_sizes=[1,8,16,32,64]):
8585
# Check if all inputs are fixed explicit batch to create a single profile and avoid duplicates
86-
if explicit_batch and all([inp.shape[0] > -1 for inp in inputs]):
86+
if all([inp.shape[0] > -1 for inp in inputs]):
8787
profile = builder.create_optimization_profile()
8888
for inp in inputs:
8989
fbs, shape = inp.shape[0], inp.shape[1:]
9090
profile.set_shape(inp.name, min=(fbs, *shape), opt=(fbs, *shape), max=(fbs, *shape))
9191
return [profile]
9292

93-
# Otherwise for implicit, or mixed fixed+dynamic explicit batch inputs, create several profiles
93+
# Otherwise for mixed fixed+dynamic explicit batch inputs, create several profiles
9494
profiles = {}
9595
for bs in batch_sizes:
9696
if not profiles.get(bs):
9797
profiles[bs] = builder.create_optimization_profile()
9898

9999
for inp in inputs:
100-
shape = inp.shape[1:] if explicit_batch else inp.shape
100+
shape = inp.shape[1:]
101+
# Check if fixed explicit batch
102+
if inp.shape[0] > -1:
103+
bs = inp.shape[0]
101104

102-
# Dynamic explicit batch or implicit batch
103-
if inp.shape[0] == -1 or not explicit_batch:
104-
profiles[bs].set_shape(inp.name, min=(bs, *shape), opt=(bs, *shape), max=(bs, *shape))
105-
# Fixed explicit batch
106-
else:
107-
fbs = inp.shape[0]
108-
profiles[bs].set_shape(inp.name, min=(fbs, *shape), opt=(fbs, *shape), max=(fbs, *shape))
105+
profiles[bs].set_shape(inp.name, min=(bs, *shape), opt=(bs, *shape), max=(bs, *shape))
109106

110107
return list(profiles.values())
111108

@@ -196,14 +193,14 @@ def main():
196193
# Display network info and check certain properties
197194
check_network(network)
198195

199-
# Add optimization profiles
200-
batch_sizes = [1, 8, 16, 32, 64]
201-
inputs = [network.get_input(i) for i in range(network.num_inputs)]
202-
opt_profiles = create_optimization_profiles(builder, inputs, args.explicit_batch, batch_sizes)
203-
add_profiles(config, inputs, opt_profiles)
204-
196+
if args.explicit_batch:
197+
# Add optimization profiles
198+
batch_sizes = [1, 8, 16, 32, 64]
199+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
200+
opt_profiles = create_optimization_profiles(builder, inputs, args.explicit_batch, batch_sizes)
201+
add_profiles(config, inputs, opt_profiles)
205202
# Implicit Batch Network
206-
if not args.explicit_batch:
203+
else:
207204
builder.max_batch_size = args.max_batch_size
208205

209206
logger.info("Building Engine...")

0 commit comments

Comments
 (0)