Skip to content

Commit 80703ca

Browse files
drisspgpytorchmergebot
authored andcommitted
[FlexAttention] Allow dispatch to SAC for flex (pytorch#150080)
Pull Request resolved: pytorch#150080 Approved by: https://github.com/zou3519
1 parent fa63de0 commit 80703ca

File tree

5 files changed

+165
-6
lines changed

5 files changed

+165
-6
lines changed

test/inductor/test_flex_attention.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from unittest.mock import patch
1515

1616
import torch
17+
import torch.nn as nn
1718
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
1819
from torch._inductor import metrics
1920
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
@@ -30,6 +31,7 @@
3031
BlockMask,
3132
create_block_mask,
3233
flex_attention,
34+
flex_attention_hop,
3335
noop_mask,
3436
or_masks,
3537
)
@@ -3842,6 +3844,132 @@ def forward(self, q, k, v, block_mask):
38423844
attn_output = mod(q, k, v, mask)
38433845
self.assertEqual(attn_output.device, torch.device("cuda:1"))
38443846

3847+
@supported_platform
3848+
@skip_on_cpu
3849+
@common_utils.parametrize(
3850+
"ops_to_save",
3851+
[
3852+
[
3853+
torch.ops.aten.mm.default,
3854+
],
3855+
[
3856+
flex_attention_hop,
3857+
],
3858+
[torch.ops.aten.mm.default, flex_attention_hop],
3859+
],
3860+
)
3861+
def test_selective_ac(self, device, ops_to_save):
3862+
class FlexAttentionModule(nn.Module):
3863+
def __init__(self, hidden_size, num_heads):
3864+
super().__init__()
3865+
self.hidden_size = hidden_size
3866+
self.num_heads = num_heads
3867+
self.head_dim = hidden_size // num_heads
3868+
3869+
# In-projections (query, key, value)
3870+
self.q_proj = nn.Linear(hidden_size, hidden_size)
3871+
self.k_proj = nn.Linear(hidden_size, hidden_size)
3872+
self.v_proj = nn.Linear(hidden_size, hidden_size)
3873+
3874+
# Out-projection
3875+
self.out_proj = nn.Linear(hidden_size, hidden_size)
3876+
3877+
def forward(self, x):
3878+
batch_size, seq_len, _ = x.size()
3879+
3880+
# Project queries, keys, and values
3881+
q = (
3882+
self.q_proj(x)
3883+
.view(batch_size, seq_len, self.num_heads, self.head_dim)
3884+
.transpose(1, 2)
3885+
)
3886+
k = (
3887+
self.k_proj(x)
3888+
.view(batch_size, seq_len, self.num_heads, self.head_dim)
3889+
.transpose(1, 2)
3890+
)
3891+
v = (
3892+
self.v_proj(x)
3893+
.view(batch_size, seq_len, self.num_heads, self.head_dim)
3894+
.transpose(1, 2)
3895+
)
3896+
3897+
# Apply flex attention
3898+
attn_output = flex_attention(
3899+
q,
3900+
k,
3901+
v,
3902+
)
3903+
3904+
# Reshape output
3905+
attn_output = (
3906+
attn_output.transpose(1, 2)
3907+
.contiguous()
3908+
.view(batch_size, seq_len, self.hidden_size)
3909+
)
3910+
3911+
# Out projection
3912+
output = self.out_proj(attn_output)
3913+
3914+
return output
3915+
3916+
from torch.utils.checkpoint import (
3917+
checkpoint,
3918+
create_selective_checkpoint_contexts,
3919+
)
3920+
3921+
context_fn = functools.partial(
3922+
create_selective_checkpoint_contexts, ops_to_save
3923+
)
3924+
3925+
# Define a model that uses FlexAttention with selective activation checkpointing
3926+
class SacModule(nn.Module):
3927+
def __init__(self, hidden_size, num_heads, context_fn):
3928+
super().__init__()
3929+
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
3930+
self.context_fn = context_fn
3931+
3932+
def forward(self, x):
3933+
def flex_attn_fn(x):
3934+
return self.flex_attn(x)
3935+
3936+
output = checkpoint(
3937+
flex_attn_fn,
3938+
x,
3939+
use_reentrant=False,
3940+
context_fn=self.context_fn,
3941+
)
3942+
3943+
return output
3944+
3945+
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
3946+
"cuda", dtype=torch.bfloat16
3947+
)
3948+
x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16)
3949+
3950+
# Run without compilation
3951+
output_module = flex_module(x)
3952+
compiled_module = torch.compile(flex_module)
3953+
output_compiled = compiled_module(x)
3954+
3955+
torch.testing.assert_close(output_module, output_compiled, rtol=1e-2, atol=1e-2)
3956+
3957+
# Calculate gradients and compare them
3958+
x.requires_grad_(True)
3959+
output_module = flex_module(x)
3960+
output_compiled = compiled_module(x)
3961+
grad_output = torch.ones_like(output_module)
3962+
3963+
grad_module = torch.autograd.grad(
3964+
outputs=output_module, inputs=x, grad_outputs=grad_output, retain_graph=True
3965+
)[0]
3966+
3967+
grad_compiled = torch.autograd.grad(
3968+
outputs=output_compiled, inputs=x, grad_outputs=grad_output
3969+
)[0]
3970+
3971+
torch.testing.assert_close(grad_module, grad_compiled, rtol=1e-2, atol=1e-2)
3972+
38453973
@supported_platform
38463974
@skip_on_cpu
38473975
def test_validate_small_embedding_size_error_message(self, device):

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,9 @@ def call_function(
941941
) -> VariableTracker:
942942
unimplemented(f"HigherOrderOperator {self.value.__name__}")
943943

