Skip to content

Commit b34da0d

Browse files
weifengpyfacebook-github-bot
authored andcommitted
return total grad norm in torchrec grad clipping (meta-pytorch#2507)
Summary: Pull Request resolved: meta-pytorch#2507 this is to keep consistent with torch.nn.utils.clip_grad_norm_ Reviewed By: awgu Differential Revision: D64712277 fbshipit-source-id: 689e02bd21dc37568c3347d5b0833c573f042c15
1 parent 1a57ce1 commit b34da0d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchrec/optim/clipping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def step(self, closure: Any = None) -> None:
136136
self._step_num += 1
137137

138138
@torch.no_grad()
139-
def clip_grad_norm_(self) -> None:
139+
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
140140
"""Clip the gradient norm of all parameters."""
141141
max_norm = self._max_gradient
142142
norm_type = float(self._norm_type)
@@ -224,6 +224,7 @@ def clip_grad_norm_(self) -> None:
224224
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
225225
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226226
torch._foreach_mul_(all_grads, clip_coef_clamped)
227+
return total_grad_norm
227228

228229

229230
def _batch_cal_norm(

0 commit comments

Comments
 (0)