@@ -59,7 +59,7 @@ def __init__(
59
59
super ().__init__ (optimizer )
60
60
self ._clipping = clipping
61
61
self ._max_gradient = max_gradient
62
- self ._norm_type = norm_type
62
+ self ._norm_type = float ( norm_type )
63
63
self ._check_meta : bool = True
64
64
self ._enable_global_grad_clip = enable_global_grad_clip
65
65
self ._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124
124
torch .nn .utils .clip_grad_norm_ (
125
125
replicate_params ,
126
126
self ._max_gradient ,
127
- norm_type = float ( self ._norm_type ) ,
127
+ norm_type = self ._norm_type ,
128
128
)
129
129
else :
130
130
self .clip_grad_norm_ ()
@@ -139,7 +139,6 @@ def step(self, closure: Any = None) -> None:
139
139
def clip_grad_norm_ (self ) -> Optional [Union [float , torch .Tensor ]]:
140
140
"""Clip the gradient norm of all parameters."""
141
141
max_norm = self ._max_gradient
142
- norm_type = float (self ._norm_type )
143
142
all_grads = []
144
143
total_grad_norm = None
145
144
@@ -157,15 +156,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
157
156
sharded_grad_norm = _batch_cal_norm (
158
157
sharded_grads ,
159
158
max_norm ,
160
- norm_type ,
159
+ self . _norm_type ,
161
160
pgs ,
162
161
)
163
162
total_grad_norm = (
164
163
sharded_grad_norm
165
164
if total_grad_norm is None
166
165
else (
167
166
torch .maximum (total_grad_norm , sharded_grad_norm )
168
- if norm_type == torch .inf
167
+ if self . _norm_type == torch .inf
169
168
else total_grad_norm + sharded_grad_norm
170
169
)
171
170
)
@@ -184,27 +183,36 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
184
183
replicated_grad_norm = _batch_cal_norm (
185
184
replicated_grads ,
186
185
max_norm ,
187
- norm_type ,
186
+ self . _norm_type ,
188
187
None ,
189
188
)
190
189
total_grad_norm = (
191
190
replicated_grad_norm
192
191
if total_grad_norm is None
193
192
else (
194
193
torch .maximum (total_grad_norm , replicated_grad_norm )
195
- if norm_type == torch .inf
194
+ if self . _norm_type == torch .inf
196
195
else total_grad_norm + replicated_grad_norm
197
196
)
198
197
)
199
198
square_replicated_grad_norm = replicated_grad_norm
200
199
else :
201
200
square_replicated_grad_norm = 0
202
201
202
+ if total_grad_norm is not None :
203
+ total_grad_norm = (
204
+ torch .pow (total_grad_norm , 1.0 / self ._norm_type )
205
+ if self ._norm_type != torch .inf
206
+ else total_grad_norm
207
+ )
208
+ else :
209
+ return None
210
+
203
211
global log_grad_norm
204
212
if log_grad_norm :
205
- if total_grad_norm is not None and norm_type != torch .inf :
213
+ if total_grad_norm is not None and self . _norm_type != torch .inf :
206
214
# pyre-ignore[58]
207
- grad_norm = total_grad_norm ** (1.0 / norm_type )
215
+ grad_norm = total_grad_norm ** (1.0 / self . _norm_type )
208
216
else :
209
217
grad_norm = total_grad_norm
210
218
@@ -213,15 +221,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
213
221
f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214
222
)
215
223
216
- # Aggregation
217
- if total_grad_norm is None :
218
- return
219
-
220
- if norm_type != torch .inf :
221
- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222
- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223
- # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224
- clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
224
+ clip_coef = torch .tensor (max_norm ) / (total_grad_norm + 1e-6 )
225
225
clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226
226
torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227
227
return total_grad_norm
0 commit comments