Skip to content

Commit 428cffd

Browse files
galrotemfacebook-github-bot
authored andcommitted
BC for grad scaler type (#753)
Summary: Pull Request resolved: #753 Reviewed By: diego-urgell Differential Revision: D55132971 fbshipit-source-id: c1008baf411ad89922d51184bc5ac2951d31704b
1 parent c7095bf commit 428cffd

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torchtnt.utils.precision import (
3939
convert_precision_str_to_dtype,
4040
get_grad_scaler_from_precision,
41+
GradScaler,
4142
)
4243
from torchtnt.utils.prepare_module import (
4344
_is_fsdp_module,
@@ -505,7 +506,7 @@ def __init__(
505506
enable_compiled_autograd=enable_compiled_autograd,
506507
)
507508

508-
self.grad_scaler: Optional[torch.amp.GradScaler] = None
509+
self.grad_scaler: Optional[GradScaler] = None
509510
if self.precision:
510511
self.grad_scaler = get_grad_scaler_from_precision(
511512
self.precision,

torchtnt/utils/precision.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from typing import Mapping, Optional
1111

1212
import torch
13-
from torch.cuda.amp.grad_scaler import GradScaler
13+
from torch.cuda.amp.grad_scaler import GradScaler as CudaGradScaler
14+
15+
try:
16+
from torch.amp.grad_scaler import GradScaler
17+
except Exception:
18+
GradScaler = CudaGradScaler
1419

1520
_DTYPE_STRING_TO_DTYPE_MAPPING: Mapping[str, Optional[torch.dtype]] = {
1621
"fp16": torch.float16,
@@ -39,7 +44,7 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:
3944

4045
def get_grad_scaler_from_precision(
4146
precision: torch.dtype, *, is_fsdp_module: Optional[bool] = False
42-
) -> Optional[torch.amp.GradScaler]:
47+
) -> Optional[GradScaler]:
4348
"""
4449
Returns the correct grad scaler to use based on the precision and whether
4550
or not the model is FSDP.
@@ -58,5 +63,5 @@ def get_grad_scaler_from_precision(
5863

5964
return ShardedGradScaler()
6065
else:
61-
return GradScaler()
66+
return CudaGradScaler()
6267
return None

0 commit comments

Comments
 (0)