Skip to content

Commit d319630

Browse files
authored
Merge pull request #282 from tensorlayer/luo-yapf2
format core library using yapf.
2 parents 73e64d4 + 92c861a commit d319630

File tree

10 files changed

+1821
-1777
lines changed

10 files changed

+1821
-1777
lines changed

tensorlayer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
from __future__ import absolute_import
55

6-
76
try:
87
install_instr = "Please make sure you install a recent enough version of TensorFlow."
98
import tensorflow

tensorlayer/activation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def identity(x, name=None):
1818
"""
1919
return x
2020

21+
2122
# Shortcut
2223
linear = identity
2324

25+
2426
def ramp(x=None, v_min=0, v_max=1, name=None):
2527
"""The ramp activation function.
2628
@@ -41,6 +43,7 @@ def ramp(x=None, v_min=0, v_max=1, name=None):
4143
"""
4244
return tf.clip_by_value(x, clip_value_min=v_min, clip_value_max=v_max, name=name)
4345

46+
4447
def leaky_relu(x=None, alpha=0.1, name="lrelu"):
4548
"""The LeakyReLU, Shortcut is ``lrelu``.
4649
@@ -65,12 +68,13 @@ def leaky_relu(x=None, alpha=0.1, name="lrelu"):
6568
- `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) <http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf>`_
6669
"""
6770
# with tf.name_scope(name) as scope:
68-
# x = tf.nn.relu(x)
69-
# m_x = tf.nn.relu(-x)
70-
# x -= alpha * m_x
71+
# x = tf.nn.relu(x)
72+
# m_x = tf.nn.relu(-x)
73+
# x -= alpha * m_x
7174
x = tf.maximum(x, alpha * x, name=name)
7275
return x
7376

77+
7478
#Shortcut
7579
lrelu = leaky_relu
7680

@@ -88,9 +92,10 @@ def swish(x, name='swish'):
8892
A `Tensor` with the same type as `x`.
8993
"""
9094
with tf.name_scope(name) as scope:
91-
x = tf.nn.sigmoid(x) * x
95+
x = tf.nn.sigmoid(x) * x
9296
return x
9397

98+
9499
def pixel_wise_softmax(output, name='pixel_wise_softmax'):
95100
"""Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1.
96101
Usually be used for image segmentation.

tensorlayer/cost.py

Lines changed: 181 additions & 199 deletions
Large diffs are not rendered by default.

tensorlayer/db.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Latest Version
77
"""
88

9-
109
import inspect
1110
import math
1211
import pickle
@@ -23,15 +22,12 @@
2322

2423

2524
def AutoFill(func):
26-
def func_wrapper(self,*args,**kwargs):
27-
d=inspect.getcallargs(func,self,*args,**kwargs)
28-
d['args'].update({"studyID":self.studyID})
29-
return func(**d)
30-
return func_wrapper
31-
32-
33-
25+
def func_wrapper(self, *args, **kwargs):
26+
d = inspect.getcallargs(func, self, *args, **kwargs)
27+
d['args'].update({"studyID": self.studyID})
28+
return func(**d)
3429

30+
return func_wrapper
3531

3632

