Skip to content

Commit 926b5ec

Browse files
alanhdufacebook-github-bot
authored andcommitted
Add callback for enabling Tensorfloat32 (#885)
Summary: Pull Request resolved: #885 This is something that can boost performance quite a bit with float32 training on CUDA, so I figured it'd make sense to package it up into a re-useable callback. Reviewed By: diego-urgell Differential Revision: D61608792 fbshipit-source-id: ccd0712c9022029bf59ee0730a71ad59feea60ae
1 parent ebda066 commit 926b5ec

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import contextlib
10+
import unittest
11+
from typing import Iterator
12+
13+
import torch
14+
from torchtnt.framework._test_utils import (
15+
DummyFitUnit,
16+
DummyPredictUnit,
17+
DummyTrainUnit,
18+
generate_random_dataloader,
19+
)
20+
from torchtnt.framework.callback import Callback
21+
from torchtnt.framework.callbacks.tensorfloat32 import EnableTensorFloat32
22+
from torchtnt.framework.fit import fit
23+
from torchtnt.framework.predict import predict
24+
from torchtnt.framework.state import State
25+
from torchtnt.framework.train import train
26+
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
27+
28+
29+
class _CheckTensorFloat32Enabled(Callback):
30+
def __init__(self, testcase: unittest.TestCase) -> None:
31+
self.testcase = testcase
32+
33+
def assert_enabled(self) -> None:
34+
self.testcase.assertEqual(torch.get_float32_matmul_precision(), "high")
35+
self.testcase.assertTrue(torch.backends.cudnn.allow_tf32)
36+
self.testcase.assertTrue(torch.backends.cuda.matmul.allow_tf32)
37+
38+
def on_train_step_start(self, state: State, unit: TTrainUnit) -> None:
39+
self.assert_enabled()
40+
41+
def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None:
42+
self.assert_enabled()
43+
44+
def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None:
45+
self.assert_enabled()
46+
47+
48+
class EnableTensorFloat32Test(unittest.TestCase):
49+
@contextlib.contextmanager
50+
def check_proper_restore(self) -> Iterator[EnableTensorFloat32]:
51+
callback = EnableTensorFloat32()
52+
53+
# Disable TensorFloat32
54+
torch.set_float32_matmul_precision("highest")
55+
torch.backends.cudnn.allow_tf32 = False
56+
torch.backends.cuda.matmul.allow_tf32 = False
57+
58+
yield callback
59+
60+
# Original Values are Restored
61+
self.assertIsNone(callback.original_cuda_matmul)
62+
self.assertIsNone(callback.original_cudnn)
63+
self.assertIsNone(callback.original_float32_matmul_precision)
64+
65+
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
66+
self.assertFalse(torch.backends.cudnn.allow_tf32)
67+
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
68+
69+
def test_tensorfloat32_callback_train(self) -> None:
70+
input_dim = batch_size = max_epochs = 2
71+
dataset_len = 5
72+
73+
unit = DummyTrainUnit(input_dim=input_dim)
74+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
75+
with self.check_proper_restore() as callback:
76+
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
77+
train(unit, dataloader, max_epochs=max_epochs, callbacks=callbacks)
78+
79+
def test_tensorfloat32_callback_fit(self) -> None:
80+
input_dim = batch_size = max_epochs = 2
81+
dataset_len = 5
82+
83+
unit = DummyFitUnit(input_dim=input_dim)
84+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
85+
with self.check_proper_restore() as callback:
86+
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
87+
fit(
88+
unit,
89+
dataloader,
90+
dataloader,
91+
max_epochs=max_epochs,
92+
callbacks=callbacks,
93+
)
94+
95+
def test_tensorfloat32_callback_predict(self) -> None:
96+
input_dim = batch_size = 2
97+
dataset_len = 5
98+
99+
unit = DummyPredictUnit(input_dim=input_dim)
100+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
101+
with self.check_proper_restore() as callback:
102+
callbacks: list[Callback] = [callback, _CheckTensorFloat32Enabled(self)]
103+
predict(unit, dataloader, callbacks=callbacks)

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .slow_rank_detector import SlowRankDetector
2323
from .system_resources_monitor import SystemResourcesMonitor
2424
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
25+
from .tensorfloat32 import EnableTensorFloat32
2526
from .throughput_logger import ThroughputLogger
2627
from .time_limit_interrupter import TimeLimitInterrupter
2728
from .time_wait_for_batch_logger import TimeWaitForBatchLogger
@@ -34,6 +35,7 @@
3435
"BaseCSVWriter",
3536
"EarlyStopping",
3637
"EmptyCudaCache",
38+
"EnableTensorFloat32",
3739
"GarbageCollector",
3840
"IterationTimeLogger",
3941
"Lambda",
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
from typing import Optional
11+
12+
import torch
13+
from torchtnt.framework.callback import Callback
14+
from torchtnt.framework.state import EntryPoint, State
15+
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
16+
from torchtnt.utils.rank_zero_log import rank_zero_info
17+
18+
logger: logging.Logger = logging.getLogger(__name__)
19+
20+
21+
class EnableTensorFloat32(Callback):
22+
"""
23+
A callback that enables TensorFloat32 operations on CUDA.
24+
25+
Args:
26+
float32_matmul_precision: precision to use for float32 matmul operations.
27+
See `torch.set_float32_matmul_precision` for details.
28+
"""
29+
30+
def __init__(self, float32_matmul_precision: str = "high") -> None:
31+
self.float32_matmul_precision = float32_matmul_precision
32+
33+
self.original_float32_matmul_precision: Optional[str] = None
34+
self.original_cuda_matmul: Optional[bool] = None
35+
self.original_cudnn: Optional[bool] = None
36+
37+
def _enable(self) -> None:
38+
rank_zero_info("Enabling TensorFloat32 operations on CUDA", logger=logger)
39+
assert self.original_float32_matmul_precision is None
40+
assert self.original_cuda_matmul is None
41+
assert self.original_cudnn is None
42+
43+
self.original_float32_matmul_precision = torch.get_float32_matmul_precision()
44+
self.original_cuda_matmul = torch.backends.cuda.matmul.allow_tf32
45+
self.original_cudnn = torch.backends.cudnn.allow_tf32
46+
47+
torch.set_float32_matmul_precision(self.float32_matmul_precision)
48+
torch.backends.cuda.matmul.allow_tf32 = True
49+
torch.backends.cudnn.allow_tf32 = True
50+
51+
def _reset(self) -> None:
52+
rank_zero_info(
53+
"Restoring original TensorFloat32 permissions on CUDA", logger=logger
54+
)
55+
if self.original_float32_matmul_precision is not None:
56+
torch.set_float32_matmul_precision(self.original_float32_matmul_precision)
57+
self.original_float32_matmul_precision = None
58+
59+
if self.original_cuda_matmul is not None:
60+
torch.backends.cuda.matmul.allow_tf32 = self.original_cuda_matmul
61+
self.original_cuda_matmul = None
62+
63+
if self.original_cudnn is not None:
64+
torch.backends.cudnn.allow_tf32 = self.original_cudnn
65+
self.original_cudnn = None
66+
67+
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
68+
self._enable()
69+
70+
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
71+
self._reset()
72+
73+
def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
74+
if state.entry_point == EntryPoint.FIT:
75+
return # if fitting, this is already handled in on_train_start
76+
self._enable()
77+
78+
def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
79+
if state.entry_point == EntryPoint.FIT:
80+
return # if fitting, this is already handled in on_train_end
81+
self._reset()
82+
83+
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
84+
self._enable()
85+
86+
def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
87+
self._reset()

0 commit comments

Comments
 (0)