@@ -245,62 +245,73 @@ def test_clip_no_gradients_norm_meta_device(
245
245
@unittest .skipIf (not torch .cuda .is_available (), "Skip when CUDA is not available" )
246
246
@instantiate_parametrized_tests
247
247
class TestGradientClippingDTensor (DTensorTestBase ):
248
+ """No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249
+
248
250
def _get_params_to_pg (
249
251
self , params : List [DTensor ]
250
252
) -> Dict [DTensor , List [ProcessGroup ]]:
251
253
return {param : [param .device_mesh .get_group ()] for param in params }
252
254
253
255
@with_comms
254
256
@parametrize ("norm_type" , ("inf" , 1 , 2 ))
255
- def test_dtensor_clip_all_gradients_norm (
257
+ def test_tensor_and_sharded_dtensor_clip_all_gradients_norm (
256
258
self , norm_type : Union [float , str ]
257
259
) -> None :
258
260
"""
259
261
Test to ensure that the gradient clipping optimizer clips gradients
260
- correctly with mixed DTensor and tensor by comparing gradients to its
262
+ correctly with mixed sharded DTensor and tensor by comparing gradients to its
261
263
torch.tensor counterpart.
262
264
263
265
Note that clipping for DTensor may require communication.
264
266
"""
265
267
268
+ # data for testing clipping
269
+ data_1 = torch .tensor ([1.0 , 2.0 , 3.0 ], device = self .device_type )
270
+ data_2 = torch .tensor ([4.0 , 5.0 , 6.0 ], device = self .device_type )
271
+ data_1_grad = torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type )
272
+ data_2_grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
273
+
266
274
# create gradient clipping optimizer containing no dtensor for reference
267
- ref_param_1 = torch .nn .Parameter (
268
- torch .tensor ([1.0 , 2.0 , 3.0 ], device = self .device_type )
269
- )
270
- ref_param_2 = torch .nn .Parameter (
271
- torch .tensor ([4.0 , 5.0 , 6.0 ], device = self .device_type )
272
- )
275
+ ref_param_1 = torch .nn .Parameter (data_1 .clone ())
276
+ ref_param_2 = torch .nn .Parameter (data_2 .clone ())
277
+ ref_param_1 .grad = data_1_grad .clone ()
278
+ ref_param_2 .grad = data_2_grad .clone ()
273
279
ref_keyed_optimizer = DummyKeyedOptimizer (
274
- {"param_1" : ref_param_1 , "param_2" : ref_param_2 },
275
- {},
276
- [{"params" : [ref_param_1 , ref_param_2 ]}],
280
+ params = {"param_1" : ref_param_1 , "param_2" : ref_param_2 },
281
+ state = {},
282
+ param_groups = [{"params" : [ref_param_1 , ref_param_2 ]}],
277
283
)
278
284
ref_gradient_clipping_optimizer = GradientClippingOptimizer (
279
285
optimizer = ref_keyed_optimizer ,
280
286
clipping = GradientClipping .NORM ,
281
287
max_gradient = 10.0 ,
282
288
norm_type = norm_type ,
283
289
)
284
- ref_gradient_clipping_optimizer .zero_grad ()
285
- ref_param_1 .grad = torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type )
286
- ref_param_2 .grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
287
290
ref_gradient_clipping_optimizer .step ()
288
291
289
- # create gradient clipping optimizer containing both DTensor and tensor
292
+ # create gradient clipping optimizer containing a DTensor and a tensor
290
293
device_mesh = init_device_mesh (self .device_type , (self .world_size ,))
291
294
param_1 = distribute_tensor (
292
- torch .tensor ([1.0 , 2.0 , 3.0 ], requires_grad = True , device = self .device_type ),
293
- device_mesh ,
294
- [Shard (0 )],
295
+ tensor = torch .tensor (
296
+ data_1 .clone (), requires_grad = True , device = self .device_type
297
+ ),
298
+ device_mesh = device_mesh ,
299
+ placements = [Shard (0 )],
295
300
)
296
301
param_2 = torch .tensor (
297
- [ 4.0 , 5.0 , 6.0 ] , requires_grad = True , device = self .device_type
302
+ data_2 . clone () , requires_grad = True , device = self .device_type
298
303
)
304
+ param_1 .grad = distribute_tensor (
305
+ tensor = data_1_grad .clone (),
306
+ device_mesh = device_mesh ,
307
+ placements = [Shard (0 )],
308
+ )
309
+ param_2 .grad = data_2_grad .clone ()
299
310
param_to_pgs = self ._get_params_to_pg ([param_1 ])
300
311
keyed_optimizer = DummyKeyedOptimizer (
301
- {"dtensor_param_1" : param_1 , "dtensor_param_2" : param_2 },
302
- {},
303
- [{"params" : [param_1 , param_2 ]}],
312
+ params = {"dtensor_param_1" : param_1 , "dtensor_param_2" : param_2 },
313
+ state = {},
314
+ param_groups = [{"params" : [param_1 , param_2 ]}],
304
315
)
305
316
gradient_clipping_optimizer = GradientClippingOptimizer (
306
317
optimizer = keyed_optimizer ,
@@ -310,21 +321,113 @@ def test_dtensor_clip_all_gradients_norm(
310
321
enable_global_grad_clip = True ,
311
322
param_to_pgs = param_to_pgs , # pyre-ignore[6]
312
323
)
313
- gradient_clipping_optimizer .zero_grad ()
324
+ gradient_clipping_optimizer .step ()
325
+
326
+ for param_group , ref_param_group in zip (
327
+ gradient_clipping_optimizer .param_groups ,
328
+ ref_gradient_clipping_optimizer .param_groups ,
329
+ strict = True ,
330
+ ):
331
+ for param , ref_param in zip (
332
+ param_group ["params" ], ref_param_group ["params" ], strict = True
333
+ ):
334
+ param_grad = (
335
+ param .grad .full_tensor () # pyre-ignore[16]
336
+ if isinstance (param , DTensor )
337
+ else param .grad
338
+ )
339
+ self .assertEqual (
340
+ param_grad ,
341
+ ref_param .grad ,
342
+ f"Expect gradient to be the same. However, found { param_grad = } , { ref_param .grad = } " ,
343
+ )
344
+
345
+ @with_comms
346
+ @parametrize ("norm_type" , ("inf" , 1 , 2 ))
347
+ def test_multiple_sharded_dtensors_clip_all_gradients_norm (
348
+ self , norm_type : Union [float , str ]
349
+ ) -> None :
350
+ """
351
+ Test to ensure that the gradient clipping optimizer clips gradients
352
+ correctly with multiple sharded DTensors by comparing gradients to their
353
+ torch.tensor counterpart.
354
+
355
+ Note that clipping for DTensor may require communication.
356
+ """
357
+
358
+ # data for testing clipping
359
+ data_1 = torch .tensor ([1.0 , 2.0 , 3.0 ], device = self .device_type )
360
+ data_2 = torch .tensor ([4.0 , 5.0 , 6.0 ], device = self .device_type )
361
+ data_1_grad = torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type )
362
+ data_2_grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
363
+
364
+ # create gradient clipping optimizer containing no dtensor for reference
365
+ ref_param_1 = torch .nn .Parameter (data_1 .clone ())
366
+ ref_param_2 = torch .nn .Parameter (data_2 .clone ())
367
+ ref_param_1 .grad = data_1_grad .clone ()
368
+ ref_param_2 .grad = data_2_grad .clone ()
369
+ ref_keyed_optimizer = DummyKeyedOptimizer (
370
+ params = {"param_1" : ref_param_1 , "param_2" : ref_param_2 },
371
+ state = {},
372
+ param_groups = [{"params" : [ref_param_1 , ref_param_2 ]}],
373
+ )
374
+ ref_gradient_clipping_optimizer = GradientClippingOptimizer (
375
+ optimizer = ref_keyed_optimizer ,
376
+ clipping = GradientClipping .NORM ,
377
+ max_gradient = 10.0 ,
378
+ norm_type = norm_type ,
379
+ )
380
+ ref_gradient_clipping_optimizer .step ()
381
+
382
+ # create gradient clipping optimizer containing 2 DTensors
383
+ device_mesh = init_device_mesh (self .device_type , (self .world_size ,))
384
+ param_1 = distribute_tensor (
385
+ tensor = torch .tensor (
386
+ data_1 .clone (), requires_grad = True , device = self .device_type
387
+ ),
388
+ device_mesh = device_mesh ,
389
+ placements = [Shard (0 )],
390
+ )
391
+ param_2 = distribute_tensor (
392
+ tensor = torch .tensor (
393
+ data_2 .clone (), requires_grad = True , device = self .device_type
394
+ ),
395
+ device_mesh = device_mesh ,
396
+ placements = [Shard (0 )],
397
+ )
314
398
param_1 .grad = distribute_tensor (
315
- torch .tensor ([12.0 , 15.0 , 18.0 ], device = self .device_type ),
316
- device_mesh ,
317
- [Shard (0 )],
399
+ tensor = data_1_grad .clone (),
400
+ device_mesh = device_mesh ,
401
+ placements = [Shard (0 )],
402
+ )
403
+ param_2 .grad = distribute_tensor (
404
+ tensor = data_2_grad .clone (),
405
+ device_mesh = device_mesh ,
406
+ placements = [Shard (0 )],
407
+ )
408
+ param_to_pgs = self ._get_params_to_pg ([param_1 , param_2 ])
409
+ keyed_optimizer = DummyKeyedOptimizer (
410
+ params = {"dtensor_param_1" : param_1 , "dtensor_param_2" : param_2 },
411
+ state = {},
412
+ param_groups = [{"params" : [param_1 , param_2 ]}],
413
+ )
414
+ gradient_clipping_optimizer = GradientClippingOptimizer (
415
+ optimizer = keyed_optimizer ,
416
+ clipping = GradientClipping .NORM ,
417
+ max_gradient = 10.0 ,
418
+ norm_type = norm_type ,
419
+ enable_global_grad_clip = True ,
420
+ param_to_pgs = param_to_pgs , # pyre-ignore[6]
318
421
)
319
- param_2 .grad = torch .tensor ([20.0 , 30.0 , 15.0 ], device = self .device_type )
320
422
gradient_clipping_optimizer .step ()
321
423
322
424
for param_group , ref_param_group in zip (
323
425
gradient_clipping_optimizer .param_groups ,
324
426
ref_gradient_clipping_optimizer .param_groups ,
427
+ strict = True ,
325
428
):
326
429
for param , ref_param in zip (
327
- param_group ["params" ], ref_param_group ["params" ]
430
+ param_group ["params" ], ref_param_group ["params" ], strict = True
328
431
):
329
432
param_grad = (
330
433
param .grad .full_tensor () # pyre-ignore[16]
0 commit comments