Skip to content

Commit 3a29992

Browse files
bohnstinglpytorchmergebot
authored andcommitted
[associative_scan] Lifted arguments (pytorch#140043)
This PR implements lifted arguments for associative_scan Pull Request resolved: pytorch#140043 Approved by: https://github.com/ydwu4
1 parent f59a56e commit 3a29992

File tree

5 files changed

+431
-31
lines changed

5 files changed

+431
-31
lines changed

test/export/test_export.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_param,
3030
register_dataclass_as_pytree_node,
3131
)
32+
from torch._higher_order_ops.associative_scan import associative_scan
3233
from torch._higher_order_ops.hints_wrap import hints_wrapper
3334
from torch._inductor.compile_fx import split_const_gm
3435
from torch._subclasses import FakeTensorMode
@@ -6379,6 +6380,72 @@ def forward(self, x):
63796380
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1
63806381
)
63816382

6383+
@requires_gpu
6384+
def test_export_associative_scan_symbol_dim(self):
6385+
dim1 = torch.export.Dim("dim0", min=5, max=15)
6386+
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
6387+
6388+
class Foo(torch.nn.Module):
6389+
def __init__(self) -> None:
6390+
super().__init__()
6391+
6392+
def combine_fn(self, x, y):
6393+
return x + y
6394+
6395+
def forward(self, x):
6396+
return associative_scan(self.combine_fn, x, 2)
6397+
6398+
ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
6399+
self.assertTrue(torch.allclose(ep.module()(xs), Foo()(xs)))
6400+
6401+
@requires_gpu
6402+
def test_export_associative_scan_symbol_scandim(self):
6403+
dim1 = torch.export.Dim("dim0", min=5, max=15)
6404+
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
6405+
6406+
class Foo(torch.nn.Module):
6407+
def __init__(self) -> None:
6408+
super().__init__()
6409+
6410+
def combine_fn(self, x, y):
6411+
return x + y
6412+
6413+
def forward(self, x):
6414+
return associative_scan(self.combine_fn, x, 1)
6415+
6416+
ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
6417+
self.assertTrue(torch.allclose(ep.module()(xs), Foo()(xs)))
6418+
6419+
# This test is expected to fail because accociative_scan's backend is not set to "eager"
6420+
@unittest.expectedFailure
6421+
@requires_gpu
6422+
def test_export_associative_scan_lifted_buffers(self):
6423+
class M(torch.nn.Module):
6424+
def __init__(self) -> None:
6425+
super().__init__()
6426+
self.buffer = torch.nn.Buffer(
6427+
torch.ones(3, 2, device=torch.device("cuda"))
6428+
)
6429+
6430+
def combine_fn(self, x, y):
6431+
return (x + y) * self.buffer
6432+
6433+
def forward(self, x):
6434+
# TODO: need combine_mode='generic' here as lifted arguments are not yet supported in inductor
6435+
return associative_scan(self.combine_fn, x, 1, combine_mode="pointwise")
6436+
6437+
inp = torch.ones(3, 10, 2, device=torch.device("cuda"))
6438+
ep = export(M(), (inp,))
6439+
epm = ep.module()
6440+
self.assertTrue(torch.allclose(epm(inp), M()(inp)))
6441+
6442+
for gm in epm.named_modules():
6443+
if not isinstance(gm, torch.fx.GraphModule):
6444+
continue
6445+
self.assertEqual(
6446+
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1
6447+
)
6448+
63826449
# map_fn references module outside the module hierarchy
63836450
@unittest.expectedFailure
63846451
def test_map_buffers(self):

test/functorch/test_control_flow.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3855,6 +3855,302 @@ def test_associative_scan_different_input_size_wrong_dim(self):
38553855
combine_mode="pointwise",
38563856
)
38573857

