@@ -81,31 +81,28 @@ def get_batch_sizes(max_batch_size):
8181
8282
8383# TODO: This only covers dynamic shape for batch size, not dynamic shape for other dimensions
84- def create_optimization_profiles (builder , inputs , explicit_batch , batch_sizes = [1 ,8 ,16 ,32 ,64 ]):
84+ def create_optimization_profiles (builder , inputs , batch_sizes = [1 ,8 ,16 ,32 ,64 ]):
8585 # Check if all inputs are fixed explicit batch to create a single profile and avoid duplicates
86- if explicit_batch and all ([inp .shape [0 ] > - 1 for inp in inputs ]):
86+ if all ([inp .shape [0 ] > - 1 for inp in inputs ]):
8787 profile = builder .create_optimization_profile ()
8888 for inp in inputs :
8989 fbs , shape = inp .shape [0 ], inp .shape [1 :]
9090 profile .set_shape (inp .name , min = (fbs , * shape ), opt = (fbs , * shape ), max = (fbs , * shape ))
9191 return [profile ]
9292
93- # Otherwise for implicit, or mixed fixed+dynamic explicit batch inputs, create several profiles
93+ # Otherwise for mixed fixed+dynamic explicit batch inputs, create several profiles
9494 profiles = {}
9595 for bs in batch_sizes :
9696 if not profiles .get (bs ):
9797 profiles [bs ] = builder .create_optimization_profile ()
9898
9999 for inp in inputs :
100- shape = inp .shape [1 :] if explicit_batch else inp .shape
100+ shape = inp .shape [1 :]
101+ # Check if fixed explicit batch
102+ if inp .shape [0 ] > - 1 :
103+ bs = inp .shape [0 ]
101104
102- # Dynamic explicit batch or implicit batch
103- if inp .shape [0 ] == - 1 or not explicit_batch :
104- profiles [bs ].set_shape (inp .name , min = (bs , * shape ), opt = (bs , * shape ), max = (bs , * shape ))
105- # Fixed explicit batch
106- else :
107- fbs = inp .shape [0 ]
108- profiles [bs ].set_shape (inp .name , min = (fbs , * shape ), opt = (fbs , * shape ), max = (fbs , * shape ))
105+ profiles [bs ].set_shape (inp .name , min = (bs , * shape ), opt = (bs , * shape ), max = (bs , * shape ))
109106
110107 return list (profiles .values ())
111108
@@ -196,14 +193,14 @@ def main():
196193 # Display network info and check certain properties
197194 check_network (network )
198195
199- # Add optimization profiles
200- batch_sizes = [ 1 , 8 , 16 , 32 , 64 ]
201- inputs = [network . get_input ( i ) for i in range ( network . num_inputs ) ]
202- opt_profiles = create_optimization_profiles ( builder , inputs , args . explicit_batch , batch_sizes )
203- add_profiles ( config , inputs , opt_profiles )
204-
196+ if args . explicit_batch :
197+ # Add optimization profiles
198+ batch_sizes = [1 , 8 , 16 , 32 , 64 ]
199+ inputs = [ network . get_input ( i ) for i in range ( network . num_inputs )]
200+ opt_profiles = create_optimization_profiles ( builder , inputs , args . explicit_batch , batch_sizes )
201+ add_profiles ( config , inputs , opt_profiles )
205202 # Implicit Batch Network
206- if not args . explicit_batch :
203+ else :
207204 builder .max_batch_size = args .max_batch_size
208205
209206 logger .info ("Building Engine..." )
0 commit comments