Skip to content

Commit a32caed

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D78398248 (#3225)
Summary: Pull Request resolved: #3225 This diff reverts D78398248 Diff broke an E2E NE test https://www.internalfb.com/intern/test/281475194407311 (see T232023690). Depends on D78398248 Reviewed By: ztlbells Differential Revision: D78794599 fbshipit-source-id: bc6bf982954de16c87eefc0544a177b8f2b1ce8c
1 parent 01f8654 commit a32caed

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

torchrec/optim/clipping.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
super().__init__(optimizer)
6060
self._clipping = clipping
6161
self._max_gradient = max_gradient
62-
self._norm_type = float(norm_type)
62+
self._norm_type = norm_type
6363
self._check_meta: bool = True
6464
self._enable_global_grad_clip = enable_global_grad_clip
6565
self._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124124
torch.nn.utils.clip_grad_norm_(
125125
replicate_params,
126126
self._max_gradient,
127-
norm_type=self._norm_type,
127+
norm_type=float(self._norm_type),
128128
)
129129
else:
130130
self.clip_grad_norm_()
@@ -139,6 +139,7 @@ def step(self, closure: Any = None) -> None:
139139
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
140140
"""Clip the gradient norm of all parameters."""
141141
max_norm = self._max_gradient
142+
norm_type = float(self._norm_type)
142143
all_grads = []
143144
total_grad_norm = None
144145

@@ -156,15 +157,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
156157
sharded_grad_norm = _batch_cal_norm(
157158
sharded_grads,
158159
max_norm,
159-
self._norm_type,
160+
norm_type,
160161
pgs,
161162
)
162163
total_grad_norm = (
163164
sharded_grad_norm
164165
if total_grad_norm is None
165166
else (
166167
torch.maximum(total_grad_norm, sharded_grad_norm)
167-
if self._norm_type == torch.inf
168+
if norm_type == torch.inf
168169
else total_grad_norm + sharded_grad_norm
169170
)
170171
)
@@ -183,36 +184,27 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
183184
replicated_grad_norm = _batch_cal_norm(
184185
replicated_grads,
185186
max_norm,
186-
self._norm_type,
187+
norm_type,
187188
None,
188189
)
189190
total_grad_norm = (
190191
replicated_grad_norm
191192
if total_grad_norm is None
192193
else (
193194
torch.maximum(total_grad_norm, replicated_grad_norm)
194-
if self._norm_type == torch.inf
195+
if norm_type == torch.inf
195196
else total_grad_norm + replicated_grad_norm
196197
)
197198
)
198199
square_replicated_grad_norm = replicated_grad_norm
199200
else:
200201
square_replicated_grad_norm = 0
201202

202-
if total_grad_norm is not None:
203-
total_grad_norm = (
204-
torch.pow(total_grad_norm, 1.0 / self._norm_type)
205-
if self._norm_type != torch.inf
206-
else total_grad_norm
207-
)
208-
else:
209-
return None
210-
211203
global log_grad_norm
212204
if log_grad_norm:
213-
if total_grad_norm is not None and self._norm_type != torch.inf:
205+
if total_grad_norm is not None and norm_type != torch.inf:
214206
# pyre-ignore[58]
215-
grad_norm = total_grad_norm ** (1.0 / self._norm_type)
207+
grad_norm = total_grad_norm ** (1.0 / norm_type)
216208
else:
217209
grad_norm = total_grad_norm
218210

@@ -221,7 +213,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
221213
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
222214
)
223215

224-
clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6)
216+
# Aggregation
217+
if total_grad_norm is None:
218+
return
219+
220+
if norm_type != torch.inf:
221+
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222+
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223+
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224+
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
225225
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226226
torch._foreach_mul_(all_grads, clip_coef_clamped)
227227
return total_grad_norm

0 commit comments

Comments
 (0)