944+
def as_python_constant(self):
945+
return self.value
946+
944947

945948
class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
946949
"""

torch/_higher_order_ops/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_has_potential_branch_input_mutation,
1111
_maybe_reenter_make_fx,
1212
autograd_not_implemented,
13+
redirect_to_mode,
1314
reenter_make_fx,
1415
save_tensors_and_symints_for_backward,
1516
saved_tensors_and_symints,
@@ -24,6 +25,7 @@
2425
track_tensor_tree,
2526
)
2627
from torch.fx.graph_module import GraphModule
28+
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
2729

2830

2931
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
@@ -481,6 +483,11 @@ def flex_attention_fake_tensor_mode(
481483
return out, logsumexp
482484

483485

486+
# Registers dispatches for SAC
487+
redirect_to_mode(flex_attention, _CachingTorchDispatchMode)
488+
redirect_to_mode(flex_attention, _CachedTorchDispatchMode)
489+
490+
484491
# ---------------------------- Autograd Implementation ----------------------------
485492
def create_fw_bw_graph(
486493
score_mod: Callable,

torch/_higher_order_ops/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,24 @@ def _maybe_fake_prop_ignore_unbacked(fn, args):
531531
return fn(*args)
532532

533533

534+
def redirect_to_mode(hop: OperatorBase, mode):
535+
"""Utility for redispatching HOP to underlying mode
536+
537+
Args:
538+
hop: The HOP to redispatch
539+
mode: The mode to redispatch to
540+
541+
Returns:
542+
A decorated function that implements the HOP for the given mode
543+
"""
544+
545+
@hop.py_impl(mode)
546+
def impl(mode, *args, **kwargs):
547+
return mode.__torch_dispatch__(hop, [], args, kwargs)
548+
549+
return impl
550+
551+
534552
# TODO: The parameter use_output_and_grad_bw is required because some operations
535553
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
536554
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
@@ -897,10 +915,7 @@ def register_fake(hop, fn=None):
897915
def register(func):
898916
from torch._subclasses.fake_tensor import FakeTensorMode
899917

900-
# Redirect the hop to the fake tensor mode implementation.
901-
@hop.py_impl(FakeTensorMode)
902-
def _(mode, *args, **kwargs):
903-
return mode.__torch_dispatch__(hop, [], args, kwargs)
918+
redirect_to_mode(hop, FakeTensorMode)
904919

905920
registered_hop_fake_fns[hop] = func
906921
return func

torch/utils/checkpoint.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,7 +1297,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
12971297

12981298
out = func(*args, **kwargs)
12991299

1300-
any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
1300+
# HOPs don't support func._schema
1301+
# HOPs don't alias -> this is always true today and will be always true for a long time
1302+
# TODO HOPs don't mutate -> this is always true today but will not be true forever
1303+
if isinstance(func, torch._ops.HigherOrderOperator):
1304+
any_ret_has_alias_info = False
1305+
else:
1306+
any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
13011307

13021308
if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
13031309
self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
@@ -1396,7 +1402,7 @@ def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mu
13961402
# context_fn anyway, so proceed as usual.
13971403
if isinstance(policy_fn_or_list, list):
13981404
for op in policy_fn_or_list:
1399-
if not isinstance(op, torch._ops.OpOverload):
1405+
if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
14001406
_extra_msg = (
14011407
"Please update the OpOverloadPacket to a specific OpOverload."
14021408
"For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."

0 commit comments

Comments
 (0)