Skip to content

Commit 2ce298f

Browse files
y-sqfacebook-github-bot
authored andcommitted
Don't recompute scaling factor in activation checkpointing (#951)
Summary: Pull Request resolved: #951 Add the policy "layer_based_auto_wrap_policy_float8_training". It skips the recompute of float8 scaling factor (a scaler) to improve the latency. To enable it, change the config file like: P1690229394 Reviewed By: yoyoyocmu Differential Revision: D65360604 fbshipit-source-id: bd8c052fcf3c8af48775c08ef66a1e367397f09b
1 parent dc5bafd commit 2ce298f

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

torchtnt/utils/prepare_module.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,17 @@
88

99
from dataclasses import asdict, dataclass
1010
from functools import partial
11-
from typing import Any, Callable, cast, Dict, Iterable, Optional, Union
11+
from typing import (
12+
Any,
13+
Callable,
14+
cast,
15+
ContextManager,
16+
Dict,
17+
Iterable,
18+
Optional,
19+
Tuple,
20+
Union,
21+
)
1222

1323
import torch
1424
import torch.distributed as dist
@@ -165,6 +175,8 @@ class ActivationCheckpointParams:
165175
checkpoint_impl: CheckpointImpl
166176
check_fn: Callable[[torch.nn.Module], bool] = lambda _: True
167177
auto_wrap_policy: Optional[Callable[[torch.nn.Module, bool, int], bool]] = None
178+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
179+
context_fn: Optional[Callable[[], Tuple[ContextManager, ContextManager]]] = None
168180

169181

170182
def prepare_ddp(
@@ -357,9 +369,14 @@ def prepare_module(
357369
checkpoint_impl = activation_checkpoint_params.checkpoint_impl
358370
check_fn = activation_checkpoint_params.check_fn
359371
auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy
372+
context_fn = activation_checkpoint_params.context_fn
373+
additional_params = {}
374+
if context_fn:
375+
additional_params["context_fn"] = context_fn
360376
custom_checkpoint_wrapper = partial(
361377
checkpoint_wrapper,
362378
checkpoint_impl=checkpoint_impl,
379+
**additional_params,
363380
)
364381
apply_activation_checkpointing(
365382
module,

0 commit comments

Comments
 (0)