Skip to content

Commit 092469f

Browse files
committed
more pep8; py2 vs. py3 compat
1 parent 31d6464 commit 092469f

File tree

6 files changed

+24
-13
lines changed

6 files changed

+24
-13
lines changed

tensorflowonspark/TFManager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from multiprocessing.managers import BaseManager
1111
from multiprocessing import JoinableQueue
1212

13+
1314
class TFManager(BaseManager):
1415
"""Python multiprocessing.Manager for distributed, multi-process communication."""
1516
pass
@@ -20,18 +21,22 @@ class TFManager(BaseManager):
2021
qdict = {} # dictionary of queues
2122
kdict = {} # dictionary of key-values
2223

24+
2325
def _get(key):
2426
return kdict[key]
2527

28+
2629
def _set(key, value):
2730
kdict[key] = value
2831

32+
2933
def _get_queue(qname):
3034
try:
3135
return qdict[qname]
3236
except KeyError:
3337
return None
3438

39+
3540
def start(authkey, queues, mode='local'):
3641
"""Create a new multiprocess.Manager (or return existing one).
3742
@@ -53,12 +58,13 @@ def start(authkey, queues, mode='local'):
5358
TFManager.register('get', callable=lambda key: _get(key))
5459
TFManager.register('set', callable=lambda key, value: _set(key, value))
5560
if mode == 'remote':
56-
mgr = TFManager(address=('',0), authkey=authkey)
61+
mgr = TFManager(address=('', 0), authkey=authkey)
5762
else:
5863
mgr = TFManager(authkey=authkey)
5964
mgr.start()
6065
return mgr
6166

67+
6268
def connect(address, authkey):
6369
"""Connect to a multiprocess.Manager.
6470
@@ -75,4 +81,3 @@ def connect(address, authkey):
7581
m = TFManager(address, authkey=authkey)
7682
m.connect()
7783
return m
78-

tensorflowonspark/dfutil.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ def _toTFFeature(name, dtype, row):
109109
elif dtype in int64_dtypes:
110110
feature = (name, tf.train.Feature(int64_list=tf.train.Int64List(value=[row[name]])))
111111
elif dtype in bytes_dtypes:
112-
feature = (name, tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(row[name])])))
112+
if dtype == 'binary':
113+
feature = (name, tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(row[name])])))
114+
else:
115+
feature = (name, tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(row[name]).encode('utf-8')])))
113116
elif dtype in float_list_dtypes:
114117
feature = (name, tf.train.Feature(float_list=tf.train.FloatList(value=row[name])))
115118
elif dtype in int64_list_dtypes:
@@ -181,16 +184,15 @@ def fromTFExample(iter, binary_features=[]):
181184
"""
182185
# convert from protobuf-like dict to DataFrame-friendly dict
183186
def _get_value(k, v):
184-
# special handling for binary features
185-
if k in binary_features:
186-
return bytearray(v.bytes_list.value[0])
187-
188187
if v.int64_list.value:
189188
result = v.int64_list.value
190189
elif v.float_list.value:
191190
result = v.float_list.value
192-
else:
193-
result = v.bytes_list.value
191+
else: # string or bytearray
192+
if k in binary_features:
193+
return bytearray(v.bytes_list.value[0])
194+
else:
195+
return v.bytes_list.value[0].decode('utf-8')
194196

195197
if len(result) > 1: # represent multi-item tensors as python lists
196198
return list(result)

tensorflowonspark/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _fit(self, dataset):
391391
assert local_args.tfrecord_dir, "Please specify --tfrecord_dir to export DataFrame to TFRecord."
392392
if self.getInputMapping():
393393
# if input mapping provided, filter only required columns before exporting
394-
dataset = dataset.select(self.getInputMapping().keys())
394+
dataset = dataset.select(list(self.getInputMapping()))
395395
logging.info("Exporting DataFrame {} as TFRecord to: {}".format(dataset.dtypes, local_args.tfrecord_dir))
396396
dfutil.saveAsTFRecords(dataset, local_args.tfrecord_dir)
397397
logging.info("Done saving")
@@ -401,7 +401,7 @@ def _fit(self, dataset):
401401
local_args.tensorboard, local_args.input_mode, driver_ps_nodes=local_args.driver_ps_nodes)
402402
if local_args.input_mode == TFCluster.InputMode.SPARK:
403403
# feed data, using a deterministic order for input columns (lexicographic by key)
404-
input_cols = sorted(self.getInputMapping().keys())
404+
input_cols = sorted(self.getInputMapping())
405405
cluster.train(dataset.select(input_cols).rdd, local_args.epochs)
406406
cluster.shutdown()
407407

test/test_TFNode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_hdfs_path(self):
2626

2727
def test_datafeed(self):
2828
"""TFNode.DataFeed basic operations"""
29-
mgr = TFManager.start('abc', ['input', 'output'], 'local')
29+
mgr = TFManager.start('abc'.encode('utf-8'), ['input', 'output'], 'local')
3030

3131
# insert 10 numbers followed by an end-of-feed marker
3232
q = mgr.get_queue('input')

test/test_dfutil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_dfutils(self):
3131
row1 = ('text string', 1, [2, 3, 4, 5], -1.1, [-2.2, -3.3, -4.4, -5.5], bytearray(b'\xff\xfe\xfd\xfc'))
3232
rdd = self.sc.parallelize([row1])
3333
df1 = self.spark.createDataFrame(rdd, ['a', 'b', 'c', 'd', 'e', 'f'])
34-
print ("schema: {}".format(df1.schema))
34+
print("schema: {}".format(df1.schema))
3535

3636
# save the DataFrame as TFRecords
3737
dfutil.saveAsTFRecords(df1, self.tfrecord_dir)

test/test_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import shutil
44
import test
5+
import time
56
import unittest
67

78
from tensorflowonspark import TFCluster, dfutil
@@ -323,6 +324,9 @@ def _get_examples(batch_size):
323324
ckpt_name = args.model_dir + "/model.ckpt"
324325
print("Saving checkpoint to: {}".format(ckpt_name))
325326
saver.save(sess, ckpt_name)
327+
328+
# wait for rest of cluster to connect
329+
time.sleep(30)
326330
sv.stop()
327331

328332
def _tf_export(args):

0 commit comments

Comments
 (0)