Skip to content

Commit f9f566b

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
enable use_buffer for SWA in AutoUnit (#844)
Summary: Pull Request resolved: #844 Enables batch normalization statistics to be updated via `use_buffers` flag. See https://fb.workplace.com/groups/1323951304836028/posts/1668131377084684/?comment_id=1668215887076233 for user req This functionality existed in the base AveragedModel but was never public in TNT Reviewed By: diego-urgell Differential Revision: D58368554 fbshipit-source-id: fff608da39c471ee6e61af945a2f782e76ece5b9
1 parent 9d99ea6 commit f9f566b

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class SWAParams:
8484
Args:
8585
warmup_steps_or_epochs: number of steps or epochs before starting SWA
8686
step_or_epoch_update_freq: number of steps or epochs between each SWA update
87+
use_buffers: if ``True``, it will compute running averages for
88+
both the parameters and the buffers of the model. (default: ``True``)
89+
This will update activation statistics for Batch Normalization. This is an
90+
alternative to calling `torch.optim.swa_utils.update_bn` post-training.
8791
averaging_method: whether to use SWA or EMA to average model weights
8892
ema_decay: the exponential decay applied to the averaged parameters. This param
8993
is only needed for EMA, and is ignored otherwise (for SWA).
@@ -101,6 +105,7 @@ class SWAParams:
101105

102106
warmup_steps_or_epochs: int
103107
step_or_epoch_update_freq: int
108+
use_buffers: bool = True
104109
averaging_method: Literal["ema", "swa"] = "ema"
105110
ema_decay: float = 0.999
106111
use_lit: bool = False
@@ -487,7 +492,7 @@ def __init__(
487492
self.swa_model = AveragedModel(
488493
module_for_swa,
489494
device=device,
490-
use_buffers=True,
495+
use_buffers=swa_params.use_buffers,
491496
averaging_method=swa_params.averaging_method,
492497
ema_decay=swa_params.ema_decay,
493498
skip_deepcopy=skip_deepcopy,

torchtnt/utils/swa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def __init__(
4646
to see what the model, device, and use_buffer arguments entail.
4747
4848
Args:
49+
use_buffers: if ``True``, it will compute running averages for
50+
both the parameters and the buffers of the model. (default: ``False``)
51+
This will update activation statistics for Batch Normalization.
4952
averaging_method: Whether to use EMA or SWA.
5053
ema_decay: The exponential decay applied to the averaged parameters. This param
5154
is only needed for EMA, and is ignored otherwise (for SWA).

0 commit comments

Comments
 (0)