@@ -59,7 +59,7 @@ def __init__(
59
59
super ().__init__ (optimizer )
60
60
self ._clipping = clipping
61
61
self ._max_gradient = max_gradient
62
- self ._norm_type = float ( norm_type )
62
+ self ._norm_type = norm_type
63
63
self ._check_meta : bool = True
64
64
self ._enable_global_grad_clip = enable_global_grad_clip
65
65
self ._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124
124
torch .nn .utils .clip_grad_norm_ (
125
125
replicate_params ,
126
126
self ._max_gradient ,
127
- norm_type = self ._norm_type ,
127
+ norm_type = float ( self ._norm_type ) ,
128
128
)
129
129
else :
130
130
self .clip_grad_norm_ ()
@@ -139,6 +139,7 @@ def step(self, closure: Any = None) -> None:
139
139
def clip_grad_norm_ (self ) -> Optional [Union [float , torch .Tensor ]]:
140
140
"""Clip the gradient norm of all parameters."""
141
141
max_norm = self ._max_gradient
142
+ norm_type = float (self ._norm_type )
142
143
all_grads = []
143
144
total_grad_norm = None
144
145
@@ -156,15 +157,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
156
157
sharded_grad_norm = _batch_cal_norm (
157
158
sharded_grads ,
158
159
max_norm ,
159
- self . _norm_type ,
160
+ norm_type ,
160
161
pgs ,
161
162
)
162
163
total_grad_norm = (
163
164
sharded_grad_norm
164
165
if total_grad_norm is None
165
166
else (
166
167
torch .maximum (total_grad_norm , sharded_grad_norm )
167
- if self . _norm_type == torch .inf
168
+ if norm_type == torch .inf
168
169
else total_grad_norm + sharded_grad_norm
169
170
)
170
171
)
@@ -183,36 +184,27 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
183
184
replicated_grad_norm = _batch_cal_norm (
184
185
replicated_grads ,
185
186
max_norm ,
186
- self . _norm_type ,
187
+ norm_type ,
187
188
None ,
188
189
)
189
190
total_grad_norm = (
190
191
replicated_grad_norm
191
192
if total_grad_norm is None
192
193
else (
193
194
torch .maximum (total_grad_norm , replicated_grad_norm )
194
- if self . _norm_type == torch .inf
195
+ if norm_type == torch .inf
195
196
else total_grad_norm + replicated_grad_norm
196
197
)
197
198
)
198
199
square_replicated_grad_norm = replicated_grad_norm
199
200
else :
200
201
square_replicated_grad_norm = 0
201
202
202
- if total_grad_norm is not None :
203
- total_grad_norm = (
204
- torch .pow (total_grad_norm , 1.0 / self ._norm_type )
205
- if self ._norm_type != torch .inf
206
- else total_grad_norm
207
- )
208
- else :
209
- return None
210
-
211
203
global log_grad_norm
212
204
if log_grad_norm :
213
- if total_grad_norm is not None and self . _norm_type != torch .inf :
205
+ if total_grad_norm is not None and norm_type != torch .inf :
214
206
# pyre-ignore[58]
215
- grad_norm = total_grad_norm ** (1.0 / self . _norm_type )
207
+ grad_norm = total_grad_norm ** (1.0 / norm_type )
216
208
else :
217
209
grad_norm = total_grad_norm
218
210
@@ -221,7 +213,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
221
213
f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
222
214
)
223
215
224
- clip_coef = torch .tensor (max_norm ) / (total_grad_norm + 1e-6 )
216
+ # Aggregation
217
+ if total_grad_norm is None :
218
+ return
219
+
220
+ if norm_type != torch .inf :
221
+ # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222
+ total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223
+ # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224
+ clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
225
225
clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226
226
torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227
227
return total_grad_norm
0 commit comments