@@ -245,19 +245,21 @@ def test_clip_no_gradients_norm_meta_device(
245
245
@unittest .skipIf (not torch .cuda .is_available (), "Skip when CUDA is not available" )
246
246
@instantiate_parametrized_tests
247
247
class TestGradientClippingDTensor (DTensorTestBase ):
248
+ """No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249
+
248
250
def _get_params_to_pg (
249
251
self , params : List [DTensor ]
250
252
) -> Dict [DTensor , List [ProcessGroup ]]:
251
253
return {param : [param .device_mesh .get_group ()] for param in params }
252
254
253
255
@with_comms
254
256
@parametrize ("norm_type" , ("inf" , 1 , 2 ))
255
- def test_dtensor_clip_all_gradients_norm (
257
+ def test_tensor_and_sharded_dtensor_clip_all_gradients_norm (
256
258
self , norm_type : Union [float , str ]
257
259
) -> None :
258
260
"""
259
261
Test to ensure that the gradient clipping optimizer clips gradients
260
- correctly with mixed DTensor and tensor by comparing gradients to its
262
+ correctly with mixed sharded DTensor and tensor by comparing gradients to its
261
263
torch.tensor counterpart.
262
264
263
265
Note that clipping for DTensor may require communication.
@@ -286,7 +288,7 @@ def test_dtensor_clip_all_gradients_norm(
286
288
ref_param_2 .grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
287
289
ref_gradient_clipping_optimizer .step ()
288
290
289
- # create gradient clipping optimizer containing both DTensor and tensor
291
+ # create gradient clipping optimizer containing sharded DTensor and tensor
290
292
device_mesh = init_device_mesh (self .device_type , (self .world_size ,))
291
293
param_1 = distribute_tensor (
292
294
torch .tensor ([1.0 , 2.0 , 3.0 ], requires_grad = True , device = self .device_type ),
@@ -336,3 +338,96 @@ def test_dtensor_clip_all_gradients_norm(
336
338
ref_param .grad ,
337
339
f"Expect gradient to be the same. However, found { param_grad = } , { ref_param .grad = } " ,
338
340
)
341
+
342
+ @with_comms
343
+ @parametrize ("norm_type" , ("inf" , 1 , 2 ))
344
+ def test_multiple_sharded_dtensors_clip_all_gradients_norm (
345
+ self , norm_type : Union [float , str ]
346
+ ) -> None :
347
+ """
348
+ Test to ensure that the gradient clipping optimizer clips gradients
349
+ correctly with multiple sharded DTensors by comparing gradients to their
350
+ torch.tensor counterpart.
351
+
352
+ Note that clipping for DTensor may require communication.
353
+ """
354
+
355
+ # create gradient clipping optimizer containing no dtensor for reference
356
+ ref_param_1 = torch .nn .Parameter (
357
+ torch .tensor ([1.0 , 2.0 , 3.0 ], device = self .device_type )
358
+ )
359
+ ref_param_2 = torch .nn .Parameter (
360
+ torch .tensor ([4.0 , 5.0 , 6.0 ], device = self .device_type )
361
+ )
362
+ ref_keyed_optimizer = DummyKeyedOptimizer (
363
+ {"param_1" : ref_param_1 , "param_2" : ref_param_2 },
364
+ {},
365
+ [{"params" : [ref_param_1 , ref_param_2 ]}],
366
+ )
367
+ ref_gradient_clipping_optimizer = GradientClippingOptimizer (
368
+ optimizer = ref_keyed_optimizer ,
369
+ clipping = GradientClipping .NORM ,
370
+ max_gradient = 10.0 ,
371
+ norm_type = norm_type ,
372
+ )
373
+ ref_gradient_clipping_optimizer .zero_grad ()
374
+ ref_param_1 .grad = torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type )
375
+ ref_param_2 .grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
376
+ ref_gradient_clipping_optimizer .step ()
377
+
378
+ # create gradient clipping optimizer containing 2 shareded DTensors
379
+ device_mesh = init_device_mesh (self .device_type , (self .world_size ,))
380
+ param_1 = distribute_tensor (
381
+ torch .tensor ([1.0 , 2.0 , 3.0 ], requires_grad = True , device = self .device_type ),
382
+ device_mesh ,
383
+ [Shard (0 )],
384
+ )
385
+ param_2 = distribute_tensor (
386
+ torch .tensor ([4.0 , 5.0 , 6.0 ], requires_grad = True , device = self .device_type ),
387
+ device_mesh ,
388
+ [Shard (0 )],
389
+ )
390
+ param_to_pgs = self ._get_params_to_pg ([param_1 , param_2 ])
391
+ keyed_optimizer = DummyKeyedOptimizer (
392
+ {"dtensor_param_1" : param_1 , "dtensor_param_2" : param_2 },
393
+ {},
394
+ [{"params" : [param_1 , param_2 ]}],
395
+ )
396
+ gradient_clipping_optimizer = GradientClippingOptimizer (
397
+ optimizer = keyed_optimizer ,
398
+ clipping = GradientClipping .NORM ,
399
+ max_gradient = 10.0 ,
400
+ norm_type = norm_type ,
401
+ enable_global_grad_clip = True ,
402
+ param_to_pgs = param_to_pgs , # pyre-ignore[6]
403
+ )
404
+ gradient_clipping_optimizer .zero_grad ()
405
+ param_1 .grad = distribute_tensor (
406
+ torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type ),
407
+ device_mesh ,
408
+ [Shard (0 )],
409
+ )
410
+ param_2 .grad = distribute_tensor (
411
+ torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type ),
412
+ device_mesh ,
413
+ [Shard (0 )],
414
+ )
415
+ gradient_clipping_optimizer .step ()
416
+
417
+ for param_group , ref_param_group in zip (
418
+ gradient_clipping_optimizer .param_groups ,
419
+ ref_gradient_clipping_optimizer .param_groups ,
420
+ ):
421
+ for param , ref_param in zip (
422
+ param_group ["params" ], ref_param_group ["params" ]
423
+ ):
424
+ param_grad = (
425
+ param .grad .full_tensor () # pyre-ignore[16]
426
+ if isinstance (param , DTensor )
427
+ else param .grad
428
+ )
429
+ self .assertEqual (
430
+ param_grad ,
431
+ ref_param .grad ,
432
+ f"Expect gradient to be the same. However, found { param_grad = } , { ref_param .grad = } " ,
433
+ )
0 commit comments