Skip to content

Commit 6d0c078

Browse files
Victor Bourginfacebook-github-bot
authored andcommitted
Recognizer: Incorporate EMA (#922)
Summary: Add EMA to the recognizer: - Separate out learning rate scheduler updates and EMA model updates: in d2go, the EMA weights were updated every step, while the scheduler was updated every epoch. We separate them to implement the same functionality in Vizard and override `on_train_step_end` to update the EMA weights every step (irrespective of other parameters). - Update torchtnt auto_unit to use self.device for the EMA / SWA model, which may be set from environment in the superclass init. This enables model evaluation in GPU. Differential Revision: D64206735
1 parent 1beb1f0 commit 6d0c078

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def __init__(
512512

513513
self.swa_model = AveragedModel(
514514
module_for_swa,
515-
device=device,
515+
device=self.device,
516516
use_buffers=swa_params.use_buffers,
517517
averaging_method=swa_params.averaging_method,
518518
ema_decay=swa_params.ema_decay,

0 commit comments

Comments
 (0)