Skip to content

Commit 2f30b8e

Browse files
committed
Clean code stryle for yapf test
1 parent c6797ac commit 2f30b8e

File tree

77 files changed

+787
-2856
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+787
-2856
lines changed

docs/user/tutorials.rst

Lines changed: 0 additions & 1982 deletions
This file was deleted.

tensorlayer/cost.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,9 @@ def mn_i(weights, name='maxnorm_i_regularizer'):
781781
return mn_i
782782

783783

784-
def huber_loss(output, target,is_mean=True, delta=1.0, dynamichuber=False, reverse=False, axis=-1, epsilon= 0.00001,name=None):
784+
def huber_loss(
785+
output, target, is_mean=True, delta=1.0, dynamichuber=False, reverse=False, axis=-1, epsilon=0.00001, name=None
786+
):
785787
"""Huber Loss operation, see ``https://en.wikipedia.org/wiki/Huber_loss`` .
786788
Reverse Huber Loss operation, see ''https://statweb.stanford.edu/~owen/reports/hhu.pdf''.
787789
Dynamic Reverse Huber Loss operation, see ''https://arxiv.org/pdf/1606.00373.pdf''.
@@ -814,19 +816,45 @@ def huber_loss(output, target,is_mean=True, delta=1.0, dynamichuber=False, rever
814816
"""
815817
if reverse:
816818
if dynamichuber:
817-
huber_c = 0.2*tf.reduce_max(tf.abs(output - target))
819+
huber_c = 0.2 * tf.reduce_max(tf.abs(output - target))
818820
else:
819-
huber_c=delta
821+
huber_c = delta
820822
if is_mean:
821-
loss=tf.reduce_mean(tf.where(tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
822-
tf.multiply(tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0), tf.math.divide_no_nan(.5, huber_c + epsilon))),name=name)
823+
loss = tf.reduce_mean(
824+
tf.where(
825+
tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
826+
tf.multiply(
827+
tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0),
828+
tf.math.divide_no_nan(.5, huber_c + epsilon)
829+
)
830+
), name=name
831+
)
823832
else:
824-
loss=tf.reduce_mean(tf.reduce_sum(tf.where(tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
825-
tf.multiply(tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0), tf.math.divide_no_nan(.5, huber_c + epsilon))),axis),name=name)
833+
loss = tf.reduce_mean(
834+
tf.reduce_sum(
835+
tf.where(
836+
tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
837+
tf.multiply(
838+
tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0),
839+
tf.math.divide_no_nan(.5, huber_c + epsilon)
840+
)
841+
), axis
842+
), name=name
843+
)
826844
elif is_mean:
827-
loss = tf.reduce_mean(tf.where(tf.less_equal(tf.abs(output - target), delta), 0.5*tf.pow(output - target,2),
828-
delta*(tf.abs(output - target)-0.5*delta)), name=name)
845+
loss = tf.reduce_mean(
846+
tf.where(
847+
tf.less_equal(tf.abs(output - target), delta), 0.5 * tf.pow(output - target, 2),
848+
delta * (tf.abs(output - target) - 0.5 * delta)
849+
), name=name
850+
)
829851
else:
830-
loss = tf.reduce_mean(tf.reduce_sum(tf.where(tf.less_equal(tf.abs(output - target), delta), 0.5*tf.pow(output - target,2),
831-
delta*(tf.abs(output - target)-0.5*delta)),axis), name=name)
832-
return loss
852+
loss = tf.reduce_mean(
853+
tf.reduce_sum(
854+
tf.where(
855+
tf.less_equal(tf.abs(output - target), delta), 0.5 * tf.pow(output - target, 2),
856+
delta * (tf.abs(output - target) - 0.5 * delta)
857+
), axis
858+
), name=name
859+
)
860+
return loss

tensorlayer/db.py

Lines changed: 86 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import tensorflow as tf
1414