3733
class TensorDB(object):
@@ -68,32 +64,24 @@ class TensorDB(object):
6864
1 : You may like to install MongoChef or Mongo Management Studo APP for
6965
visualizing or testing your MongoDB.
7066
"""
71-
def __init__(
72-
self,
73-
ip = 'localhost',
74-
port = 27017,
75-
db_name = 'db_name',
76-
user_name = None,
77-
password = 'password',
78-
studyID=None
79-
):
67+
68+
def __init__(self, ip='localhost', port=27017, db_name='db_name', user_name=None, password='password', studyID=None):
8069
## connect mongodb
8170
client = MongoClient(ip, port)
8271
self.db = client[db_name]
8372
if user_name != None:
8473
self.db.authenticate(user_name, password)
8574

86-
8775
if studyID is None:
88-
self.studyID=str(uuid.uuid1())
76+
self.studyID = str(uuid.uuid1())
8977
else:
90-
self.studyID=studyID
78+
self.studyID = studyID
9179

9280
## define file system (Buckets)
9381
self.datafs = gridfs.GridFS(self.db, collection="datafs")
9482
self.modelfs = gridfs.GridFS(self.db, collection="modelfs")
9583
self.paramsfs = gridfs.GridFS(self.db, collection="paramsfs")
96-
self.archfs=gridfs.GridFS(self.db,collection="ModelArchitecture")
84+
self.archfs = gridfs.GridFS(self.db, collection="ModelArchitecture")
9785
##
9886
print("[TensorDB] Connect SUCCESS {}:{} {} {} {}".format(ip, port, db_name, user_name, studyID))
9987

@@ -102,16 +90,16 @@ def __init__(
10290
self.db_name = db_name
10391
self.user_name = user_name
10492

105-
def __autofill(self,args):
106-
return args.update({'studyID':self.studyID})
93+
def __autofill(self, args):
94+
return args.update({'studyID': self.studyID})
10795

108-
def __serialization(self,ps):
96+
def __serialization(self, ps):
10997
return pickle.dumps(ps, protocol=2)
11098

111-
def __deserialization(self,ps):
99+
def __deserialization(self, ps):
112100
return pickle.loads(ps)
113101

114-
def save_params(self, params=[], args={}):#, file_name='parameters'):
102+
def save_params(self, params=[], args={}): #, file_name='parameters'):
115103
""" Save parameters into MongoDB Buckets, and save the file ID into Params Collections.
116104
117105
Parameters
@@ -125,15 +113,15 @@ def save_params(self, params=[], args={}):#, file_name='parameters'):
125113
"""
126114
self.__autofill(args)
127115
s = time.time()
128-
f_id = self.paramsfs.put(self.__serialization(params))#, file_name=file_name)
116+
f_id = self.paramsfs.put(self.__serialization(params)) #, file_name=file_name)
129117
args.update({'f_id': f_id, 'time': datetime.utcnow()})
130118
self.db.Params.insert_one(args)
131119
# print("[TensorDB] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2)))
132-
print("[TensorDB] Save params: SUCCESS, took: {}s".format(round(time.time()-s, 2)))
120+
print("[TensorDB] Save params: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
133121
return f_id
134122

135123
@AutoFill
136-
def find_one_params(self, args={},sort=None):
124+
def find_one_params(self, args={}, sort=None):
137125
""" Find one parameter from MongoDB Buckets.
138126
139127
Parameters
@@ -148,7 +136,7 @@ def find_one_params(self, args={},sort=None):
148136

149137
s = time.time()
150138
# print(args)
151-
d = self.db.Params.find_one(filter=args,sort=sort)
139+
d = self.db.Params.find_one(filter=args, sort=sort)
152140

153141
if d is not None:
154142
f_id = d['f_id']
@@ -157,7 +145,7 @@ def find_one_params(self, args={},sort=None):
157145
return False, False
158146
try:
159147
params = self.__deserialization(self.paramsfs.get(f_id).read())
160-
print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time()-s, 2)))
148+
print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time() - s, 2)))
161149
return params, f_id
162150
except:
163151
return False, False
@@ -182,14 +170,14 @@ def find_all_params(self, args={}):
182170
if pc is not None:
183171
f_id_list = pc.distinct('f_id')
184172
params = []
185-
for f_id in f_id_list: # you may have multiple Buckets files
173+
for f_id in f_id_list: # you may have multiple Buckets files
186174
tmp = self.paramsfs.get(f_id).read()
187175
params.append(self.__deserialization(tmp))
188176
else:
189177
print("[TensorDB] FAIL! Cannot find any: {}".format(args))
190178
return False
191179

