Skip to content

Commit 13e0bad

Browse files
Shagun Guptafacebook-github-bot
authored andcommitted
Bootcamp Task : Unit Tests Gradient Clipping for Dtensors (#3253)
Summary: Pull Request resolved: #3253 Implemented unit tests to include cases for 2 sharded Dtensors for norm based clipping. All test cases pass. Differential Revision: D79301301
1 parent 886ea8d commit 13e0bad

File tree

1 file changed

+131
-28
lines changed

1 file changed

+131
-28
lines changed

torchrec/optim/tests/test_clipping.py

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -245,62 +245,73 @@ def test_clip_no_gradients_norm_meta_device(
245245
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246246
@instantiate_parametrized_tests
247247
class TestGradientClippingDTensor(DTensorTestBase):
248+
"""No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249+
248250
def _get_params_to_pg(
249251
self, params: List[DTensor]
250252
) -> Dict[DTensor, List[ProcessGroup]]:
251253
return {param: [param.device_mesh.get_group()] for param in params}
252254

253255
@with_comms
254256
@parametrize("norm_type", ("inf", 1, 2))
255-
def test_dtensor_clip_all_gradients_norm(
257+
def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
256258
self, norm_type: Union[float, str]
257259
) -> None:
258260
"""
259261
Test to ensure that the gradient clipping optimizer clips gradients
260-
correctly with mixed DTensor and tensor by comparing gradients to its
262+
correctly with mixed sharded DTensor and tensor by comparing gradients to its
261263
torch.tensor counterpart.
262264
263265
Note that clipping for DTensor may require communication.
264266
"""
265267

268+
# data for testing clipping
269+
data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
270+
data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
271+
data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
272+
data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
273+
266274
# create gradient clipping optimizer containing no dtensor for reference
267-
ref_param_1 = torch.nn.Parameter(
268-
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
269-
)
270-
ref_param_2 = torch.nn.Parameter(
271-
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
272-
)
275+
ref_param_1 = torch.nn.Parameter(data_1.clone())
276+
ref_param_2 = torch.nn.Parameter(data_2.clone())
277+
ref_param_1.grad = data_1_grad.clone()
278+
ref_param_2.grad = data_2_grad.clone()
273279
ref_keyed_optimizer = DummyKeyedOptimizer(
274-
{"param_1": ref_param_1, "param_2": ref_param_2},
275-
{},
276-
[{"params": [ref_param_1, ref_param_2]}],
280+
params={"param_1": ref_param_1, "param_2": ref_param_2},
281+
state={},
282+
param_groups=[{"params": [ref_param_1, ref_param_2]}],
277283
)
278284
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
279285
optimizer=ref_keyed_optimizer,
280286
clipping=GradientClipping.NORM,
281287
max_gradient=10.0,
282288
norm_type=norm_type,
283289
)
284-
ref_gradient_clipping_optimizer.zero_grad()
285-
ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
286-
ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
287290
ref_gradient_clipping_optimizer.step()
288291

289-
# create gradient clipping optimizer containing both DTensor and tensor
292+
# create gradient clipping optimizer containing a DTensor and a tensor
290293
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
291294
param_1 = distribute_tensor(
292-
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
293-
device_mesh,
294-
[Shard(0)],
295+
tensor=torch.tensor(
296+
data_1.clone(), requires_grad=True, device=self.device_type
297+
),
298+
device_mesh=device_mesh,
299+
placements=[Shard(0)],
295300
)
296301
param_2 = torch.tensor(
297-
[4.0, 5.0, 6.0], requires_grad=True, device=self.device_type
302+
data_2.clone(), requires_grad=True, device=self.device_type
298303
)
304+
param_1.grad = distribute_tensor(
305+
tensor=data_1_grad.clone(),
306+
device_mesh=device_mesh,
307+
placements=[Shard(0)],
308+
)
309+
param_2.grad = data_2_grad.clone()
299310
param_to_pgs = self._get_params_to_pg([param_1])
300311
keyed_optimizer = DummyKeyedOptimizer(
301-
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
302-
{},
303-
[{"params": [param_1, param_2]}],
312+
params={"dtensor_param_1": param_1, "dtensor_param_2": param_2},
313+
state={},
314+
param_groups=[{"params": [param_1, param_2]}],
304315
)
305316
gradient_clipping_optimizer = GradientClippingOptimizer(
306317
optimizer=keyed_optimizer,
@@ -310,21 +321,113 @@ def test_dtensor_clip_all_gradients_norm(
310321
enable_global_grad_clip=True,
311322
param_to_pgs=param_to_pgs, # pyre-ignore[6]
312323
)
313-
gradient_clipping_optimizer.zero_grad()
324+
gradient_clipping_optimizer.step()
325+
326+
for param_group, ref_param_group in zip(
327+
gradient_clipping_optimizer.param_groups,
328+
ref_gradient_clipping_optimizer.param_groups,
329+
strict=True,
330+
):
331+
for param, ref_param in zip(
332+
param_group["params"], ref_param_group["params"]
333+
):
334+
param_grad = (
335+
param.grad.full_tensor() # pyre-ignore[16]
336+
if isinstance(param, DTensor)
337+
else param.grad
338+
)
339+
self.assertEqual(
340+
param_grad,
341+
ref_param.grad,
342+
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
343+
)
344+
345+
@with_comms
346+
@parametrize("norm_type", ("inf", 1, 2))
347+
def test_multiple_sharded_dtensors_clip_all_gradients_norm(
348+
self, norm_type: Union[float, str]
349+
) -> None:
350+
"""
351+
Test to ensure that the gradient clipping optimizer clips gradients
352+
correctly with multiple sharded DTensors by comparing gradients to their
353+
torch.tensor counterpart.
354+
355+
Note that clipping for DTensor may require communication.
356+
"""
357+
358+
# data for testing clipping
359+
data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
360+
data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
361+
data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
362+
data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
363+
364+
# create gradient clipping optimizer containing no dtensor for reference
365+
ref_param_1 = torch.nn.Parameter(data_1.clone())
366+
ref_param_2 = torch.nn.Parameter(data_2.clone())
367+
ref_param_1.grad = data_1_grad.clone()
368+
ref_param_2.grad = data_2_grad.clone()
369+
ref_keyed_optimizer = DummyKeyedOptimizer(
370+
params={"param_1": ref_param_1, "param_2": ref_param_2},
371+
state={},
372+
param_groups=[{"params": [ref_param_1, ref_param_2]}],
373+
)
374+
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
375+
optimizer=ref_keyed_optimizer,
376+
clipping=GradientClipping.NORM,
377+
max_gradient=10.0,
378+
norm_type=norm_type,
379+
)
380+
ref_gradient_clipping_optimizer.step()
381+
382+
# create gradient clipping optimizer containing 2 DTensors
383+
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
384+
param_1 = distribute_tensor(
385+
tensor=torch.tensor(
386+
data_1.clone(), requires_grad=True, device=self.device_type
387+
),
388+
device_mesh=device_mesh,
389+
placements=[Shard(0)],
390+
)
391+
param_2 = distribute_tensor(
392+
tensor=torch.tensor(
393+
data_2.clone(), requires_grad=True, device=self.device_type
394+
),
395+
device_mesh=device_mesh,
396+
placements=[Shard(0)],
397+
)
314398
param_1.grad = distribute_tensor(
315-
torch.tensor([12.0, 15.0, 18.0], device=self.device_type),
316-
device_mesh,
317-
[Shard(0)],
399+
tensor=data_1_grad.clone(),
400+
device_mesh=device_mesh,
401+
placements=[Shard(0)],
402+
)
403+
param_2.grad = distribute_tensor(
404+
tensor=data_2_grad.clone(),
405+
device_mesh=device_mesh,
406+
placements=[Shard(0)],
407+
)
408+
param_to_pgs = self._get_params_to_pg([param_1, param_2])
409+
keyed_optimizer = DummyKeyedOptimizer(
410+
params={"dtensor_param_1": param_1, "dtensor_param_2": param_2},
411+
state={},
412+
param_groups=[{"params": [param_1, param_2]}],
413+
)
414+
gradient_clipping_optimizer = GradientClippingOptimizer(
415+
optimizer=keyed_optimizer,
416+
clipping=GradientClipping.NORM,
417+
max_gradient=10.0,
418+
norm_type=norm_type,
419+
enable_global_grad_clip=True,
420+
param_to_pgs=param_to_pgs, # pyre-ignore[6]
318421
)
319-
param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
320422
gradient_clipping_optimizer.step()
321423

322424
for param_group, ref_param_group in zip(
323425
gradient_clipping_optimizer.param_groups,
324426
ref_gradient_clipping_optimizer.param_groups,
427+
strict=True,
325428
):
326429
for param, ref_param in zip(
327-
param_group["params"], ref_param_group["params"]
430+
param_group["params"], ref_param_group["params"], strict=True
328431
):
329432
param_grad = (
330433
param.grad.full_tensor() # pyre-ignore[16]

0 commit comments

Comments
 (0)