1515
from tensorlayer import logging
16-
# from tensorlayer.files import load_graph_and_params
16+
from tensorlayer.files import net2static_graph, static_graph2net, assign_weights
17+
from tensorlayer.files import save_weights_to_hdf5, load_hdf5_to_weights
1718
from tensorlayer.files import del_folder, exists_or_mkdir
1819

1920

@@ -112,8 +113,8 @@ def save_model(self, network=None, model_name='model', **kwargs):
112113
113114
Parameters
114115
----------
115-
network : TensorLayer layer
116-
TensorLayer layer instance.
116+
network : TensorLayer Model
117+
TensorLayer Model instance.
117118
model_name : str
118119
The name/key of model.
119120
kwargs : other events
@@ -125,15 +126,15 @@ def save_model(self, network=None, model_name='model', **kwargs):
125126
>>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model')
126127
127128
Load one model with parameters from database (run this in other script)
128-
>>> net = db.find_top_model(sess=sess, accuracy=0.8, loss=2.3)
129+
>>> net = db.find_top_model(accuracy=0.8, loss=2.3)
129130
130131
Find and load the latest model.
131-
>>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.DESCENDING)])
132-
>>> net = db.find_top_model(sess=sess, sort=[("time", -1)])
132+
>>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)])
133+
>>> net = db.find_top_model(sort=[("time", -1)])
133134
134135
Find and load the oldest model.
135-
>>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.ASCENDING)])
136-
>>> net = db.find_top_model(sess=sess, sort=[("time", 1)])
136+
>>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)])
137+
>>> net = db.find_top_model(sort=[("time", 1)])
137138
138139
Get model information
139140
>>> net._accuracy
@@ -146,11 +147,13 @@ def save_model(self, network=None, model_name='model', **kwargs):
146147
kwargs.update({'model_name': model_name})
147148
self._fill_project_info(kwargs) # put project_name into kwargs
148149

149-
params = network.get_all_params()
150+
# params = network.get_all_params()
151+
params = network.weights
150152

151153
s = time.time()
152154

153-
kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
155+
# kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
156+
kwargs.update({'architecture': net2static_graph(network), 'time': datetime.utcnow()})
154157

155158
try:
156159
params_id = self.model_fs.put(self._serialization(params))
@@ -165,13 +168,11 @@ def save_model(self, network=None, model_name='model', **kwargs):
165168
print("[Database] Save model: FAIL")
166169
return False
167170

168-
def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
171+
def find_top_model(self, sort=None, model_name='model', **kwargs):
169172
"""Finds and returns a model architecture and its parameters from the database which matches the requirement.
170173
171174
Parameters
172175
----------
173-
sess : Session
174-
TensorFlow session.
175176
sort : List of tuple
176177
PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
177178
model_name : str or None
@@ -185,7 +186,7 @@ def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
185186
186187
Returns
187188
---------
188-
network : TensorLayer layer
189+
network : TensorLayer Model
189190
Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
190191
"""
191192
# print(kwargs) # {}
@@ -196,33 +197,38 @@ def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
196197

197198
d = self.db.Model.find_one(filter=kwargs, sort=sort)
198199

199-
_temp_file_name = '_find_one_model_ztemp_file'
200+
# _temp_file_name = '_find_one_model_ztemp_file'
200201
if d is not None:
201202
params_id = d['params_id']
202203
graphs = d['architecture']
203204
_datetime = d['time']
204-
exists_or_mkdir(_temp_file_name, False)
205-
with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
206-
pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
205+
# exists_or_mkdir(_temp_file_name, False)
206+
# with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
207+
# pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
207208
else:
208209
print("[Database] FAIL! Cannot find model: {}".format(kwargs))
209210
return False
210211
try:
211212
params = self._deserialization(self.model_fs.get(params_id).read())
212-
np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
213-
214-
network = load_graph_and_params(name=_temp_file_name, sess=sess)
215-
del_folder(_temp_file_name)
213+
# TODO : restore model and load weights
214+
network = static_graph2net(graphs)
215+
assign_weights(weights=params, network=network)
216+
# np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
217+
#
218+
# network = load_graph_and_params(name=_temp_file_name, sess=sess)
219+
# del_folder(_temp_file_name)
216220

