diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index 5f311459c..7a5660187 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -245,6 +245,8 @@ def test_clip_no_gradients_norm_meta_device( @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") @instantiate_parametrized_tests class TestGradientClippingDTensor(DTensorTestBase): + """No tests for Replicated DTensors as handled prior to GradientClippingOptimizer""" + def _get_params_to_pg( self, params: List[DTensor] ) -> Dict[DTensor, List[ProcessGroup]]: @@ -252,28 +254,32 @@ def _get_params_to_pg( @with_comms @parametrize("norm_type", ("inf", 1, 2)) - def test_dtensor_clip_all_gradients_norm( + def test_tensor_and_sharded_dtensor_clip_all_gradients_norm( self, norm_type: Union[float, str] ) -> None: """ Test to ensure that the gradient clipping optimizer clips gradients - correctly with mixed DTensor and tensor by comparing gradients to its + correctly with mixed sharded DTensor and tensor by comparing gradients to its torch.tensor counterpart. Note that clipping for DTensor may require communication. """ + # data for testing clipping + data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type) + data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type) + data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) + data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) + # create gradient clipping optimizer containing no dtensor for reference - ref_param_1 = torch.nn.Parameter( - torch.tensor([1.0, 2.0, 3.0], device=self.device_type) - ) - ref_param_2 = torch.nn.Parameter( - torch.tensor([4.0, 5.0, 6.0], device=self.device_type) - ) + ref_param_1 = torch.nn.Parameter(data_1.clone()) + ref_param_2 = torch.nn.Parameter(data_2.clone()) + ref_param_1.grad = data_1_grad.clone() + ref_param_2.grad = data_2_grad.clone() ref_keyed_optimizer = DummyKeyedOptimizer( - {"param_1": ref_param_1, "param_2": ref_param_2}, - {}, - [{"params": [ref_param_1, ref_param_2]}], + params={"param_1": ref_param_1, "param_2": ref_param_2}, + state={}, + param_groups=[{"params": [ref_param_1, ref_param_2]}], ) ref_gradient_clipping_optimizer = GradientClippingOptimizer( optimizer=ref_keyed_optimizer, @@ -281,26 +287,31 @@ def test_dtensor_clip_all_gradients_norm( max_gradient=10.0, norm_type=norm_type, ) - ref_gradient_clipping_optimizer.zero_grad() - ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) - ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) ref_gradient_clipping_optimizer.step() - # create gradient clipping optimizer containing both DTensor and tensor + # create gradient clipping optimizer containing a DTensor and a tensor device_mesh = init_device_mesh(self.device_type, (self.world_size,)) param_1 = distribute_tensor( - torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type), - device_mesh, - [Shard(0)], + tensor=torch.tensor( + data_1.clone(), requires_grad=True, device=self.device_type + ), + device_mesh=device_mesh, + placements=[Shard(0)], ) param_2 = torch.tensor( - [4.0, 5.0, 6.0], requires_grad=True, device=self.device_type + data_2.clone(), requires_grad=True, device=self.device_type ) + param_1.grad = distribute_tensor( + tensor=data_1_grad.clone(), + device_mesh=device_mesh, + placements=[Shard(0)], + ) + param_2.grad = data_2_grad.clone() param_to_pgs = self._get_params_to_pg([param_1]) keyed_optimizer = DummyKeyedOptimizer( - {"dtensor_param_1": param_1, "dtensor_param_2": param_2}, - {}, - [{"params": [param_1, param_2]}], + params={"dtensor_param_1": param_1, "dtensor_param_2": param_2}, + state={}, + param_groups=[{"params": [param_1, param_2]}], ) gradient_clipping_optimizer = GradientClippingOptimizer( optimizer=keyed_optimizer, @@ -310,21 +321,113 @@ def test_dtensor_clip_all_gradients_norm( enable_global_grad_clip=True, param_to_pgs=param_to_pgs, # pyre-ignore[6] ) - gradient_clipping_optimizer.zero_grad() + gradient_clipping_optimizer.step() + + for param_group, ref_param_group in zip( + gradient_clipping_optimizer.param_groups, + ref_gradient_clipping_optimizer.param_groups, + strict=True, + ): + for param, ref_param in zip( + param_group["params"], ref_param_group["params"] + ): + param_grad = ( + param.grad.full_tensor() # pyre-ignore[16] + if isinstance(param, DTensor) + else param.grad + ) + self.assertEqual( + param_grad, + ref_param.grad, + f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}", + ) + + @with_comms + @parametrize("norm_type", ("inf", 1, 2)) + def test_multiple_sharded_dtensors_clip_all_gradients_norm( + self, norm_type: Union[float, str] + ) -> None: + """ + Test to ensure that the gradient clipping optimizer clips gradients + correctly with multiple sharded DTensors by comparing gradients to their + torch.tensor counterpart. + + Note that clipping for DTensor may require communication. + """ + + # data for testing clipping + data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type) + data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type) + data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) + data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) + + # create gradient clipping optimizer containing no dtensor for reference + ref_param_1 = torch.nn.Parameter(data_1.clone()) + ref_param_2 = torch.nn.Parameter(data_2.clone()) + ref_param_1.grad = data_1_grad.clone() + ref_param_2.grad = data_2_grad.clone() + ref_keyed_optimizer = DummyKeyedOptimizer( + params={"param_1": ref_param_1, "param_2": ref_param_2}, + state={}, + param_groups=[{"params": [ref_param_1, ref_param_2]}], + ) + ref_gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=ref_keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + ) + ref_gradient_clipping_optimizer.step() + + # create gradient clipping optimizer containing 2 DTensors + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + param_1 = distribute_tensor( + tensor=torch.tensor( + data_1.clone(), requires_grad=True, device=self.device_type + ), + device_mesh=device_mesh, + placements=[Shard(0)], + ) + param_2 = distribute_tensor( + tensor=torch.tensor( + data_2.clone(), requires_grad=True, device=self.device_type + ), + device_mesh=device_mesh, + placements=[Shard(0)], + ) param_1.grad = distribute_tensor( - torch.tensor([12.0, 15.0, 18.0], device=self.device_type), - device_mesh, - [Shard(0)], + tensor=data_1_grad.clone(), + device_mesh=device_mesh, + placements=[Shard(0)], + ) + param_2.grad = distribute_tensor( + tensor=data_2_grad.clone(), + device_mesh=device_mesh, + placements=[Shard(0)], + ) + param_to_pgs = self._get_params_to_pg([param_1, param_2]) + keyed_optimizer = DummyKeyedOptimizer( + params={"dtensor_param_1": param_1, "dtensor_param_2": param_2}, + state={}, + param_groups=[{"params": [param_1, param_2]}], + ) + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + enable_global_grad_clip=True, + param_to_pgs=param_to_pgs, # pyre-ignore[6] ) - param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) gradient_clipping_optimizer.step() for param_group, ref_param_group in zip( gradient_clipping_optimizer.param_groups, ref_gradient_clipping_optimizer.param_groups, + strict=True, ): for param, ref_param in zip( - param_group["params"], ref_param_group["params"] + param_group["params"], ref_param_group["params"], strict=True ): param_grad = ( param.grad.full_tensor() # pyre-ignore[16]