Skip to content

Commit 83f8114

Browse files
wz337facebook-github-bot
authored andcommitted
Refactor _batch_cal_norm and remove #pyre-ignore (#3200)
Summary: Pull Request resolved: #3200 As title. 1. Parse `norm_type` by `self._norm_type = float(norm_type)` at GradientClippingOptimizer `__init__()` immediately and only use `self._norm_type` later on so it's not susceptible to the error in D78326114. 2. Remove repeated `total_grad_norm` calculation 3. Fix #pyre errors. Reviewed By: aliafzal Differential Revision: D78398248 fbshipit-source-id: 44fe4cb19609503168beba2a0e1e9c1c23f097cc
1 parent 3ac28bf commit 83f8114

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

torchrec/optim/clipping.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
super().__init__(optimizer)
6060
self._clipping = clipping
6161
self._max_gradient = max_gradient
62-
self._norm_type = norm_type
62+
self._norm_type = float(norm_type)
6363
self._check_meta: bool = True
6464
self._enable_global_grad_clip = enable_global_grad_clip
6565
self._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124124
torch.nn.utils.clip_grad_norm_(
125125
replicate_params,
126126
self._max_gradient,
127-
norm_type=float(self._norm_type),
127+
norm_type=self._norm_type,
128128
)
129129
else:
130130
self.clip_grad_norm_()
@@ -139,7 +139,6 @@ def step(self, closure: Any = None) -> None:
139139
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
140140
"""Clip the gradient norm of all parameters."""
141141
max_norm = self._max_gradient
142-
norm_type = float(self._norm_type)
143142
all_grads = []
144143
total_grad_norm = None
145144

@@ -157,15 +156,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
157156
sharded_grad_norm = _batch_cal_norm(
158157
sharded_grads,
159158
max_norm,
160-
norm_type,
159+
self._norm_type,
161160
pgs,
162161
)
163162
total_grad_norm = (
164163
sharded_grad_norm
165164
if total_grad_norm is None
166165
else (
167166
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if norm_type == torch.inf
167+
if self._norm_type == torch.inf
169168
else total_grad_norm + sharded_grad_norm
170169
)
171170
)
@@ -184,27 +183,36 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
184183
replicated_grad_norm = _batch_cal_norm(
185184
replicated_grads,
186185
max_norm,
187-
norm_type,
186+
self._norm_type,
188187
None,
189188
)
190189
total_grad_norm = (
191190
replicated_grad_norm
192191
if total_grad_norm is None
193192
else (
194193
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if norm_type == torch.inf
194+
if self._norm_type == torch.inf
196195
else total_grad_norm + replicated_grad_norm
197196
)
198197
)
199198
square_replicated_grad_norm = replicated_grad_norm
200199
else:
201200
square_replicated_grad_norm = 0
202201

202+
if total_grad_norm is not None:
203+
total_grad_norm = (
204+
torch.pow(total_grad_norm, 1.0 / self._norm_type)
205+
if self._norm_type != torch.inf
206+
else total_grad_norm
207+
)
208+
else:
209+
return None
210+
203211
global log_grad_norm
204212
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
213+
if total_grad_norm is not None and self._norm_type != torch.inf:
206214
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
215+
grad_norm = total_grad_norm ** (1.0 / self._norm_type)
208216
else:
209217
grad_norm = total_grad_norm
210218

@@ -213,15 +221,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
213221
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214222
)
215223

216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
219-
220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223-
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224-
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
224+
clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6)
225225
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226226
torch._foreach_mul_(all_grads, clip_coef_clamped)
227227
return total_grad_norm

0 commit comments

Comments
 (0)