Skip to content

Commit 123453e

Browse files
alanhdufacebook-github-bot
authored andcommitted
Use torch.amp instead of torch.cuda.amp (#877)
Summary: Pull Request resolved: #877 `torch.cuda.amp` throws up some deprecation warnings, so let's jus use `torch.amp` instead. This requires Pytorch 2.3+, so I've also modified the requirements.txt to add a minimum supported torch version. Reviewed By: JKSenthil Differential Revision: D61151603 fbshipit-source-id: 83ff784132ed21565a51e06c5e7bfa69c113aeda
1 parent c39dadb commit 123453e

File tree

5 files changed

+15
-22
lines changed

5 files changed

+15
-22
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch
1+
torch>=2.3.0
22
numpy==1.24.4
33
fsspec
44
tensorboard

tests/framework/test_app_state_mixin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self) -> None:
3232
self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
3333
self.optimizer_c, step_size=30, gamma=0.1
3434
)
35-
self.grad_scaler_e = torch.cuda.amp.GradScaler()
35+
self.grad_scaler_e = torch.amp.GradScaler("cuda")
3636
self.optimizer_class_f = torch.optim.SGD
3737

3838

@@ -218,7 +218,7 @@ def __init__(self) -> None:
218218
self.lr_2 = torch.optim.lr_scheduler.StepLR(
219219
self.optimizer_placeholder, step_size=50, gamma=0.3
220220
)
221-
self.grad_scaler_e = torch.cuda.amp.GradScaler()
221+
self.grad_scaler_e = torch.amp.GradScaler("cuda")
222222

223223
def tracked_modules(self) -> Dict[str, nn.Module]:
224224
ret = super().tracked_modules()
@@ -235,7 +235,7 @@ def tracked_lr_schedulers(
235235

236236
def tracked_misc_statefuls(self) -> Dict[str, Any]:
237237
ret = super().tracked_misc_statefuls()
238-
ret["another_scaler"] = torch.cuda.amp.GradScaler()
238+
ret["another_scaler"] = torch.amp.GradScaler("cuda")
239239
return ret
240240

241241
o = Override()
@@ -266,6 +266,6 @@ def test_construct_tracked_optimizers_and_schedulers(self) -> None:
266266
):
267267
result = auto_unit._construct_tracked_optimizers_and_schedulers()
268268

269-
self.assertTrue(isinstance(result["optimizer"], FSDPOptimizerWrapper))
270-
self.assertTrue(isinstance(result["optim2"], torch.optim.Optimizer))
271-
self.assertTrue(isinstance(result["lr_scheduler"], TLRScheduler))
269+
self.assertIsInstance(result["optimizer"], FSDPOptimizerWrapper)
270+
self.assertIsInstance(result["optim2"], torch.optim.Optimizer)
271+
self.assertIsInstance(result["lr_scheduler"], TLRScheduler)

tests/framework/test_auto_unit.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,9 @@ def test_app_state_mixin(self) -> None:
6262
)
6363

6464
self.assertEqual(auto_unit.tracked_modules()["module"], my_module)
65-
self.assertTrue(
66-
isinstance(
67-
auto_unit.tracked_misc_statefuls()["grad_scaler"],
68-
torch.cuda.amp.GradScaler,
69-
)
65+
self.assertIsInstance(
66+
auto_unit.tracked_misc_statefuls()["grad_scaler"],
67+
torch.amp.GradScaler,
7068
)
7169
for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
7270
self.assertIn(key, auto_unit.app_state())

tests/utils/test_precision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import unittest
1111

1212
import torch
13-
from torch.cuda.amp.grad_scaler import GradScaler
13+
from torch.amp.grad_scaler import GradScaler
1414
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
1515

1616
from torchtnt.utils.precision import (
@@ -49,7 +49,7 @@ def test_get_grad_scaler_from_precision(self) -> None:
4949
grad_scaler = get_grad_scaler_from_precision(
5050
torch.float16, is_fsdp_module=False
5151
)
52-
self.assertTrue(isinstance(grad_scaler, GradScaler))
52+
self.assertIsInstance(grad_scaler, GradScaler)
5353

5454
grad_scaler = get_grad_scaler_from_precision(torch.float16, is_fsdp_module=True)
55-
self.assertTrue(isinstance(grad_scaler, ShardedGradScaler))
55+
self.assertIsInstance(grad_scaler, ShardedGradScaler)

torchtnt/utils/precision.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010
from typing import Mapping, Optional
1111

1212
import torch
13-
from torch.cuda.amp.grad_scaler import GradScaler as CudaGradScaler
14-
15-
try:
16-
from torch.amp.grad_scaler import GradScaler
17-
except Exception:
18-
GradScaler = CudaGradScaler
13+
from torch.amp.grad_scaler import GradScaler
1914

2015
_DTYPE_STRING_TO_DTYPE_MAPPING: Mapping[str, Optional[torch.dtype]] = {
2116
"fp16": torch.float16,
@@ -63,5 +58,5 @@ def get_grad_scaler_from_precision(
6358

6459
return ShardedGradScaler()
6560
else:
66-
return CudaGradScaler()
61+
return GradScaler("cuda")
6762
return None

0 commit comments

Comments
 (0)