You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
@@ -94,6 +97,7 @@ There are a few points worth noting in the above user code, when using PS traini
94
97
* Batch level callback will be disabled; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
95
98
* The cluster is synced at the end of every epoch. This is an implementation detail users do not necessarily need to be aware of, however is important for the correctness of epoch-level callbacks.
96
99
* model.fit(..., run_eagerly=True) case is not supported. This is because `ClusterCoordinator.schedule` requires a `tf.function` to be `schedule`d*, and regular python function cannot.
100
+
* The "bundled" evaluation in `model.fit` is performed locally (in this initial design). If an evaluator task, running concurrently with training (aka sidecar evaluator), is preferred, the user should use a `ModelCheckpoint(save_weights_only=True)` callback to produce checkpoint files for the evaluator.
97
101
98
102
*There are a couple of reasons why we chose to only support `tf.function` to be scheduled. Primarily, we in general have better control over the behavior of `tf.functions`, including variable and resource creation. Furthermore, this forces the content of the function to be executed on remote workers, as opposed to possible execution of python code on the coordinator.
99
103
@@ -456,12 +460,9 @@ Similarly, the hyper and slot variables an `optimizer` object uses, would be cre
456
460
457
461
Initially, we aim to have `model.evaluate` and `model.predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`, and thus the evaluate function is executed on the coordinator.
458
462
459
-
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute. The current `ClusterCoordinator` API has a limitation where distributed evaluation does not have visitation guarantee, when workers can become unavailable. Thus, we have a couple of options:
463
+
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute.
460
464
461
-
1. Implement distributed `model.evaluate` without visitation guarantee, but require user's opt-in because of the behavior change (by `model.evaluate(..., distributed_eval=True)`)
462
-
2. Support distributed `model.evaluate` only after `ClusterCoordinator` provides visitation guarantee mechanism
463
-
464
-
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a dataset factory. That is, `model.fit` will take a `DatasetFactory` for `validation_data` argument, and `model.evaluate` will take a `DatasetFactory` for `x` as opposed to a `dataset` instance.
465
+
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a dataset factory. That is, `model.fit` will take a `DatasetFactory` for `validation_data` argument, and `model.evaluate` will take a `DatasetFactory` for `x` as opposed to a `dataset` instance. Even if a `DatasetFactory` for `model.evaluate` may seem unnecessary at this point, having a type restriction allows compatible changes in the future for distributed evaluation. See "Plan for distributed evaluation" section in Appendix for further discussion.
465
466
466
467
See below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
467
468
@@ -531,27 +532,29 @@ server.join()
531
532
532
533
### Evaluation
533
534
534
-
In addition to the existing train-evaluate solution provided by `model.fit`, we also support a dedicated evaluator task to be used, aka sidecar evaluator. Users will then have two evaluation schemes to choose from: alternating evaluation, or sidecar evaluation, or both, depending on their needs.
535
+
In addition to the existing train-evaluate solution provided by `model.fit` (though `model.evaluate` is local in the first phase), we also support a dedicated evaluator task to be used, aka sidecar evaluator, which runs on a separate machine not considered part of the training cluster.
536
+
537
+
Users will then have two evaluation schemes to choose from: alternating evaluation, or sidecar evaluation, or both, depending on their needs.
535
538
536
539
#### Built-in, alternating evaluation in `model.fit`
537
540
538
-
If `validation_data` argument is provided, and certain conditions are satisfied, `model.fit` also runs evaluation via `model.evaluate` API every epoch, in an train-evaluate alternating manner. As described above, at this time, only the coordinator is used for `model.evaluate` evaluation, and we plan to extend this to worker-distributed evaluation when visitation guarantee is supported. See above "model.evaluate" section for more information.
541
+
If `validation_data` argument is provided, and certain conditions are satisfied, `model.fit` also runs evaluation via `model.evaluate` API every epoch, in an train-evaluate alternating manner. As described above, at this time, only the coordinator is used for `model.evaluate` evaluation. See below "Plan for distributed evaluation" section in Appendix for further discussion on distributed evaluation.
539
542
540
543
#### Sidecar evaluation
541
544
542
-
In addition to the built-in evaluation `model.fit` provides, sidecar evaluation is also supported. Currently, we have a [recommended user flow](https://www.tensorflow.org/tutorials/distribute/parameter_server_training#side-car_evaluation) using a sidecar evaluator task for CTL users. The section discusses the proposed changes in sidecar evaluator accompanying `model.fit` usage with parameter server training.
545
+
For custom training loop users, sidecar evaluation is available currently, via a [recommended user flow](https://www.tensorflow.org/tutorials/distribute/parameter_server_training#side-car_evaluation). With `model.fit`, users can readily apply the said user flow in an evaluator task for a continuous evaluation, but in addition to that, this section discusses the proposed APIs for sidecar evaluator accompanying `model.fit` usage for users' convenience, with the goal that users do not need to write the detailed looping or checkpoint loading logic.
543
546
544
-
##### A sidecar evaluator task
547
+
##### `SidecarEvaluator` API
545
548
546
-
In the short term, a task that is allocated for evaluation (aka sidecar evaluator) continues to be the recommended evaluation solution for PS training. We plan to propose a `SidecarEvaluator` API in a separate RFC for user’s convenience: with this, user is expected to kick start an additional task `evaluator`, in which the python program runs a `SidecarEvaluator` as follows:
549
+
As part of this design, we propose a `SidecarEvaluator` API for handled, continuous evaluation. With this, user is expected to kick start an additional task `evaluator`, in which the python program runs a `SidecarEvaluator` as follows:
checkpoint_dir='/tmp/checkpoint_dir', # dir for training-saved checkpoint
@@ -561,69 +564,17 @@ SidecarEvaluator(
561
564
).start()
562
565
```
563
566
567
+
`SidecarEvaluator` continuously loads the checkpoint files saved by `ModelCheckpoint` in the training counterpart as they become available, and performs evaluation until the evaluation dataset is exhausted, or designated steps have reached.
564
568
565
-
`SidecarEvaluator` periodically loads the checkpoint files saved by the training counterpart, as long as the checkpoint captures the `model` (and optionally, `optimizer` objects if summary is written). As part of full integration with `model.fit` workflow, we propose to extend `SidecarEvaluator` to
566
-
567
-
568
-
569
-
* also accept the checkpoint files saved by `ModelCheckpoint` callback for periodic evaluation.
570
-
* accept arbitrary callbacks to be used in its internal `model.evaluate` call
571
-
572
-
##### A sidecar evaluation thread on coordinator
573
-
574
-
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code (for evaluation), can be done with an evaluation thread on the coordinator. With this approach, the user does not allocate a task with type 'evaluator', because one 'worker' task (that runs a `tf.distribute.Server`) from the cluster can be used for evaluation. It can be any of the workers, but for convenience, let’s say the Nth worker is used for evaluation.
575
-
576
-
The thread would be started by `model.fit`, if the user expresses to opt in via an argument such as `fit(..., run_sidecar_eval_thread=True)`. The thread would remotely execute an evaluation function on this worker #N, and wait for its result synchronously. Once the result is returned, it can write a summary, adjust learning rate, or signal to end the training. After that, it re-`schedule`s an evaluation function, and so on:
# At some point, we start a thread for sidecar eval
603
-
t = threading.Thread(target=self._continuously_evaluate)
604
-
t.start()
605
-
...
606
-
if run_sidecar_eval_thread:
607
-
self.should_eval = False
608
-
t.join()
609
-
```
610
-
611
-
Note that with this approach, the training cluster will be limited to the first N-1 workers it has remaining, so the training cluster and evaluation do not block each other.
612
-
613
-
If we compare the sidecar evaluator thread solution vs sidecar evaluator task (process):
569
+
Currently, `SidecarEvaluator` has an [implementation](https://github.com/tensorflow/tensorflow/blob/c046c38f6f7d80497174710f8547f9c89923bdc2/tensorflow/python/keras/distribute/sidecar_evaluator.py#L35) that was created for custom training loop use cases. We are proposing the following changes to `SidecarEvaluator` in order to be fully compatible with `model.fit`:
614
570
615
-
Pros (advantages of evaluator thread approach):
616
-
* This does not require a task to be set aside as evaluator, so 1) less work on the user, and 2) there is one fewer version of python binary
617
-
* There is easier communication between the sidecar evaluator (thread) and the coordinator main thread, which is important for many callbacks
571
+
*`SidecarEvaluator` should accept the checkpoint file format saved by `ModelCheckpoint` callback
572
+
*`SidecarEvaluator` should accept arbitrary callbacks to be used in its internal `model.evaluate` call
618
573
619
-
Cons (disadvantages of evaluator thread approach):
620
-
* This solution presents a challenge when workers can easily become unavailable, in which case it is not straightforward to immediately find another available worker to take over*
621
-
* This solution is blocked on `tf.keras.models.load_model` being available on PS, if `variable_partitioner` is used. Here, model saving and loading are for cloning the model, so if there is an alternative to clone, this solution is not blocked.
622
-
* Users who can afford to allocate a high priority on an evaluator task cannot do so with workers; workers would simply have the same, usually lower, priority (and thus more frequent function-takeovers)*
574
+
##### Coordinator-driven sidecar evaluation (via a thread)
623
575
624
-
*Fault tolerance, the first con, may further be addressed with possibly another `ClusterCoordinator`, if it shares the threads with the other `ClusterCoordinator`, and the library allows multiple function queues to be accessed by the threads. More details may be discussed in a separate RFC.
576
+
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code (for evaluation), can be done with an evaluation thread on the coordinator. This should be considered the preferred route for sidecar evaluation, because it would be started by `model.fit` automatically, and does not require a separate binary to be provided. The discussion is available in "Discussion on coordinator-driven sidecar evaluation" section in appendix, and is considered beyond the scope of this proposal. We plan to cover it in a future design/RFC.
625
577
626
-
*Regarding priority, the third con, we can address it by having a separate job (with only one task for now), say "eval_worker", for the worker that is solely used for evaluation. It'd be a little more work where TF_CONFIG, device filter, etc. need to be changed, but it is possible. It gives us the flexibility to assign a higher job priority.
627
578
628
579
### Fault tolerance
629
580
@@ -816,3 +767,72 @@ class ClusterCoodinator(object):
816
767
```
817
768
818
769
At this time, we're proposing to have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator` instead.
770
+
771
+
## Appendix
772
+
773
+
This section gathers topics that are considered out of this proposal's scope, but kept to continue relevant discussions.
774
+
775
+
### Plan for distributed evaluation
776
+
777
+
The current `ClusterCoordinator` API has a limitation where distributed evaluation does not have visitation guarantee, when workers can become unavailable. Thus, we have a couple of options:
778
+
779
+
1. Implement distributed `model.evaluate` without visitation guarantee, but require user's opt-in because of the behavior change (by `model.evaluate(..., distributed_eval=True)`)
780
+
2. Enable distributed `model.evaluate` after `ClusterCoordinator` supports visitation guarantee mechanism
781
+
782
+
There will be a separate design RFC discussing the technical details.
783
+
784
+
### Discussion on coordinator-driven sidecar evaluation
785
+
786
+
This section continues the discussions mentioned in the above "Coordinator-driven sidecar evaluation" section. With this approach, the user does not allocate a task with type 'evaluator', because one 'worker' task (that runs a `tf.distribute.Server`) from the cluster can be used for evaluation. It can be any of the workers, but for convenience, let’s say the Nth worker is used for evaluation.
787
+
788
+
The thread would be started by `model.fit`, if the user expresses to opt in via an argument such as `fit(..., run_sidecar_eval_thread=True)`. The thread would remotely execute an evaluation function on this worker #N, and wait for its result synchronously. Once the result is returned, it can write a summary, adjust learning rate, or signal to end the training. After that, it re-`schedule`s an evaluation function, and so on:
# At some point, we start a thread for sidecar eval
815
+
t = threading.Thread(target=self._continuously_evaluate)
816
+
t.start()
817
+
...
818
+
if run_sidecar_eval_thread:
819
+
self.should_eval = False
820
+
t.join()
821
+
```
822
+
823
+
Note that with this approach, the training cluster will be limited to the first N-1 workers it has remaining, so the training cluster and evaluation do not block each other. It is also worth noting that the actual continuous evaluation logic, can be done by a slightly modified version of `SidecarEvaluator`. We should aim to reuse this component for sidecar evaluation running in a thread.
824
+
825
+
If we compare the sidecar evaluator thread solution vs sidecar evaluator task (process):
826
+
827
+
Pros (advantages of evaluator thread approach):
828
+
* This does not require a task to be set aside as evaluator, so 1) less work on the user, and 2) there is one fewer version of python binary
829
+
* There is easier communication between the sidecar evaluator (thread) and the coordinator main thread, which is important for many callbacks
830
+
831
+
Cons (disadvantages of evaluator thread approach):
832
+
* This solution presents a challenge when workers can easily become unavailable, in which case it is not straightforward to immediately find another available worker to take over*
833
+
* This solution is blocked on `tf.keras.models.load_model` being available on PS, if `variable_partitioner` is used. Here, model saving and loading are for cloning the model, so if there is an alternative to clone, this solution is not blocked.
834
+
* Users who can afford to allocate a high priority on an evaluator task cannot do so with workers; workers would simply have the same, usually lower, priority (and thus more frequent function-takeovers)*
835
+
836
+
*Fault tolerance, the first con, may further be addressed with possibly another `ClusterCoordinator`, if it shares the threads with the other `ClusterCoordinator`, and the library allows multiple function queues to be accessed by the threads. More details may be discussed in a separate RFC.
837
+
838
+
*Regarding priority, the third con, we can address it by having a separate job (with only one task for now), say "eval_worker", for the worker that is solely used for evaluation. It'd be a little more work where TF_CONFIG, device filter, etc. need to be changed, but it is possible. It gives us the flexibility to assign a higher job priority.
0 commit comments