@@ -47,18 +47,29 @@ def init_async(self):
47
47
tf .distribute .experimental .coordinator .ClusterCoordinator (
48
48
self ._strategy ))
49
49
50
+ def coordinator_for_async (
51
+ self ,
52
+ ) -> tf .distribute .experimental .coordinator .ClusterCoordinator :
53
+ if not self ._coordinator :
54
+ raise ValueError (
55
+ "Coordinator uninitialized for async run. Call init_async() first."
56
+ )
57
+ return self ._coordinator
58
+
50
59
def join (self ):
51
60
"""Join all async steps. Only useful in aysnc training."""
52
61
if getattr (self , "_is_async" , False ):
53
- self ._coordinator .join ()
62
+ self .coordinator_for_async () .join ()
54
63
55
64
def create_train_loop_fn (self ):
56
65
"""Creates a eval loop from the given step function and options."""
57
66
train_loop_fn = super ().create_train_loop_fn ()
58
67
if getattr (self , "_is_async" , False ):
59
68
60
69
def _async_loop_fn (iterator , num_steps ):
61
- self ._coordinator .schedule (train_loop_fn , args = (iterator , num_steps ))
70
+ self .coordinator_for_async ().schedule (
71
+ train_loop_fn , args = (iterator , num_steps )
72
+ )
62
73
63
74
return _async_loop_fn
64
75
else :
@@ -76,7 +87,9 @@ def create_eval_loop_fn(self, has_state: bool):
76
87
def _async_loop_fn (iterator , num_steps , state = None , reduce_fn = None ):
77
88
assert state is None
78
89
assert reduce_fn is None
79
- self ._coordinator .schedule (eval_loop_fn , args = (iterator , num_steps ))
90
+ self .coordinator_for_async ().schedule (
91
+ eval_loop_fn , args = (iterator , num_steps )
92
+ )
80
93
81
94
return _async_loop_fn
82
95
else :
@@ -102,7 +115,9 @@ def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
102
115
* args , ** kwargs )
103
116
per_worker_dataset_fn = tf .function (per_worker_dataset_fn )
104
117
105
- return self ._coordinator .create_per_worker_dataset (per_worker_dataset_fn )
118
+ return self .coordinator_for_async ().create_per_worker_dataset (
119
+ per_worker_dataset_fn
120
+ )
106
121
else :
107
122
return orbit .utils .make_distributed_dataset (self ._strategy , dataset_or_fn ,
108
123
* args , ** kwargs )
@@ -352,7 +367,10 @@ def next_train_inputs(self, iterator):
352
367
This method provides a way to control how to fetch the next model input, and
353
368
what data to send to the model.
354
369
355
- This function runs in eager mode.
370
+ Note: This function runs on the host side when accelerators are used.
371
+
372
+ Note: Depending on the training setup this may or may not run in eager mode.
373
+ In most cases it will be run in graph mode.
356
374
357
375
Args:
358
376
iterator: Dataset iterator to generate the next inputs from.
@@ -399,7 +417,10 @@ def next_eval_inputs(self, iterator):
399
417
processed later in `aggregate_logs`. This is useful for sending extra logs
400
418
downstream that are not compatible with the accelerators.
401
419
402
- This function runs in eager mode.
420
+ Note: This function runs on the host side when accelerators are used.
421
+
422
+ Note: Depending on the training setup this may or may not run in eager mode.
423
+ In most cases it will be run in graph mode.
403
424
404
425
Args:
405
426
iterator: Dataset iterator to generate the next inputs from.
0 commit comments