File tree Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change 38
38
from torchtnt .utils .precision import (
39
39
convert_precision_str_to_dtype ,
40
40
get_grad_scaler_from_precision ,
41
+ GradScaler ,
41
42
)
42
43
from torchtnt .utils .prepare_module import (
43
44
_is_fsdp_module ,
@@ -505,7 +506,7 @@ def __init__(
505
506
enable_compiled_autograd = enable_compiled_autograd ,
506
507
)
507
508
508
- self .grad_scaler : Optional [torch . amp . GradScaler ] = None
509
+ self .grad_scaler : Optional [GradScaler ] = None
509
510
if self .precision :
510
511
self .grad_scaler = get_grad_scaler_from_precision (
511
512
self .precision ,
Original file line number Diff line number Diff line change 10
10
from typing import Mapping , Optional
11
11
12
12
import torch
13
- from torch .cuda .amp .grad_scaler import GradScaler
13
+ from torch .cuda .amp .grad_scaler import GradScaler as CudaGradScaler
14
+
15
+ try :
16
+ from torch .amp .grad_scaler import GradScaler
17
+ except Exception :
18
+ GradScaler = CudaGradScaler
14
19
15
20
_DTYPE_STRING_TO_DTYPE_MAPPING : Mapping [str , Optional [torch .dtype ]] = {
16
21
"fp16" : torch .float16 ,
@@ -39,7 +44,7 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:
39
44
40
45
def get_grad_scaler_from_precision (
41
46
precision : torch .dtype , * , is_fsdp_module : Optional [bool ] = False
42
- ) -> Optional [torch . amp . GradScaler ]:
47
+ ) -> Optional [GradScaler ]:
43
48
"""
44
49
Returns the correct grad scaler to use based on the precision and whether
45
50
or not the model is FSDP.
@@ -58,5 +63,5 @@ def get_grad_scaler_from_precision(
58
63
59
64
return ShardedGradScaler ()
60
65
else :
61
- return GradScaler ()
66
+ return CudaGradScaler ()
62
67
return None
You can’t perform that action at this time.
0 commit comments