192-
print("[TensorDB] Find all params SUCCESS, took: {}s".format(round(time.time()-s, 2)))
180+
print("[TensorDB] Find all params SUCCESS, took: {}s".format(round(time.time() - s, 2)))
193181
return params
194182

195183
@AutoFill
@@ -217,7 +205,7 @@ def _print_dict(self, args):
217205
string = ''
218206
for key, value in args.items():
219207
if key is not '_id':
220-
string += str(key) + ": "+ str(value) + " / "
208+
string += str(key) + ": " + str(value) + " / "
221209
return string
222210

223211
## =========================== LOG =================================== ##
@@ -267,7 +255,7 @@ def valid_log(self, args={}):
267255
_result = self.db.ValidLog.insert_one(args)
268256
# _log = "".join(str(key) + ": " + str(value) for key, value in args.items())
269257
_log = self._print_dict(args)
270-
print("[TensorDB] ValidLog: " +_log)
258+
print("[TensorDB] ValidLog: " + _log)
271259
return _result
272260

273261
@AutoFill
@@ -297,7 +285,7 @@ def test_log(self, args={}):
297285
_result = self.db.TestLog.insert_one(args)
298286
# _log = "".join(str(key) + str(value) for key, value in args.items())
299287
_log = self._print_dict(args)
300-
print("[TensorDB] TestLog: " +_log)
288+
print("[TensorDB] TestLog: " + _log)
301289
return _result
302290

303291
@AutoFill
@@ -314,14 +302,14 @@ def del_test_log(self, args={}):
314302

315303
## =========================== Network Architecture ================== ##
316304
@AutoFill
317-
def save_model_architecture(self,s,args={}):
305+
def save_model_architecture(self, s, args={}):
318306
self.__autofill(args)
319-
fid=self.archfs.put(s,filename="modelarchitecture")
320-
args.update({"fid":fid})
307+
fid = self.archfs.put(s, filename="modelarchitecture")
308+
args.update({"fid": fid})
321309
self.db.march.insert_one(args)
322310

323311
@AutoFill
324-
def load_model_architecture(self,args={}):
312+
def load_model_architecture(self, args={}):
325313