217221
pc = self.db.Model.find(kwargs)
218222
print(
219-
"[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".
220-
format(kwargs, sort, _datetime, round(time.time() - s, 2))
223+
"[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
224+
kwargs, sort, _datetime, round(time.time() - s, 2)
225+
)
221226
)
222227

228+
# FIXME : not sure what's this for
223229
# put all informations of model into the TL layer
224-
for key in d:
225-
network.__dict__.update({"_%s" % key: d[key]})
230+
# for key in d:
231+
# network.__dict__.update({"_%s" % key: d[key]})
226232

227233
# check whether more parameters match the requirement
228234
params_id_list = pc.distinct('params_id')
@@ -553,12 +559,12 @@ def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_
553559
>>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial')
554560
555561
Finds and runs the latest task
556-
>>> db.run_top_task(sess=sess, sort=[("time", pymongo.DESCENDING)])
557-
>>> db.run_top_task(sess=sess, sort=[("time", -1)])
562+
>>> db.run_top_task(sort=[("time", pymongo.DESCENDING)])
563+
>>> db.run_top_task(sort=[("time", -1)])
558564
559565
Finds and runs the oldest task
560-
>>> db.run_top_task(sess=sess, sort=[("time", pymongo.ASCENDING)])
561-
>>> db.run_top_task(sess=sess, sort=[("time", 1)])
566+
>>> db.run_top_task(sort=[("time", pymongo.ASCENDING)])
567+
>>> db.run_top_task(sort=[("time", 1)])
562568
563569
"""
564570
if not isinstance(task_name, str): # is None:
@@ -613,58 +619,57 @@ def run_top_task(self, task_name=None, sort=None, **kwargs):
613619
# find task and set status to running
614620
task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort)
615621

616-
try:
617-
# get task info e.g. hyper parameters, python script
618-
if task is None:
619-
logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort))
620-
return False
621-
else:
622-
logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort))
623-
_datetime = task['time']
624-
_script = task['script']
625-
_id = task['_id']
626-
_hyper_parameters = task['hyper_parameters']
627-
_saved_result_keys = task['saved_result_keys']
628-
logging.info(" hyper parameters:")
629-
for key in _hyper_parameters:
630-
globals()[key] = _hyper_parameters[key]
631-
logging.info(" {}: {}".format(key, _hyper_parameters[key]))
632-
# run task
633-
s = time.time()
634-
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
635-
_script = _script.decode('utf-8')
636-
with tf.Graph().as_default(): # as graph: # clear all TF graphs
637-
exec(_script, globals())
638-
639-
# set status to finished
640-
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
641-
642-
# return results
643-
__result = {}
644-
for _key in _saved_result_keys:
645-
logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key])))
646-
__result.update({"%s" % _key: globals()[_key]})
647-
_ = self.db.Task.find_one_and_update(
648-
{
649-
'_id': _id
650-
}, {'$set': {
651-
'result': __result
652-
}}, return_document=pymongo.ReturnDocument.AFTER
653-
)
654-
logging.info(
655-
"[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".
656-
format(task_name, sort, _datetime,
657-
time.time() - s)
658-
)
659-
return True
660-
except Exception as e:
661-
exc_type, exc_obj, exc_tb = sys.exc_info()
662-
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
663-
logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
664-
logging.info("[Database] Fail to run task")
665-
# if fail, set status back to pending
666-
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
622+
# try:
623+
# get task info e.g. hyper parameters, python script
624+
if task is None:
625+
logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort))
667626
return False
627+
else:
628+
logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort))
629+
_datetime = task['time']
630+
_script = task['script']
631+
_id = task['_id']
632+
_hyper_parameters = task['hyper_parameters']
633+
_saved_result_keys = task['saved_result_keys']
634+
logging.info(" hyper parameters:")
635+
for key in _hyper_parameters:
636+
globals()[key] = _hyper_parameters[key]
637+
logging.info(" {}: {}".format(key, _hyper_parameters[key]))
638+
# run task
639+
s = time.time()
640+
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
641+
_script = _script.decode('utf-8')
642+
with tf.Graph().as_default(): # # as graph: # clear all TF graphs
643+
exec(_script, globals())
644+
645+
# set status to finished
646+
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
647+
648+
# return results
649+
__result = {}
650+
for _key in _saved_result_keys:
651+
logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key])))
652+
__result.update({"%s" % _key: globals()[_key]})
653+
_ = self.db.Task.find_one_and_update(
654+
{'_id': _id}, {'$set': {
655+
'result': __result
656+
}}, return_document=pymongo.ReturnDocument.AFTER
657+
)
658+
logging.info(
659+
"[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format(
660+
task_name, sort, _datetime,
661+
time.time() - s
662+
)
663+
)
664+
return True
665+
# except Exception as e:
666+
# exc_type, exc_obj, exc_tb = sys.exc_info()
667+
# fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
668+
# logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
669+
# logging.info("[Database] Fail to run task")
670+
# # if fail, set status back to pending
671+
# _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
672+
# return False
668673

669674
def delete_tasks(self, **kwargs):
670675
"""Delete tasks.

