@@ -3855,6 +3855,302 @@ def test_associative_scan_different_input_size_wrong_dim(self):
38553855 combine_mode = "pointwise" ,
38563856 )
38573857
3858+ @unittest .skipIf (not SM70OrLater , "triton" )
3859+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
3860+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
3861+ @parametrize ("combine_mode" , ["pointwise" , "generic" ])
3862+ @parametrize ("reverse" , [False , True ])
3863+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
3864+ # Skipping the combine_mode=pointwise
3865+ # as the current implementation of associative_scan lowering
3866+ # does not support lifted arguments
3867+ @decorateIf (
3868+ unittest .skip ,
3869+ lambda params : (params ["combine_mode" ] == "pointwise" ),
3870+ )
3871+ def test_associative_scan_freevars_simple (
3872+ self , compile_mode , combine_mode , reverse , device
3873+ ):
3874+ H = torch .rand (2 , device = device )
3875+
3876+ def fct_freevars1 (x : torch .Tensor , y : torch .Tensor ):
3877+ return x * H + y * 2
3878+
3879+ def fct_freevars2 (x : torch .Tensor , y : torch .Tensor ):
3880+ return x * H + y * H
3881+
3882+ H1 = torch .rand (1 , device = device )
3883+ H2 = torch .rand (1 , device = device )
3884+
3885+ def fct_freevars3 (x : torch .Tensor , y : torch .Tensor ):
3886+ return x * H1 + y * H2
3887+
3888+ inp = torch .randn (3 , 2 , 2 , device = device )
3889+
3890+ for fct , param in [
3891+ (fct_freevars1 , (H ,)),
3892+ (fct_freevars2 , (H ,)),
3893+ (fct_freevars3 , (H1 , H2 )),
3894+ ]:
3895+ kwargs = {
3896+ "dim" : 0 ,
3897+ "reverse" : reverse ,
3898+ "compile_mode" : compile_mode ,
3899+ "combine_fn" : fct ,
3900+ "combine_mode" : combine_mode ,
3901+ }
3902+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
3903+ self ._run_test (
3904+ model = AssociativeScanModels .CombineFn (** kwargs ),
3905+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
3906+ inputs = inp ,
3907+ )
3908+
3909+ @unittest .skipIf (not SM70OrLater , "triton" )
3910+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
3911+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
3912+ @parametrize ("combine_mode" , ["pointwise" , "generic" ])
3913+ @parametrize ("reverse" , [False , True ])
3914+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
3915+ # Skipping the combine_mode=pointwise
3916+ # as the current implementation of associative_scan lowering
3917+ # does not support lifted arguments
3918+ @decorateIf (
3919+ unittest .skip ,
3920+ lambda params : (params ["combine_mode" ] == "pointwise" ),
3921+ )
3922+ def test_associative_scan_freevars_nested (
3923+ self , compile_mode , combine_mode , reverse , device
3924+ ):
3925+ H1 = torch .rand (4 , 5 , device = device )
3926+ H2 = torch .rand (4 , 1 , device = device )
3927+
3928+ def fct_nested_outside (x : torch .Tensor , y : torch .Tensor ):
3929+ def inner (xi ):
3930+ return xi * H2
3931+
3932+ ret = inner (y )
3933+ return x + ret * H1
3934+
3935+ def fct_nested_outside_fake (x : torch .Tensor , y : torch .Tensor ):
3936+ def inner (xi ):
3937+ return xi * H2
3938+
3939+ ret = inner (y )
3940+ return x + ret * H1
3941+
3942+ H1_i = torch .rand (4 , 5 , device = device )
3943+
3944+ # TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
3945+ # RuntimeError: vmap: called random operation while in randomness error mode.
3946+ # Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
3947+ def fct_nested_inside (x : torch .Tensor , y : torch .Tensor ):
3948+ # H2_i = torch.rand(4, 1, device=device)
3949+ H2_i = torch .ones (4 , 1 , device = device ) * 42
3950+
3951+ def inner (xi ):
3952+ return xi * H2_i
3953+
3954+ ret = inner (y )
3955+ return x + ret * H1
3956+
3957+ def fct_nested_inside_fake (x : torch .Tensor , y : torch .Tensor ):
3958+ # H2_i = torch.rand(4, 1, device=device)
3959+ H2_i = torch .ones (4 , 1 , device = device ) * 42
3960+
3961+ def inner (xi ):
3962+ return xi * H2_i
3963+
3964+ ret = inner (y )
3965+ return x + ret * H1
3966+
3967+ inp = torch .randn (3 , 4 , 5 , device = device )
3968+
3969+ for fct , fct_fake , param in [
3970+ (fct_nested_outside , fct_nested_outside_fake , (H1 , H2 )),
3971+ (fct_nested_inside , fct_nested_inside_fake , (H1_i ,)),
3972+ ]:
3973+ kwargs = {
3974+ "dim" : 0 ,
3975+ "reverse" : reverse ,
3976+ "compile_mode" : compile_mode ,
3977+ "combine_fn" : fct ,
3978+ "combine_mode" : combine_mode ,
3979+ }
3980+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
3981+ kwargs_fake ["combine_fn" ] = fct_fake
3982+ self ._run_test (
3983+ model = AssociativeScanModels .CombineFn (** kwargs ),
3984+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
3985+ inputs = inp ,
3986+ )
3987+
3988+ @unittest .skipIf (not SM70OrLater , "triton" )
3989+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
3990+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
3991+ @parametrize ("combine_mode" , ["pointwise" , "generic" ])
3992+ @parametrize ("reverse" , [False , True ])
3993+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
3994+ # Skipping the combine_mode=pointwise
3995+ # as the current implementation of associative_scan lowering
3996+ # does not support lifted arguments
3997+ @decorateIf (
3998+ unittest .skip ,
3999+ lambda params : (params ["combine_mode" ] == "pointwise" ),
4000+ )
4001+ def test_associative_scan_freevars_fct (
4002+ self , compile_mode , combine_mode , reverse , device
4003+ ):
4004+ def additional_fct_no_add_inp (x , y ):
4005+ return x * y
4006+
4007+ def fct_nested_outside (x : torch .Tensor , y : torch .Tensor ):
4008+ ret = additional_fct_no_add_inp (y , y )
4009+ return x + ret
4010+
4011+ inp = torch .randn (3 , 4 , 5 , device = device )
4012+
4013+ kwargs = {
4014+ "dim" : 0 ,
4015+ "reverse" : reverse ,
4016+ "compile_mode" : compile_mode ,
4017+ "combine_fn" : fct_nested_outside ,
4018+ "combine_mode" : combine_mode ,
4019+ }
4020+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
4021+ self ._run_test (
4022+ model = AssociativeScanModels .CombineFn (** kwargs ),
4023+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
4024+ inputs = inp ,
4025+ )
4026+
4027+ @unittest .skipIf (not SM70OrLater , "triton" )
4028+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
4029+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
4030+ @parametrize ("reverse" , [False , True ])
4031+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
4032+ def test_associative_scan_freevars_fct_generic (self , compile_mode , reverse , device ):
4033+ def additional_fct_no_add_inp (x , y ):
4034+ return x * y
4035+
4036+ def fct_nested_outside (x : torch .Tensor , y : torch .Tensor ):
4037+ ret = associative_scan (
4038+ additional_fct_no_add_inp , y , 1 , combine_mode = "generic"
4039+ )
4040+ return x + ret
4041+
4042+ def fct_nested_outside_fake (x : torch .Tensor , y : torch .Tensor ):
4043+ ret = _fake_associative_scan (additional_fct_no_add_inp , y , 1 )
4044+ return x + ret
4045+
4046+ inp = torch .randn (3 , 4 , 5 , device = device )
4047+
4048+ kwargs = {
4049+ "dim" : 0 ,
4050+ "reverse" : reverse ,
4051+ "compile_mode" : compile_mode ,
4052+ "combine_fn" : fct_nested_outside ,
4053+ "combine_mode" : "generic" ,
4054+ }
4055+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
4056+ kwargs_fake ["combine_fn" ] = fct_nested_outside_fake
4057+ self ._run_test (
4058+ model = AssociativeScanModels .CombineFn (** kwargs ),
4059+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
4060+ inputs = inp ,
4061+ )
4062+
4063+ @unittest .skipIf (not SM70OrLater , "triton" )
4064+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
4065+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
4066+ @parametrize ("combine_mode" , ["pointwise" , "generic" ])
4067+ @parametrize ("reverse" , [False , True ])
4068+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
4069+ # Skipping the combine_mode=pointwise
4070+ # as the current implementation of associative_scan lowering
4071+ # does not support lifted arguments
4072+ @decorateIf (
4073+ unittest .skip ,
4074+ lambda params : (params ["combine_mode" ] == "pointwise" ),
4075+ )
4076+ def test_associative_scan_freevars_shape_check (
4077+ self , compile_mode , combine_mode , reverse , device
4078+ ):
4079+ H = torch .eye (2 , device = device , requires_grad = True )
4080+
4081+ def fct_freevars (x : torch .Tensor , y : torch .Tensor ):
4082+ return x @ H + y
4083+
4084+ inp = torch .randn (2 , 2 , 3 , device = device , requires_grad = True )
4085+
4086+ kwargs = {
4087+ "dim" : 2 ,
4088+ "reverse" : reverse ,
4089+ "compile_mode" : compile_mode ,
4090+ "combine_fn" : fct_freevars ,
4091+ "combine_mode" : combine_mode ,
4092+ }
4093+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
4094+ self ._run_test (
4095+ model = AssociativeScanModels .CombineFn (** kwargs ),
4096+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
4097+ inputs = inp ,
4098+ )
4099+
4100+ @unittest .skipIf (not SM70OrLater , "triton" )
4101+ @unittest .skipIf (not torch .cuda .is_available (), "Test requires CUDA." )
4102+ @parametrize ("compile_mode" , ["none" , "eager" , "compile" , "compile_dynamic_shape" ])
4103+ @parametrize ("reverse" , [False , True ])
4104+ @parametrize ("device" , [torch .device ("cpu" ), torch .device ("cuda" )])
4105+ @parametrize ("combine_mode" , ["pointwise" , "generic" ])
4106+ # Skipping the combine_mode=pointwise
4107+ # as the current implementation of associative_scan lowering
4108+ # does not support lifted arguments
4109+ @decorateIf (
4110+ unittest .skip ,
4111+ lambda params : (params ["combine_mode" ] == "pointwise" ),
4112+ )
4113+ def test_associative_scan_freevars_pytree (
4114+ self , compile_mode , combine_mode , reverse , device
4115+ ):
4116+ xf = torch .randn (2 , 2 , device = device , requires_grad = True )
4117+ yf = torch .randn (2 , 2 , device = device , requires_grad = True )
4118+ zf = torch .randn (2 , 2 , device = device , requires_grad = True )
4119+ inpf = {"i" : xf , "j" : ([yf ], [{"o" : zf }])}
4120+
4121+ def fct_pointwise (x , y ):
4122+ return {
4123+ "i" : (x ["i" ] * y ["i" ]) + inpf ["i" ],
4124+ "j" : (
4125+ [(x ["j" ][0 ][0 ] * y ["j" ][0 ][0 ]) + inpf ["j" ][0 ][0 ]],
4126+ [
4127+ {
4128+ "o" : (x ["j" ][1 ][0 ]["o" ] + y ["j" ][1 ][0 ]["o" ])
4129+ + inpf ["j" ][1 ][0 ]["o" ]
4130+ }
4131+ ],
4132+ ),
4133+ }
4134+
4135+ x = torch .randn (3 , 2 , 2 , device = device , requires_grad = True )
4136+ y = torch .randn (3 , 2 , 2 , device = device , requires_grad = True )
4137+ z = torch .randn (3 , 2 , 2 , device = device , requires_grad = True )
4138+ inp = {"i" : x , "j" : ([y ], [{"o" : z }])}
4139+
4140+ kwargs = {
4141+ "dim" : 0 ,
4142+ "reverse" : reverse ,
4143+ "compile_mode" : compile_mode ,
4144+ "combine_fn" : fct_pointwise ,
4145+ "combine_mode" : combine_mode ,
4146+ }
4147+ kwargs_fake = self ._prepare_fake_kwargs (kwargs )
4148+ self ._run_test (
4149+ model = AssociativeScanModels .CombineFn (** kwargs ),
4150+ model_fake = AssociativeScanModels .CombineFn (** kwargs_fake ),
4151+ inputs = inp ,
4152+ )
4153+
38584154 @unittest .skipIf (not SM70OrLater , "triton" )
38594155 @requires_cuda
38604156 def test_associative_scan_sparse_tensor (self ):
0 commit comments