@@ -54,11 +54,6 @@ def get_gradient_division() -> bool:
54
54
55
55
56
56
def set_use_sync_collectives (val : bool ) -> None :
57
- if val and torch ._running_with_deploy ():
58
- raise RuntimeError (
59
- "TorchRec sync_collectives are not supported in torch.deploy."
60
- )
61
-
62
57
global USE_SYNC_COLLECTIVES
63
58
USE_SYNC_COLLECTIVES = val
64
59
@@ -2356,202 +2351,213 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
2356
2351
return (None , None , myreq .dummy_tensor )
2357
2352
2358
2353
2359
- if not torch ._running_with_deploy (): # noqa C901
2360
- # Torch Library op def can not be used in Deploy
2361
- class AllToAllSingle (torch .autograd .Function ):
2362
- @staticmethod
2363
- # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2364
- def forward (
2365
- # pyre-fixme[2]: Parameter must be annotated.
2366
- ctx ,
2367
- input : Tensor ,
2368
- output_split_sizes : List [int ],
2369
- input_split_sizes : List [int ],
2370
- group_name : str ,
2371
- group_size : int ,
2372
- gradient_division : bool ,
2373
- ) -> Tensor :
2374
- ctx .output_split_sizes = input_split_sizes
2375
- ctx .input_split_sizes = output_split_sizes
2376
- ctx .group_name = group_name
2377
- ctx .group_size = group_size
2378
- ctx .gradient_division = gradient_division
2379
- return torch .distributed ._functional_collectives .all_to_all_single (
2380
- input , output_split_sizes , input_split_sizes , group_name
2381
- )
2382
-
2383
- @staticmethod
2384
- # pyre-ignore
2385
- def backward (ctx , grad ):
2386
- grad = torch .distributed ._functional_collectives .all_to_all_single (
2387
- grad ,
2388
- ctx .output_split_sizes ,
2389
- ctx .input_split_sizes ,
2390
- ctx .group_name ,
2391
- )
2392
- if ctx .gradient_division :
2393
- grad .div_ (ctx .group_size )
2394
-
2395
- return grad , None , None , None , None , None
2396
-
2397
- # torchrec::reduce_scatter_tensor
2398
- @torch .library .custom_op ("torchrec::reduce_scatter_tensor" , mutates_args = ())
2399
- def reduce_scatter_tensor (
2354
+ class AllToAllSingle (torch .autograd .Function ):
2355
+ @staticmethod
2356
+ # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2357
+ def forward (
2358
+ # pyre-fixme[2]: Parameter must be annotated.
2359
+ ctx ,
2400
2360
input : Tensor ,
2401
- reduceOp : str ,
2402
- group_size : int ,
2361
+ output_split_sizes : List [ int ] ,
2362
+ input_split_sizes : List [ int ] ,
2403
2363
group_name : str ,
2404
- gradient_division : bool ,
2405
- ) -> Tensor :
2406
- out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2407
- input ,
2408
- reduceOp ,
2409
- group_size ,
2410
- group_name ,
2411
- )
2412
- return torch .ops ._c10d_functional .wait_tensor (out )
2413
-
2414
- @torch .library .register_fake ("torchrec::reduce_scatter_tensor" )
2415
- def reduce_scatter_tensor_fake (
2416
- input : Tensor ,
2417
- reduceOp : str ,
2418
2364
group_size : int ,
2419
- group_name : str ,
2420
2365
gradient_division : bool ,
2421
2366
) -> Tensor :
2422
- return torch .ops ._c10d_functional .reduce_scatter_tensor (
2423
- input ,
2424
- reduceOp ,
2425
- group_size ,
2426
- group_name ,
2427
- )
2428
-
2429
- # pyre-ignore
2430
- def reduce_scatter_tensor_setup_context (ctx , inputs , output ) -> None :
2431
- _ , _ , group_size , group_name , gradient_division = inputs
2432
- ctx .group_size = group_size
2367
+ ctx .output_split_sizes = input_split_sizes
2368
+ ctx .input_split_sizes = output_split_sizes
2433
2369
ctx .group_name = group_name
2370
+ ctx .group_size = group_size
2434
2371
ctx .gradient_division = gradient_division
2372
+ return torch .distributed ._functional_collectives .all_to_all_single (
2373
+ input , output_split_sizes , input_split_sizes , group_name
2374
+ )
2435
2375
2376
+ @staticmethod
2436
2377
# pyre-ignore
2437
- def reduce_scatter_tensor_backward (ctx , grad ):
2438
- # TODO(ivankobzarev): Support codecs(quantization) on backward
2439
- out = torch .ops ._c10d_functional .all_gather_into_tensor (
2378
+ def backward (ctx , grad ):
2379
+ grad = torch .distributed ._functional_collectives .all_to_all_single (
2440
2380
grad ,
2441
- ctx .group_size ,
2381
+ ctx .output_split_sizes ,
2382
+ ctx .input_split_sizes ,
2442
2383
ctx .group_name ,
2443
2384
)
2444
- grad = torch .ops ._c10d_functional .wait_tensor (out )
2445
2385
if ctx .gradient_division :
2446
2386
grad .div_ (ctx .group_size )
2447
2387
2448
2388
return grad , None , None , None , None , None
2449
2389
2450
- torch .library .register_autograd (
2451
- "torchrec::reduce_scatter_tensor" ,
2452
- reduce_scatter_tensor_backward ,
2453
- setup_context = reduce_scatter_tensor_setup_context ,
2390
+
2391
+ # torchrec::reduce_scatter_tensor
2392
+ @torch .library .custom_op ("torchrec::reduce_scatter_tensor" , mutates_args = ())
2393
+ def reduce_scatter_tensor (
2394
+ input : Tensor ,
2395
+ reduceOp : str ,
2396
+ group_size : int ,
2397
+ group_name : str ,
2398
+ gradient_division : bool ,
2399
+ ) -> Tensor :
2400
+ out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2401
+ input ,
2402
+ reduceOp ,
2403
+ group_size ,
2404
+ group_name ,
2454
2405
)
2406
+ return torch .ops ._c10d_functional .wait_tensor (out )
2455
2407
2456
- # torchrec::all_gather_into_tensor
2457
- @torch .library .custom_op ("torchrec::all_gather_into_tensor" , mutates_args = ())
2458
- def all_gather_into_tensor (
2459
- shard : Tensor ,
2460
- gather_dim : int ,
2461
- group_size : int ,
2462
- group_name : str ,
2463
- gradient_division : bool ,
2464
- ) -> Tensor :
2465
- out = torch .ops ._c10d_functional .all_gather_into_tensor (
2466
- shard , group_size , group_name
2467
- )
2468
- return torch .ops ._c10d_functional .wait_tensor (out )
2469
2408
2470
- @torch .library .register_fake ("torchrec::all_gather_into_tensor" )
2471
- def all_gather_into_tensor_fake (
2472
- shard : Tensor ,
2473
- gather_dim : int ,
2474
- group_size : int ,
2475
- group_name : str ,
2476
- gradient_division : bool ,
2477
- ) -> Tensor :
2478
- return torch .ops ._c10d_functional .all_gather_into_tensor (
2479
- shard , group_size , group_name
2480
- )
2409
+ @torch .library .register_fake ("torchrec::reduce_scatter_tensor" )
2410
+ def reduce_scatter_tensor_fake (
2411
+ input : Tensor ,
2412
+ reduceOp : str ,
2413
+ group_size : int ,
2414
+ group_name : str ,
2415
+ gradient_division : bool ,
2416
+ ) -> Tensor :
2417
+ return torch .ops ._c10d_functional .reduce_scatter_tensor (
2418
+ input ,
2419
+ reduceOp ,
2420
+ group_size ,
2421
+ group_name ,
2422
+ )
2481
2423
2482
- # pyre-ignore
2483
- def all_gather_into_tensor_setup_context (ctx , inputs , output ) -> None :
2484
- _ , gather_dim , group_size , group_name , gradient_division = inputs
2485
- ctx .group_size = group_size
2486
- ctx .group_name = group_name
2487
- ctx .gradient_division = gradient_division
2488
2424
2489
- # pyre-ignore
2490
- def all_gather_into_tensor_backward (ctx , grad ):
2491
- # TODO(ivankobzarev): Support codecs(quantization) on backward
2492
- out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2493
- grad ,
2494
- "sum" ,
2495
- ctx .group_size ,
2496
- ctx .group_name ,
2497
- )
2498
- grad = torch .ops ._c10d_functional .wait_tensor (out )
2499
- if ctx .gradient_division :
2500
- grad .div_ (ctx .group_size )
2425
+ # pyre-ignore
2426
+ def reduce_scatter_tensor_setup_context (ctx , inputs , output ) -> None :
2427
+ _ , _ , group_size , group_name , gradient_division = inputs
2428
+ ctx .group_size = group_size
2429
+ ctx .group_name = group_name
2430
+ ctx .gradient_division = gradient_division
2501
2431
2502
- return grad , None , None , None , None
2503
2432
2504
- torch .library .register_autograd (
2505
- "torchrec::all_gather_into_tensor" ,
2506
- all_gather_into_tensor_backward ,
2507
- setup_context = all_gather_into_tensor_setup_context ,
2433
+ # pyre-ignore
2434
+ def reduce_scatter_tensor_backward (ctx , grad ):
2435
+ # TODO(ivankobzarev): Support codecs(quantization) on backward
2436
+ out = torch .ops ._c10d_functional .all_gather_into_tensor (
2437
+ grad ,
2438
+ ctx .group_size ,
2439
+ ctx .group_name ,
2508
2440
)
2441
+ grad = torch .ops ._c10d_functional .wait_tensor (out )
2442
+ if ctx .gradient_division :
2443
+ grad .div_ (ctx .group_size )
2509
2444
2510
- @torch .library .custom_op ("torchrec::_split_1d_cat_2d" , mutates_args = ())
2511
- def _split_1d_cat_2d_impl (
2512
- t : torch .Tensor , dim0 : int , dim1s : List [int ]
2513
- ) -> torch .Tensor :
2514
- torch ._check_is_size (dim0 )
2515
- [torch ._check_is_size (dim1 ) for dim1 in dim1s ]
2516
- splits : List [torch .Tensor ] = t .split ([dim0 * dim1 for dim1 in dim1s ])
2517
- return torch .cat (
2518
- [s .reshape (dim0 , dim1 ) for s , dim1 in zip (splits , dim1s )],
2519
- dim = 1 ,
2520
- )
2445
+ return grad , None , None , None , None , None
2446
+
2447
+
2448
+ torch .library .register_autograd (
2449
+ "torchrec::reduce_scatter_tensor" ,
2450
+ reduce_scatter_tensor_backward ,
2451
+ setup_context = reduce_scatter_tensor_setup_context ,
2452
+ )
2521
2453
2522
- @torch .library .register_fake ("torchrec::_split_1d_cat_2d" )
2523
- def _split_1d_cat_2d_impl_abstract (
2524
- t : torch .Tensor , dim0 : int , dim1s : List [int ]
2525
- ) -> torch .Tensor :
2526
- return t .new_empty ([dim0 , sum (dim1s )])
2527
2454
2528
- @torch .library .custom_op (
2529
- "torchrec::_split_1d_cat_2d_backward_impl" , mutates_args = ()
2455
+ # torchrec::all_gather_into_tensor
2456
+ @torch .library .custom_op ("torchrec::all_gather_into_tensor" , mutates_args = ())
2457
+ def all_gather_into_tensor (
2458
+ shard : Tensor ,
2459
+ gather_dim : int ,
2460
+ group_size : int ,
2461
+ group_name : str ,
2462
+ gradient_division : bool ,
2463
+ ) -> Tensor :
2464
+ out = torch .ops ._c10d_functional .all_gather_into_tensor (
2465
+ shard , group_size , group_name
2530
2466
)
2531
- def _split_1d_cat_2d_backward_impl (
2532
- grad : torch .Tensor , dim1s : List [int ]
2533
- ) -> torch .Tensor :
2534
- splits = grad .split (dim1s , dim = 1 )
2535
- return torch .cat ([s .reshape (- 1 ) for s in splits ], dim = 0 )
2536
-
2537
- @torch .library .register_fake ("torchrec::_split_1d_cat_2d_backward_impl" )
2538
- def _split_1d_cat_2d_backward_impl_fake (
2539
- grad : torch .Tensor , dim1s : List [int ]
2540
- ) -> torch .Tensor :
2541
- return grad .new_empty ([grad .numel ()])
2467
+ return torch .ops ._c10d_functional .wait_tensor (out )
2542
2468
2543
- # pyre-ignore
2544
- def _split_1d_cat_2d_backward (ctx , grad ):
2545
- ret = torch .ops .torchrec ._split_1d_cat_2d_backward_impl (grad , ctx .dim1s )
2546
- return ret , None , None
2547
2469
2548
- # pyre-ignore
2549
- def _split_1d_cat_2d_setup_context (ctx , inputs , output ):
2550
- (x , dim0 , dim1s ) = inputs
2551
- ctx .dim1s = dim1s
2552
-
2553
- torch .library .register_autograd (
2554
- "torchrec::_split_1d_cat_2d" ,
2555
- _split_1d_cat_2d_backward ,
2556
- setup_context = _split_1d_cat_2d_setup_context ,
2470
+ @torch .library .register_fake ("torchrec::all_gather_into_tensor" )
2471
+ def all_gather_into_tensor_fake (
2472
+ shard : Tensor ,
2473
+ gather_dim : int ,
2474
+ group_size : int ,
2475
+ group_name : str ,
2476
+ gradient_division : bool ,
2477
+ ) -> Tensor :
2478
+ return torch .ops ._c10d_functional .all_gather_into_tensor (
2479
+ shard , group_size , group_name
2557
2480
)
2481
+
2482
+
2483
+ # pyre-ignore
2484
+ def all_gather_into_tensor_setup_context (ctx , inputs , output ) -> None :
2485
+ _ , gather_dim , group_size , group_name , gradient_division = inputs
2486
+ ctx .group_size = group_size
2487
+ ctx .group_name = group_name
2488
+ ctx .gradient_division = gradient_division
2489
+
2490
+
2491
+ # pyre-ignore
2492
+ def all_gather_into_tensor_backward (ctx , grad ):
2493
+ # TODO(ivankobzarev): Support codecs(quantization) on backward
2494
+ out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2495
+ grad ,
2496
+ "sum" ,
2497
+ ctx .group_size ,
2498
+ ctx .group_name ,
2499
+ )
2500
+ grad = torch .ops ._c10d_functional .wait_tensor (out )
2501
+ if ctx .gradient_division :
2502
+ grad .div_ (ctx .group_size )
2503
+
2504
+ return grad , None , None , None , None
2505
+
2506
+
2507
+ torch .library .register_autograd (
2508
+ "torchrec::all_gather_into_tensor" ,
2509
+ all_gather_into_tensor_backward ,
2510
+ setup_context = all_gather_into_tensor_setup_context ,
2511
+ )
2512
+
2513
+
2514
+ @torch .library .custom_op ("torchrec::_split_1d_cat_2d" , mutates_args = ())
2515
+ def _split_1d_cat_2d_impl (t : torch .Tensor , dim0 : int , dim1s : List [int ]) -> torch .Tensor :
2516
+ torch ._check_is_size (dim0 )
2517
+ [torch ._check_is_size (dim1 ) for dim1 in dim1s ]
2518
+ splits : List [torch .Tensor ] = t .split ([dim0 * dim1 for dim1 in dim1s ])
2519
+ return torch .cat (
2520
+ [s .reshape (dim0 , dim1 ) for s , dim1 in zip (splits , dim1s )],
2521
+ dim = 1 ,
2522
+ )
2523
+
2524
+
2525
+ @torch .library .register_fake ("torchrec::_split_1d_cat_2d" )
2526
+ def _split_1d_cat_2d_impl_abstract (
2527
+ t : torch .Tensor , dim0 : int , dim1s : List [int ]
2528
+ ) -> torch .Tensor :
2529
+ return t .new_empty ([dim0 , sum (dim1s )])
2530
+
2531
+
2532
+ @torch .library .custom_op ("torchrec::_split_1d_cat_2d_backward_impl" , mutates_args = ())
2533
+ def _split_1d_cat_2d_backward_impl (
2534
+ grad : torch .Tensor , dim1s : List [int ]
2535
+ ) -> torch .Tensor :
2536
+ splits = grad .split (dim1s , dim = 1 )
2537
+ return torch .cat ([s .reshape (- 1 ) for s in splits ], dim = 0 )
2538
+
2539
+
2540
+ @torch .library .register_fake ("torchrec::_split_1d_cat_2d_backward_impl" )
2541
+ def _split_1d_cat_2d_backward_impl_fake (
2542
+ grad : torch .Tensor , dim1s : List [int ]
2543
+ ) -> torch .Tensor :
2544
+ return grad .new_empty ([grad .numel ()])
2545
+
2546
+
2547
+ # pyre-ignore
2548
+ def _split_1d_cat_2d_backward (ctx , grad ):
2549
+ ret = torch .ops .torchrec ._split_1d_cat_2d_backward_impl (grad , ctx .dim1s )
2550
+ return ret , None , None
2551
+
2552
+
2553
+ # pyre-ignore
2554
+ def _split_1d_cat_2d_setup_context (ctx , inputs , output ):
2555
+ (x , dim0 , dim1s ) = inputs
2556
+ ctx .dim1s = dim1s
2557
+
2558
+
2559
+ torch .library .register_autograd (
2560
+ "torchrec::_split_1d_cat_2d" ,
2561
+ _split_1d_cat_2d_backward ,
2562
+ setup_context = _split_1d_cat_2d_setup_context ,
2563
+ )
0 commit comments