Skip to content

Bootcamp Task : Unit Tests Gradient Clipping for Dtensors #3253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 131 additions & 28 deletions torchrec/optim/tests/test_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,62 +245,73 @@ def test_clip_no_gradients_norm_meta_device(
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
@instantiate_parametrized_tests
class TestGradientClippingDTensor(DTensorTestBase):
"""No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""

def _get_params_to_pg(
self, params: List[DTensor]
) -> Dict[DTensor, List[ProcessGroup]]:
return {param: [param.device_mesh.get_group()] for param in params}

@with_comms
@parametrize("norm_type", ("inf", 1, 2))
def test_dtensor_clip_all_gradients_norm(
def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
self, norm_type: Union[float, str]
) -> None:
"""
Test to ensure that the gradient clipping optimizer clips gradients
correctly with mixed DTensor and tensor by comparing gradients to its
correctly with mixed sharded DTensor and tensor by comparing gradients to its
torch.tensor counterpart.
Note that clipping for DTensor may require communication.
"""

# data for testing clipping
data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)

# create gradient clipping optimizer containing no dtensor for reference
ref_param_1 = torch.nn.Parameter(
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
)
ref_param_2 = torch.nn.Parameter(
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
)
ref_param_1 = torch.nn.Parameter(data_1.clone())
ref_param_2 = torch.nn.Parameter(data_2.clone())
ref_param_1.grad = data_1_grad.clone()
ref_param_2.grad = data_2_grad.clone()
ref_keyed_optimizer = DummyKeyedOptimizer(
{"param_1": ref_param_1, "param_2": ref_param_2},
{},
[{"params": [ref_param_1, ref_param_2]}],
params={"param_1": ref_param_1, "param_2": ref_param_2},
state={},
param_groups=[{"params": [ref_param_1, ref_param_2]}],
)
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=ref_keyed_optimizer,
clipping=GradientClipping.NORM,
max_gradient=10.0,
norm_type=norm_type,
)
ref_gradient_clipping_optimizer.zero_grad()
ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
ref_gradient_clipping_optimizer.step()

# create gradient clipping optimizer containing both DTensor and tensor
# create gradient clipping optimizer containing a DTensor and a tensor
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
param_1 = distribute_tensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
device_mesh,
[Shard(0)],
tensor=torch.tensor(
data_1.clone(), requires_grad=True, device=self.device_type
),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_2 = torch.tensor(
[4.0, 5.0, 6.0], requires_grad=True, device=self.device_type
data_2.clone(), requires_grad=True, device=self.device_type
)
param_1.grad = distribute_tensor(
tensor=data_1_grad.clone(),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_2.grad = data_2_grad.clone()
param_to_pgs = self._get_params_to_pg([param_1])
keyed_optimizer = DummyKeyedOptimizer(
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
{},
[{"params": [param_1, param_2]}],
params={"dtensor_param_1": param_1, "dtensor_param_2": param_2},
state={},
param_groups=[{"params": [param_1, param_2]}],
)
gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=keyed_optimizer,
Expand All @@ -310,21 +321,113 @@ def test_dtensor_clip_all_gradients_norm(
enable_global_grad_clip=True,
param_to_pgs=param_to_pgs, # pyre-ignore[6]
)
gradient_clipping_optimizer.zero_grad()
gradient_clipping_optimizer.step()

for param_group, ref_param_group in zip(
gradient_clipping_optimizer.param_groups,
ref_gradient_clipping_optimizer.param_groups,
strict=True,
):
for param, ref_param in zip(
param_group["params"], ref_param_group["params"]
):
param_grad = (
param.grad.full_tensor() # pyre-ignore[16]
if isinstance(param, DTensor)
else param.grad
)
self.assertEqual(
param_grad,
ref_param.grad,
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
)

@with_comms
@parametrize("norm_type", ("inf", 1, 2))
def test_multiple_sharded_dtensors_clip_all_gradients_norm(
self, norm_type: Union[float, str]
) -> None:
"""
Test to ensure that the gradient clipping optimizer clips gradients
correctly with multiple sharded DTensors by comparing gradients to their
torch.tensor counterpart.
Note that clipping for DTensor may require communication.
"""

# data for testing clipping
data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)

# create gradient clipping optimizer containing no dtensor for reference
ref_param_1 = torch.nn.Parameter(data_1.clone())
ref_param_2 = torch.nn.Parameter(data_2.clone())
ref_param_1.grad = data_1_grad.clone()
ref_param_2.grad = data_2_grad.clone()
ref_keyed_optimizer = DummyKeyedOptimizer(
params={"param_1": ref_param_1, "param_2": ref_param_2},
state={},
param_groups=[{"params": [ref_param_1, ref_param_2]}],
)
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=ref_keyed_optimizer,
clipping=GradientClipping.NORM,
max_gradient=10.0,
norm_type=norm_type,
)
ref_gradient_clipping_optimizer.step()

# create gradient clipping optimizer containing 2 DTensors
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
param_1 = distribute_tensor(
tensor=torch.tensor(
data_1.clone(), requires_grad=True, device=self.device_type
),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_2 = distribute_tensor(
tensor=torch.tensor(
data_2.clone(), requires_grad=True, device=self.device_type
),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_1.grad = distribute_tensor(
torch.tensor([12.0, 15.0, 18.0], device=self.device_type),
device_mesh,
[Shard(0)],
tensor=data_1_grad.clone(),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_2.grad = distribute_tensor(
tensor=data_2_grad.clone(),
device_mesh=device_mesh,
placements=[Shard(0)],
)
param_to_pgs = self._get_params_to_pg([param_1, param_2])
keyed_optimizer = DummyKeyedOptimizer(
params={"dtensor_param_1": param_1, "dtensor_param_2": param_2},
state={},
param_groups=[{"params": [param_1, param_2]}],
)
gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=keyed_optimizer,
clipping=GradientClipping.NORM,
max_gradient=10.0,
norm_type=norm_type,
enable_global_grad_clip=True,
param_to_pgs=param_to_pgs, # pyre-ignore[6]
)
param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
gradient_clipping_optimizer.step()

for param_group, ref_param_group in zip(
gradient_clipping_optimizer.param_groups,
ref_gradient_clipping_optimizer.param_groups,
strict=True,
):
for param, ref_param in zip(
param_group["params"], ref_param_group["params"]
param_group["params"], ref_param_group["params"], strict=True
):
param_grad = (
param.grad.full_tensor() # pyre-ignore[16]
Expand Down
Loading