@@ -231,6 +231,9 @@ def __init__(self, mgr, train_mode=True, qname_in='input', qname_out='output', i
231231 self .done_feeding = False
232232 self .input_tensors = [tensor for col , tensor in sorted (input_mapping .items ())] if input_mapping is not None else None
233233
234+ self .queue_in = mgr .get_queue (qname_in )
235+ self .queue_out = mgr .get_queue (qname_out )
236+
234237 def next_batch (self , batch_size ):
235238 """Gets a batch of items from the input RDD.
236239
@@ -249,34 +252,33 @@ def next_batch(self, batch_size):
249252 Returns:
250253 A batch of items or a dictionary of tensors.
251254 """
252- logger .debug ("next_batch() invoked" )
253- queue = self .mgr .get_queue (self .qname_in )
254255 tensors = [] if self .input_tensors is None else {tensor : [] for tensor in self .input_tensors }
255256 count = 0
257+ queue_in = self .queue_in
258+ no_input_tensors = self .input_tensors is None
256259 while count < batch_size :
257- item = queue .get (block = True )
260+ item = queue_in .get (block = True )
258261 if item is None :
259262 # End of Feed
260263 logger .info ("next_batch() got None" )
261- queue .task_done ()
264+ queue_in .task_done ()
262265 self .done_feeding = True
263266 break
264267 elif type (item ) is marker .EndPartition :
265268 # End of Partition
266269 logger .info ("next_batch() got EndPartition" )
267- queue .task_done ()
270+ queue_in .task_done ()
268271 if not self .train_mode and count > 0 :
269272 break
270273 else :
271274 # Normal item
272- if self . input_tensors is None :
275+ if no_input_tensors :
273276 tensors .append (item )
274277 else :
275278 for i in range (len (self .input_tensors )):
276279 tensors [self .input_tensors [i ]].append (item [i ])
277280 count += 1
278- queue .task_done ()
279- logger .debug ("next_batch() returning {0} items" .format (count ))
281+ queue_in .task_done ()
280282 return tensors
281283
282284 def should_stop (self ):
@@ -292,11 +294,9 @@ def batch_results(self, results):
292294 Args:
293295 :results: array of output data for the equivalent batch of input data.
294296 """
295- logger .debug ("batch_results() invoked" )
296- queue = self .mgr .get_queue (self .qname_out )
297+ queue = self .queue_out
297298 for item in results :
298299 queue .put (item , block = True )
299- logger .debug ("batch_results() returning data" )
300300
301301 def terminate (self ):
302302 """Terminate data feeding early.
0 commit comments