3858+
@unittest.skipIf(not SM70OrLater, "triton")
3859+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
3860+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
3861+
@parametrize("combine_mode", ["pointwise", "generic"])
3862+
@parametrize("reverse", [False, True])
3863+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
3864+
# Skipping the combine_mode=pointwise
3865+
# as the current implementation of associative_scan lowering
3866+
# does not support lifted arguments
3867+
@decorateIf(
3868+
unittest.skip,
3869+
lambda params: (params["combine_mode"] == "pointwise"),
3870+
)
3871+
def test_associative_scan_freevars_simple(
3872+
self, compile_mode, combine_mode, reverse, device
3873+
):
3874+
H = torch.rand(2, device=device)
3875+
3876+
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
3877+
return x * H + y * 2
3878+
3879+
def fct_freevars2(x: torch.Tensor, y: torch.Tensor):
3880+
return x * H + y * H
3881+
3882+
H1 = torch.rand(1, device=device)
3883+
H2 = torch.rand(1, device=device)
3884+
3885+
def fct_freevars3(x: torch.Tensor, y: torch.Tensor):
3886+
return x * H1 + y * H2
3887+
3888+
inp = torch.randn(3, 2, 2, device=device)
3889+
3890+
for fct, param in [
3891+
(fct_freevars1, (H,)),
3892+
(fct_freevars2, (H,)),
3893+
(fct_freevars3, (H1, H2)),
3894+
]:
3895+
kwargs = {
3896+
"dim": 0,
3897+
"reverse": reverse,
3898+
"compile_mode": compile_mode,
3899+
"combine_fn": fct,
3900+
"combine_mode": combine_mode,
3901+
}
3902+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
3903+
self._run_test(
3904+
model=AssociativeScanModels.CombineFn(**kwargs),
3905+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
3906+
inputs=inp,
3907+
)
3908+
3909+
@unittest.skipIf(not SM70OrLater, "triton")
3910+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
3911+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
3912+
@parametrize("combine_mode", ["pointwise", "generic"])
3913+
@parametrize("reverse", [False, True])
3914+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
3915+
# Skipping the combine_mode=pointwise
3916+
# as the current implementation of associative_scan lowering
3917+
# does not support lifted arguments
3918+
@decorateIf(
3919+
unittest.skip,
3920+
lambda params: (params["combine_mode"] == "pointwise"),
3921+
)
3922+
def test_associative_scan_freevars_nested(
3923+
self, compile_mode, combine_mode, reverse, device
3924+
):
3925+
H1 = torch.rand(4, 5, device=device)
3926+
H2 = torch.rand(4, 1, device=device)
3927+
3928+
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
3929+
def inner(xi):
3930+
return xi * H2
3931+
3932+
ret = inner(y)
3933+
return x + ret * H1
3934+
3935+
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
3936+
def inner(xi):
3937+
return xi * H2
3938+
3939+
ret = inner(y)
3940+
return x + ret * H1
3941+
3942+
H1_i = torch.rand(4, 5, device=device)
3943+
3944+
# TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
3945+
# RuntimeError: vmap: called random operation while in randomness error mode.
3946+
# Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
3947+
def fct_nested_inside(x: torch.Tensor, y: torch.Tensor):
3948+
# H2_i = torch.rand(4, 1, device=device)
3949+
H2_i = torch.ones(4, 1, device=device) * 42
3950+
3951+
def inner(xi):
3952+
return xi * H2_i
3953+
3954+
ret = inner(y)
3955+
return x + ret * H1
3956+
3957+
def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor):
3958+
# H2_i = torch.rand(4, 1, device=device)
3959+
H2_i = torch.ones(4, 1, device=device) * 42
3960+
3961+
def inner(xi):
3962+
return xi * H2_i
3963+
3964+
ret = inner(y)
3965+
return x + ret * H1
3966+
3967+
inp = torch.randn(3, 4, 5, device=device)
3968+
3969+
for fct, fct_fake, param in [
3970+
(fct_nested_outside, fct_nested_outside_fake, (H1, H2)),
3971+
(fct_nested_inside, fct_nested_inside_fake, (H1_i,)),
3972+
]:
3973+
kwargs = {
3974+
"dim": 0,
3975+
"reverse": reverse,
3976+
"compile_mode": compile_mode,
3977+
"combine_fn": fct,
3978+
"combine_mode": combine_mode,
3979+
}
3980+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
3981+
kwargs_fake["combine_fn"] = fct_fake
3982+
self._run_test(
3983+
model=AssociativeScanModels.CombineFn(**kwargs),
3984+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
3985+
inputs=inp,
3986+
)
3987+
3988+
@unittest.skipIf(not SM70OrLater, "triton")
3989+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
3990+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
3991+
@parametrize("combine_mode", ["pointwise", "generic"])
3992+
@parametrize("reverse", [False, True])
3993+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
3994+
# Skipping the combine_mode=pointwise
3995+
# as the current implementation of associative_scan lowering
3996+
# does not support lifted arguments
3997+
@decorateIf(
3998+
unittest.skip,
3999+
lambda params: (params["combine_mode"] == "pointwise"),
4000+
)
4001+
def test_associative_scan_freevars_fct(
4002+
self, compile_mode, combine_mode, reverse, device
4003+
):
4004+
def additional_fct_no_add_inp(x, y):
4005+
return x * y
4006+
4007+
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
4008+
ret = additional_fct_no_add_inp(y, y)
4009+
return x + ret
4010+
4011+
inp = torch.randn(3, 4, 5, device=device)
4012+
4013+
kwargs = {
4014+
"dim": 0,
4015+
"reverse": reverse,
4016+
"compile_mode": compile_mode,
4017+
"combine_fn": fct_nested_outside,
4018+
"combine_mode": combine_mode,
4019+
}
4020+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
4021+
self._run_test(
4022+
model=AssociativeScanModels.CombineFn(**kwargs),
4023+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
4024+
inputs=inp,
4025+
)
4026+
4027+
@unittest.skipIf(not SM70OrLater, "triton")
4028+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
4029+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
4030+
@parametrize("reverse", [False, True])
4031+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
4032+
def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device):
4033+
def additional_fct_no_add_inp(x, y):
4034+
return x * y
4035+
4036+
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
4037+
ret = associative_scan(
4038+
additional_fct_no_add_inp, y, 1, combine_mode="generic"
4039+
)
4040+
return x + ret
4041+
4042+
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
4043+
ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1)
4044+
return x + ret
4045+
4046+
inp = torch.randn(3, 4, 5, device=device)
4047+
4048+
kwargs = {
4049+
"dim": 0,
4050+
"reverse": reverse,
4051+
"compile_mode": compile_mode,
4052+
"combine_fn": fct_nested_outside,
4053+
"combine_mode": "generic",
4054+
}
4055+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
4056+
kwargs_fake["combine_fn"] = fct_nested_outside_fake
4057+
self._run_test(
4058+
model=AssociativeScanModels.CombineFn(**kwargs),
4059+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
4060+
inputs=inp,
4061+
)
4062+
4063+
@unittest.skipIf(not SM70OrLater, "triton")
4064+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
4065+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
4066+
@parametrize("combine_mode", ["pointwise", "generic"])
4067+
@parametrize("reverse", [False, True])
4068+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
4069+
# Skipping the combine_mode=pointwise
4070+
# as the current implementation of associative_scan lowering
4071+
# does not support lifted arguments
4072+
@decorateIf(
4073+
unittest.skip,
4074+
lambda params: (params["combine_mode"] == "pointwise"),
4075+
)
4076+
def test_associative_scan_freevars_shape_check(
4077+
self, compile_mode, combine_mode, reverse, device
4078+
):
4079+
H = torch.eye(2, device=device, requires_grad=True)
4080+
4081+
def fct_freevars(x: torch.Tensor, y: torch.Tensor):
4082+
return x @ H + y
4083+
4084+
inp = torch.randn(2, 2, 3, device=device, requires_grad=True)
4085+
4086+
kwargs = {
4087+
"dim": 2,
4088+
"reverse": reverse,
4089+
"compile_mode": compile_mode,
4090+
"combine_fn": fct_freevars,
4091+
"combine_mode": combine_mode,
4092+
}
4093+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
4094+
self._run_test(
4095+
model=AssociativeScanModels.CombineFn(**kwargs),
4096+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
4097+
inputs=inp,
4098+
)
4099+
4100+
@unittest.skipIf(not SM70OrLater, "triton")
4101+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
4102+
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
4103+
@parametrize("reverse", [False, True])
4104+
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
4105+
@parametrize("combine_mode", ["pointwise", "generic"])
4106+
# Skipping the combine_mode=pointwise
4107+
# as the current implementation of associative_scan lowering
4108+
# does not support lifted arguments
4109+
@decorateIf(
4110+
unittest.skip,
4111+
lambda params: (params["combine_mode"] == "pointwise"),
4112+
)
4113+
def test_associative_scan_freevars_pytree(
4114+
self, compile_mode, combine_mode, reverse, device
4115+
):
4116+
xf = torch.randn(2, 2, device=device, requires_grad=True)
4117+
yf = torch.randn(2, 2, device=device, requires_grad=True)
4118+
zf = torch.randn(2, 2, device=device, requires_grad=True)
4119+
inpf = {"i": xf, "j": ([yf], [{"o": zf}])}
4120+
4121+
def fct_pointwise(x, y):
4122+
return {
4123+
"i": (x["i"] * y["i"]) + inpf["i"],
4124+
"j": (
4125+
[(x["j"][0][0] * y["j"][0][0]) + inpf["j"][0][0]],
4126+
[
4127+
{
4128+
"o": (x["j"][1][0]["o"] + y["j"][1][0]["o"])
4129+
+ inpf["j"][1][0]["o"]
4130+
}
4131+
],
4132+
),
4133+
}
4134+
4135+
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
4136+
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
4137+
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
4138+
inp = {"i": x, "j": ([y], [{"o": z}])}
4139+
4140+
kwargs = {
4141+
"dim": 0,
4142+
"reverse": reverse,
4143+
"compile_mode": compile_mode,
4144+
"combine_fn": fct_pointwise,
4145+
"combine_mode": combine_mode,
4146+
}
4147+
kwargs_fake = self._prepare_fake_kwargs(kwargs)
4148+
self._run_test(
4149+
model=AssociativeScanModels.CombineFn(**kwargs),
4150+
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
4151+
inputs=inp,
4152+
)
4153+
38584154
@unittest.skipIf(not SM70OrLater, "triton")
38594155
@requires_cuda
38604156
def test_associative_scan_sparse_tensor(self):

0 commit comments

Comments
 (0)