326314
d = self.db.march.find_one(args)
327315
if d is not None:
@@ -331,7 +319,7 @@ def load_model_architecture(self,args={}):
331319
# "print find"
332320
else:
333321
print("[TensorDB] FAIL! Cannot find: {}".format(args))
334-
print ("no idtem")
322+
print("no idtem")
335323
return False, False
336324
try:
337325
archs = self.archfs.get(fid).read()
@@ -385,7 +373,6 @@ def find_one_job(self, args={}):
385373
dictionary : contains all meta data and script.
386374
"""
387375

388-
389376
temp = self.db.Job.find_one(args)
390377

391378
if temp is not None:
@@ -400,34 +387,34 @@ def find_one_job(self, args={}):
400387

401388
return temp
402389

403-
def push_job(self,margs, wargs,dargs,epoch):
390+
def push_job(self, margs, wargs, dargs, epoch):
404391

405-
ms,mid=self.load_model_architecture(margs)
406-
weight,wid=self.find_one_params(wargs)
407-
args={"weight":wid,"model":mid,"dargs":dargs,"epoch":epoch,"time":datetime.utcnow(),"Running":False}
392+
ms, mid = self.load_model_architecture(margs)
393+
weight, wid = self.find_one_params(wargs)
394+
args = {"weight": wid, "model": mid, "dargs": dargs, "epoch": epoch, "time": datetime.utcnow(), "Running": False}
408395
self.__autofill(args)
409396
self.db.JOBS.insert_one(args)
410397

411398
def peek_job(self):
412-
args={'Running':False}
399+
args = {'Running': False}
413400
self.__autofill(args)
414-
m=self.db.JOBS.find_one(args)
401+
m = self.db.JOBS.find_one(args)
415402
print(m)
416403
if m is None:
417404
return False
418405

419-
s=self.paramsfs.get(m['weight']).read()
420-
w=self.__deserialization(s)
406+
s = self.paramsfs.get(m['weight']).read()
407+
w = self.__deserialization(s)
421408

422-
ach=self.archfs.get(m['model']).read()
409+
ach = self.archfs.get(m['model']).read()
423410

424-
return m['_id'], ach,w,m["dargs"],m['epoch']
411+
return m['_id'], ach, w, m["dargs"], m['epoch']
425412

426-
def run_job(self,jid):
427-
self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Since":datetime.utcnow()}})
413+
def run_job(self, jid):
414+
self.db.JOBS.find_one_and_update({'_id': jid}, {'$set': {'Running': True, "Since": datetime.utcnow()}})
428415

429-
def del_job(self,jid):
430-
self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Finished":datetime.utcnow()}})
416+
def del_job(self, jid):
417+
self.db.JOBS.find_one_and_update({'_id': jid}, {'$set': {'Running': True, "Finished": datetime.utcnow()}})
431418

432419
def __str__(self):
433420
_s = "[TensorDB] Info:\n"
@@ -502,49 +489,50 @@ def __str__(self):
502489
# return data
503490

504491

505-
506492
class DBLogger:
507493
""" """
508-
def __init__(self,db,model):
509-
self.db=db
510-
self.model=model
511494

512-
def on_train_begin(self,logs={}):
495+
def __init__(self, db, model):
496+
self.db = db
497+
self.model = model
498+
499+
def on_train_begin(self, logs={}):
513500
print("start")
514501

515-
def on_train_end(self,logs={}):
502+
def on_train_end(self, logs={}):
516503
print("end")
517504

518-
def on_epoch_begin(self,epoch,logs={}):
519-
self.epoch=epoch
520-
self.et=time.time()
505+
def on_epoch_begin(self, epoch, logs={}):
506+
self.epoch = epoch
507+
self.et = time.time()
521508
return
522509

523510
def on_epoch_end(self, epoch, logs={}):
524-
self.et=time.time()-self.et
511+
self.et = time.time() - self.et
525512
print("ending")
526513
print(epoch)
527-
logs['epoch']=epoch
528-
logs['time']=datetime.utcnow()
529-
logs['stepTime']=self.et
530-
logs['acc']=np.asscalar(logs['acc'])
514+
logs['epoch'] = epoch
515+
logs['time'] = datetime.utcnow()
516+
logs['stepTime'] = self.et
517+
logs['acc'] = np.asscalar(logs['acc'])
531518
print(logs)
532519

533-
w=self.model.Params
534-
fid=self.db.save_params(w,logs)
535-
logs.update({'params':fid})
520+
w = self.model.Params
521+
fid = self.db.save_params(w, logs)
522+
logs.update({'params': fid})
536523
self.db.valid_log(logs)
537-
def on_batch_begin(self, batch,logs={}):
538-
self.t=time.time()
524+
525+
def on_batch_begin(self, batch, logs={}):
526+
self.t = time.time()
539527
self.losses = []
540-
self.batch=batch
528+
self.batch = batch
541529

542530
def on_batch_end(self, batch, logs={}):
543-
self.t2=time.time()-self.t
544-
logs['acc']=np.asscalar(logs['acc'])
531+
self.t2 = time.time() - self.t
532+
logs['acc'] = np.asscalar(logs['acc'])
545533
#logs['loss']=np.asscalar(logs['loss'])
546-
logs['step_time']=self.t2
547-
logs['time']=datetime.utcnow()
548-
logs['epoch']=self.epoch
549-
logs['batch']=self.batch
534+
logs['step_time'] = self.t2
535+
logs['time'] = datetime.utcnow()
536+
logs['epoch'] = self.epoch
537+
logs['batch'] = self.batch
550538
self.db.train_log(logs)

0 commit comments

Comments
 (0)