@@ -3,164 +3,21 @@ API - Distributed Training
33
44(Alpha release - usage might change later)
55
6- Helper sessions and methods to run a distributed training.
7- Check this ` minst example <https://github.com/tensorlayer/tensorlayer/blob /master/example/tutorial_mnist_distributed.py >`_.
6+ Helper API to run a distributed training.
7+ Check these ` examples <https://github.com/tensorlayer/tensorlayer/tree /master/examples/distributed_training >`_.
88
99.. automodule :: tensorlayer.distributed
1010
1111.. autosummary ::
1212
13- TaskSpecDef
14- TaskSpec
15- DistributedSession
16- StopAtTimeHook
17- LoadCheckpoint
18-
13+ Trainer
1914
2015Distributed training
2116--------------------
2217
23-
24- TaskSpecDef
18+ Trainer
2519^^^^^^^^^^^
2620
27- .. autofunction :: TaskSpecDef
28-
29- Create TaskSpecDef from environment variables
30- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31-
32- .. autofunction :: TaskSpec
33-
34- Distributed session object
35- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36-
37- .. autofunction :: DistributedSession
38-
39- Data sharding
40- ^^^^^^^^^^^^^^^^^^^^^^
41-
42- In some cases we want to shard the data among all the training servers and
43- not use all the data in all servers. TensorFlow >=1.4 provides some helper classes
44- to work with data that support data sharding: `Datasets <https://www.tensorflow.org/programmers_guide/datasets >`_
45-
46- It is important in sharding that the shuffle or any non deterministic operation
47- is done after creating the shards:
48-
49- .. code-block :: python
50-
51- from tensorflow.contrib.data import TextLineDataset
52- from tensorflow.contrib.data import Dataset
53-
54- task_spec = TaskSpec()
55- task_spec.create_server()
56- files_dataset = Dataset.list_files(files_pattern)
57- dataset = TextLineDataset(files_dataset)
58- dataset = dataset.map(your_python_map_function, num_threads = 4 )
59- if task_spec is not None :
60- dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
61- dataset = dataset.shuffle(buffer_size)
62- dataset = dataset.batch(batch_size)
63- dataset = dataset.repeat(num_epochs)
64- iterator = dataset.make_one_shot_iterator()
65- next_element = iterator.get_next()
66- with tf.device(task_spec.device_fn()):
67- tensors = create_graph(next_element)
68- with tl.DistributedSession(task_spec = task_spec,
69- checkpoint_dir = ' /tmp/ckpt' ) as session:
70- while not session.should_stop():
71- session.run(tensors)
72-
73-
74- Logging
75- ^^^^^^^^^^^^^^^^^^^^^^
76-
77- We can use task_spec to log only in the master server:
78-
79- .. code-block :: python
80-
81- while not session.should_stop():
82- should_log = task_spec.is_master() and your_conditions
83- if should_log:
84- results = session.run(tensors_with_log_info)
85- logging.info(... )
86- else :
87- results = session.run(tensors)
88-
89- Continuous evaluation
90- ^^^^^^^^^^^^^^^^^^^^^^
91-
92- You can use one of the workers to run an evaluation for the saved checkpoints:
93-
94- .. code-block :: python
95-
96- import tensorflow as tf
97- from tensorflow.python.training import session_run_hook
98- from tensorflow.python.training.monitored_session import SingularMonitoredSession
99-
100- class Evaluator (session_run_hook .SessionRunHook ):
101- def __init__ (self , checkpoints_path , output_path ):
102- self .checkpoints_path = checkpoints_path
103- self .summary_writer = tf.summary.FileWriter(output_path)
104- self .lastest_checkpoint = ' '
105-
106- def after_create_session (self , session , coord ):
107- checkpoint = tf.train.latest_checkpoint(self .checkpoints_path)
108- # wait until a new check point is available
109- while self .lastest_checkpoint == checkpoint:
110- time.sleep(30 )
111- checkpoint = tf.train.latest_checkpoint(self .checkpoints_path)
112- self .saver.restore(session, checkpoint)
113- self .lastest_checkpoint = checkpoint
114-
115- def end (self , session ):
116- super (Evaluator, self ).end(session)
117- # save summaries
118- step = int (self .lastest_checkpoint.split(' -' )[- 1 ])
119- self .summary_writer.add_summary(self .summary, step)
120-
121- def _create_graph ():
122- # your code to create the graph with the dataset
123-
124- def run_evaluation ():
125- with tf.Graph().as_default():
126- summary_tensors = create_graph()
127- self .saver = tf.train.Saver(var_list = tf_variables.trainable_variables())
128- hooks = self .create_hooks()
129- hooks.append(self )
130- if self .max_time_secs and self .max_time_secs > 0 :
131- hooks.append(StopAtTimeHook(self .max_time_secs))
132- # this evaluation runs indefinitely, until the process is killed
133- while True :
134- with SingularMonitoredSession(hooks = [self ]) as session:
135- try :
136- while not sess.should_stop():
137- self .summary = session.run(summary_tensors)
138- except OutOfRangeError:
139- pass
140- # end of evaluation
141-
142- task_spec = TaskSpec().user_last_worker_as_evaluator()
143- if task_spec.is_evaluator():
144- Evaluator().run_evaluation()
145- else :
146- task_spec.create_server()
147- # run normal training
148-
149-
150-
151- Session hooks
152- ----------------------
153-
154- TensorFlow provides some `Session Hooks <https://www.tensorflow.org/api_guides/python/train#Training_Hooks >`_
155- to do some operations in the sessions. We added more to help with common operations.
156-
157-
158- Stop after maximum time
159- ^^^^^^^^^^^^^^^^^^^^^^^
160-
161- .. autofunction :: StopAtTimeHook
21+ .. autofunction :: Trainer
16222
163- Initialize network with checkpoint
164- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16523
166- .. autofunction :: LoadCheckpoint
0 commit comments