Skip to content

Commit e989e53

Browse files
authored
Merge pull request #492 from yahoo/leewyang_perf
InputMode.SPARK perf optimization
2 parents 12e6595 + c8db6ae commit e989e53

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

examples/mnist/estimator/mnist_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def parse(ln):
181181
model = TFModel(args) \
182182
.setInputMapping({'image': 'features'}) \
183183
.setOutputMapping({'logits': 'prediction'}) \
184+
.setSignatureDefKey('serving_default') \
184185
.setExportDir(args.export_dir) \
185186
.setBatchSize(args.batch_size)
186187

examples/mnist/keras/mnist_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def parse(ln):
134134
model = TFModel(args) \
135135
.setInputMapping({'image': 'conv2d_input'}) \
136136
.setOutputMapping({'dense_1': 'prediction'}) \
137+
.setSignatureDefKey('serving_default') \
137138
.setExportDir(args.export_dir) \
138139
.setBatchSize(args.batch_size)
139140

tensorflowonspark/TFNode.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ def __init__(self, mgr, train_mode=True, qname_in='input', qname_out='output', i
231231
self.done_feeding = False
232232
self.input_tensors = [tensor for col, tensor in sorted(input_mapping.items())] if input_mapping is not None else None
233233

234+
self.queue_in = mgr.get_queue(qname_in)
235+
self.queue_out = mgr.get_queue(qname_out)
236+
234237
def next_batch(self, batch_size):
235238
"""Gets a batch of items from the input RDD.
236239
@@ -249,34 +252,33 @@ def next_batch(self, batch_size):
249252
Returns:
250253
A batch of items or a dictionary of tensors.
251254
"""
252-
logger.debug("next_batch() invoked")
253-
queue = self.mgr.get_queue(self.qname_in)
254255
tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors}
255256
count = 0
257+
queue_in = self.queue_in
258+
no_input_tensors = self.input_tensors is None
256259
while count < batch_size:
257-
item = queue.get(block=True)
260+
item = queue_in.get(block=True)
258261
if item is None:
259262
# End of Feed
260263
logger.info("next_batch() got None")
261-
queue.task_done()
264+
queue_in.task_done()
262265
self.done_feeding = True
263266
break
264267
elif type(item) is marker.EndPartition:
265268
# End of Partition
266269
logger.info("next_batch() got EndPartition")
267-
queue.task_done()
270+
queue_in.task_done()
268271
if not self.train_mode and count > 0:
269272
break
270273
else:
271274
# Normal item
272-
if self.input_tensors is None:
275+
if no_input_tensors:
273276
tensors.append(item)
274277
else:
275278
for i in range(len(self.input_tensors)):
276279
tensors[self.input_tensors[i]].append(item[i])
277280
count += 1
278-
queue.task_done()
279-
logger.debug("next_batch() returning {0} items".format(count))
281+
queue_in.task_done()
280282
return tensors
281283

282284
def should_stop(self):
@@ -292,11 +294,9 @@ def batch_results(self, results):
292294
Args:
293295
:results: array of output data for the equivalent batch of input data.
294296
"""
295-
logger.debug("batch_results() invoked")
296-
queue = self.mgr.get_queue(self.qname_out)
297+
queue = self.queue_out
297298
for item in results:
298299
queue.put(item, block=True)
299-
logger.debug("batch_results() returning data")
300300

301301
def terminate(self):
302302
"""Terminate data feeding early.

0 commit comments

Comments
 (0)