1313import tensorflow as tf
1414
1515from tensorlayer import logging
16- # from tensorlayer.files import load_graph_and_params
16+ from tensorlayer .files import net2static_graph , static_graph2net , assign_weights
17+ from tensorlayer .files import save_weights_to_hdf5 , load_hdf5_to_weights
1718from tensorlayer .files import del_folder , exists_or_mkdir
1819
1920
@@ -112,8 +113,8 @@ def save_model(self, network=None, model_name='model', **kwargs):
112113
113114 Parameters
114115 ----------
115- network : TensorLayer layer
116- TensorLayer layer instance.
116+ network : TensorLayer Model
117+ TensorLayer Model instance.
117118 model_name : str
118119 The name/key of model.
119120 kwargs : other events
@@ -125,15 +126,15 @@ def save_model(self, network=None, model_name='model', **kwargs):
125126 >>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model')
126127
127128 Load one model with parameters from database (run this in other script)
128- >>> net = db.find_top_model(sess=sess, accuracy=0.8, loss=2.3)
129+ >>> net = db.find_top_model(accuracy=0.8, loss=2.3)
129130
130131 Find and load the latest model.
131- >>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.DESCENDING)])
132- >>> net = db.find_top_model(sess=sess, sort=[("time", -1)])
132+ >>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)])
133+ >>> net = db.find_top_model(sort=[("time", -1)])
133134
134135 Find and load the oldest model.
135- >>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.ASCENDING)])
136- >>> net = db.find_top_model(sess=sess, sort=[("time", 1)])
136+ >>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)])
137+ >>> net = db.find_top_model(sort=[("time", 1)])
137138
138139 Get model information
139140 >>> net._accuracy
@@ -146,11 +147,13 @@ def save_model(self, network=None, model_name='model', **kwargs):
146147 kwargs .update ({'model_name' : model_name })
147148 self ._fill_project_info (kwargs ) # put project_name into kwargs
148149
149- params = network .get_all_params ()
150+ # params = network.get_all_params()
151+ params = network .weights
150152
151153 s = time .time ()
152154
153- kwargs .update ({'architecture' : network .all_graphs , 'time' : datetime .utcnow ()})
155+ # kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
156+ kwargs .update ({'architecture' : net2static_graph (network ), 'time' : datetime .utcnow ()})
154157
155158 try :
156159 params_id = self .model_fs .put (self ._serialization (params ))
@@ -165,13 +168,11 @@ def save_model(self, network=None, model_name='model', **kwargs):
165168 print ("[Database] Save model: FAIL" )
166169 return False
167170
168- def find_top_model (self , sess , sort = None , model_name = 'model' , ** kwargs ):
171+ def find_top_model (self , sort = None , model_name = 'model' , ** kwargs ):
169172 """Finds and returns a model architecture and its parameters from the database which matches the requirement.
170173
171174 Parameters
172175 ----------
173- sess : Session
174- TensorFlow session.
175176 sort : List of tuple
176177 PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
177178 model_name : str or None
@@ -185,7 +186,7 @@ def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
185186
186187 Returns
187188 ---------
188- network : TensorLayer layer
189+ network : TensorLayer Model
189190 Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
190191 """
191192 # print(kwargs) # {}
@@ -196,33 +197,38 @@ def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
196197
197198 d = self .db .Model .find_one (filter = kwargs , sort = sort )
198199
199- _temp_file_name = '_find_one_model_ztemp_file'
200+ # _temp_file_name = '_find_one_model_ztemp_file'
200201 if d is not None :
201202 params_id = d ['params_id' ]
202203 graphs = d ['architecture' ]
203204 _datetime = d ['time' ]
204- exists_or_mkdir (_temp_file_name , False )
205- with open (os .path .join (_temp_file_name , 'graph.pkl' ), 'wb' ) as file :
206- pickle .dump (graphs , file , protocol = pickle .HIGHEST_PROTOCOL )
205+ # exists_or_mkdir(_temp_file_name, False)
206+ # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
207+ # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
207208 else :
208209 print ("[Database] FAIL! Cannot find model: {}" .format (kwargs ))
209210 return False
210211 try :
211212 params = self ._deserialization (self .model_fs .get (params_id ).read ())
212- np .savez (os .path .join (_temp_file_name , 'params.npz' ), params = params )
213-
214- network = load_graph_and_params (name = _temp_file_name , sess = sess )
215- del_folder (_temp_file_name )
213+ # TODO : restore model and load weights
214+ network = static_graph2net (graphs )
215+ assign_weights (weights = params , network = network )
216+ # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
217+ #
218+ # network = load_graph_and_params(name=_temp_file_name, sess=sess)
219+ # del_folder(_temp_file_name)
216220
217221 pc = self .db .Model .find (kwargs )
218222 print (
219- "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s" .
220- format (kwargs , sort , _datetime , round (time .time () - s , 2 ))
223+ "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s" .format (
224+ kwargs , sort , _datetime , round (time .time () - s , 2 )
225+ )
221226 )
222227
228+ # FIXME : not sure what's this for
223229 # put all informations of model into the TL layer
224- for key in d :
225- network .__dict__ .update ({"_%s" % key : d [key ]})
230+ # for key in d:
231+ # network.__dict__.update({"_%s" % key: d[key]})
226232
227233 # check whether more parameters match the requirement
228234 params_id_list = pc .distinct ('params_id' )
@@ -553,12 +559,12 @@ def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_
553559 >>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial')
554560
555561 Finds and runs the latest task
556- >>> db.run_top_task(sess=sess, sort=[("time", pymongo.DESCENDING)])
557- >>> db.run_top_task(sess=sess, sort=[("time", -1)])
562+ >>> db.run_top_task(sort=[("time", pymongo.DESCENDING)])
563+ >>> db.run_top_task(sort=[("time", -1)])
558564
559565 Finds and runs the oldest task
560- >>> db.run_top_task(sess=sess, sort=[("time", pymongo.ASCENDING)])
561- >>> db.run_top_task(sess=sess, sort=[("time", 1)])
566+ >>> db.run_top_task(sort=[("time", pymongo.ASCENDING)])
567+ >>> db.run_top_task(sort=[("time", 1)])
562568
563569 """
564570 if not isinstance (task_name , str ): # is None:
@@ -613,58 +619,57 @@ def run_top_task(self, task_name=None, sort=None, **kwargs):
613619 # find task and set status to running
614620 task = self .db .Task .find_one_and_update (kwargs , {'$set' : {'status' : 'running' }}, sort = sort )
615621
616- try :
617- # get task info e.g. hyper parameters, python script
618- if task is None :
619- logging .info ("[Database] Find Task FAIL: key: {} sort: {}" .format (task_name , sort ))
620- return False
621- else :
622- logging .info ("[Database] Find Task SUCCESS: key: {} sort: {}" .format (task_name , sort ))
623- _datetime = task ['time' ]
624- _script = task ['script' ]
625- _id = task ['_id' ]
626- _hyper_parameters = task ['hyper_parameters' ]
627- _saved_result_keys = task ['saved_result_keys' ]
628- logging .info (" hyper parameters:" )
629- for key in _hyper_parameters :
630- globals ()[key ] = _hyper_parameters [key ]
631- logging .info (" {}: {}" .format (key , _hyper_parameters [key ]))
632- # run task
633- s = time .time ()
634- logging .info ("[Database] Start Task: key: {} sort: {} push time: {}" .format (task_name , sort , _datetime ))
635- _script = _script .decode ('utf-8' )
636- with tf .Graph ().as_default (): # as graph: # clear all TF graphs
637- exec (_script , globals ())
638-
639- # set status to finished
640- _ = self .db .Task .find_one_and_update ({'_id' : _id }, {'$set' : {'status' : 'finished' }})
641-
642- # return results
643- __result = {}
644- for _key in _saved_result_keys :
645- logging .info (" result: {}={} {}" .format (_key , globals ()[_key ], type (globals ()[_key ])))
646- __result .update ({"%s" % _key : globals ()[_key ]})
647- _ = self .db .Task .find_one_and_update (
648- {
649- '_id' : _id
650- }, {'$set' : {
651- 'result' : __result
652- }}, return_document = pymongo .ReturnDocument .AFTER
653- )
654- logging .info (
655- "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s" .
656- format (task_name , sort , _datetime ,
657- time .time () - s )
658- )
659- return True
660- except Exception as e :
661- exc_type , exc_obj , exc_tb = sys .exc_info ()
662- fname = os .path .split (exc_tb .tb_frame .f_code .co_filename )[1 ]
663- logging .info ("{} {} {} {} {}" .format (exc_type , exc_obj , fname , exc_tb .tb_lineno , e ))
664- logging .info ("[Database] Fail to run task" )
665- # if fail, set status back to pending
666- _ = self .db .Task .find_one_and_update ({'_id' : _id }, {'$set' : {'status' : 'pending' }})
622+ # try:
623+ # get task info e.g. hyper parameters, python script
624+ if task is None :
625+ logging .info ("[Database] Find Task FAIL: key: {} sort: {}" .format (task_name , sort ))
667626 return False
627+ else :
628+ logging .info ("[Database] Find Task SUCCESS: key: {} sort: {}" .format (task_name , sort ))
629+ _datetime = task ['time' ]
630+ _script = task ['script' ]
631+ _id = task ['_id' ]
632+ _hyper_parameters = task ['hyper_parameters' ]
633+ _saved_result_keys = task ['saved_result_keys' ]
634+ logging .info (" hyper parameters:" )
635+ for key in _hyper_parameters :
636+ globals ()[key ] = _hyper_parameters [key ]
637+ logging .info (" {}: {}" .format (key , _hyper_parameters [key ]))
638+ # run task
639+ s = time .time ()
640+ logging .info ("[Database] Start Task: key: {} sort: {} push time: {}" .format (task_name , sort , _datetime ))
641+ _script = _script .decode ('utf-8' )
642+ with tf .Graph ().as_default (): # # as graph: # clear all TF graphs
643+ exec (_script , globals ())
644+
645+ # set status to finished
646+ _ = self .db .Task .find_one_and_update ({'_id' : _id }, {'$set' : {'status' : 'finished' }})
647+
648+ # return results
649+ __result = {}
650+ for _key in _saved_result_keys :
651+ logging .info (" result: {}={} {}" .format (_key , globals ()[_key ], type (globals ()[_key ])))
652+ __result .update ({"%s" % _key : globals ()[_key ]})
653+ _ = self .db .Task .find_one_and_update (
654+ {'_id' : _id }, {'$set' : {
655+ 'result' : __result
656+ }}, return_document = pymongo .ReturnDocument .AFTER
657+ )
658+ logging .info (
659+ "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s" .format (
660+ task_name , sort , _datetime ,
661+ time .time () - s
662+ )
663+ )
664+ return True
665+ # except Exception as e:
666+ # exc_type, exc_obj, exc_tb = sys.exc_info()
667+ # fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
668+ # logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
669+ # logging.info("[Database] Fail to run task")
670+ # # if fail, set status back to pending
671+ # _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
672+ # return False
668673
669674 def delete_tasks (self , ** kwargs ):
670675 """Delete tasks.
0 commit comments