tensorlayer/files/dataset_loaders/celebA_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import zipfile
66

77
from tensorlayer import logging
8-
from tensorlayer.files.utils import (download_file_from_google_drive,
9-
exists_or_mkdir, load_file_list)
8+
from tensorlayer.files.utils import (download_file_from_google_drive, exists_or_mkdir, load_file_list)
109

1110
__all__ = ['load_celebA_dataset']
1211

tensorlayer/files/dataset_loaders/cyclegan_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import numpy as np
77

88
from tensorlayer import logging, visualize
9-
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list,
10-
maybe_download_and_extract)
9+
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)
1110

1211
__all__ = ['load_cyclegan_dataset']
1312

tensorlayer/files/dataset_loaders/flickr_1M_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import os
55

66
from tensorlayer import logging, visualize
7-
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list,
8-
load_folder_list,
9-
maybe_download_and_extract, read_file)
7+
from tensorlayer.files.utils import (
8+
del_file, folder_exists, load_file_list, load_folder_list, maybe_download_and_extract, read_file
9+
)
1010

1111
__all__ = ['load_flickr1M_dataset']
1212

tensorlayer/files/dataset_loaders/flickr_25k_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import os
55

66
from tensorlayer import logging, visualize
7-
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list,
8-
maybe_download_and_extract, natural_keys,
9-
read_file)
7+
from tensorlayer.files.utils import (
8+
del_file, folder_exists, load_file_list, maybe_download_and_extract, natural_keys, read_file
9+
)
1010

1111
__all__ = ['load_flickr25k_dataset']
1212

tensorlayer/files/dataset_loaders/mpii_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import os
55

66
from tensorlayer import logging
7-
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list,
8-
maybe_download_and_extract)
7+
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)
98

109
__all__ = ['load_mpii_pose_dataset']
1110

tensorlayer/files/dataset_loaders/voc_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import tensorflow as tf
77

88
from tensorlayer import logging, utils
9-
from tensorlayer.files.utils import (del_file, del_folder, folder_exists,
10-
load_file_list,
11-
maybe_download_and_extract)
9+
from tensorlayer.files.utils import (del_file, del_folder, folder_exists, load_file_list, maybe_download_and_extract)
1210

1311
__all__ = ['load_voc_dataset']
1412

0 commit comments

Comments
 (0)