4545from benchmark_args import BaseCommandLineAPI
4646from benchmark_runner import BaseBenchmarkRunner
4747
48+ SAMPLES_IN_DATASET = 10950
49+
4850
4951class CommandLineAPI (BaseCommandLineAPI ):
5052
@@ -153,6 +155,15 @@ def __init__(self):
153155 "models, False for cased models."
154156 )
155157
158+ def _post_process_args (self , args ):
159+ args = super (CommandLineAPI , self )._post_process_args (args )
160+ args .num_warmup_iterations = min (
161+ int (SAMPLES_IN_DATASET / (args .batch_size ) / 2 ),
162+ args .num_warmup_iterations
163+ )
164+
165+ return args
166+
156167
157168# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
158169# %%%%%%%%%%%%%%%%% IMPLEMENT MODEL-SPECIFIC FUNCTIONS HERE %%%%%%%%%%%%%%%%%% #
@@ -178,40 +189,44 @@ def get_dataset_batches(self):
178189
179190 def get_dataset_from_features (features , batch_size ):
180191
181- all_unique_ids = tf .convert_to_tensor ([
182- f .unique_id for f in features
183- ],
184- dtype = tf .int64 )
185- all_input_ids = tf .convert_to_tensor ([
186- f .input_ids for f in features
187- ],
188- dtype = tf .int64 )
189- all_input_mask = tf .convert_to_tensor ([
190- f .attention_mask for f in features
191- ],
192- dtype = tf .int64 )
193- all_segment_ids = tf .convert_to_tensor ([
194- f .token_type_ids for f in features
195- ],
196- dtype = tf .int64 )
197- all_start_pos = tf .convert_to_tensor ([
198- f .start_position for f in features
199- ],
200- dtype = tf .int64 )
201- all_end_pos = tf .convert_to_tensor ([
202- f .end_position for f in features
203- ],
204- dtype = tf .int64 )
205- all_cls_index = tf .convert_to_tensor ([
206- f .cls_index for f in features
207- ],
208- dtype = tf .int64 )
209- all_p_mask = tf .convert_to_tensor ([f .p_mask for f in features ],
210- dtype = tf .float32 )
211- all_is_impossible = tf .convert_to_tensor ([
212- f .is_impossible for f in features
213- ],
214- dtype = tf .float32 )
192+ # yapf: disable
193+ all_unique_ids = tf .convert_to_tensor (
194+ [f .unique_id for f in features ],
195+ dtype = tf .int64
196+ )
197+ all_input_ids = tf .convert_to_tensor (
198+ [f .input_ids for f in features ],
199+ dtype = tf .int64
200+ )
201+ all_input_mask = tf .convert_to_tensor (
202+ [f .attention_mask for f in features ],
203+ dtype = tf .int64
204+ )
205+ all_segment_ids = tf .convert_to_tensor (
206+ [f .token_type_ids for f in features ],
207+ dtype = tf .int64
208+ )
209+ all_start_pos = tf .convert_to_tensor (
210+ [f .start_position for f in features ],
211+ dtype = tf .int64
212+ )
213+ all_end_pos = tf .convert_to_tensor (
214+ [f .end_position for f in features ],
215+ dtype = tf .int64
216+ )
217+ all_cls_index = tf .convert_to_tensor (
218+ [f .cls_index for f in features ],
219+ dtype = tf .int64
220+ )
221+ all_p_mask = tf .convert_to_tensor (
222+ [f .p_mask for f in features ],
223+ dtype = tf .float32
224+ )
225+ all_is_impossible = tf .convert_to_tensor (
226+ [f .is_impossible for f in features ],
227+ dtype = tf .float32
228+ )
229+ # yapf: enable
215230
216231 dataset = tf .data .Dataset .from_tensor_slices ((
217232 all_unique_ids , all_input_ids , all_input_mask , all_segment_ids ,
0 commit comments