6
6
7
7
8
8
from functools import partial
9
- from typing import Callable , Literal
9
+ from typing import Callable , List , Literal
10
10
11
11
import torch
12
12
import torch .distributed as dist
21
21
)
22
22
from torch .distributed .tensor .parallel import ParallelStyle
23
23
from torch .distributed .tensor .placement_types import Placement
24
+ from torch .distributed .distributed_c10d import (
25
+ _resolve_process_group ,
26
+ _get_process_group_name
27
+ )
24
28
25
29
26
- # from torch.distributed._functional_collectives import all_to_all_single_autograd
27
- # TODO: there is memory leak issue with AC + all_to_all_single_autograd
28
- # This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29
- class _A2A (torch .autograd .Function ):
30
- @staticmethod
31
- def forward (ctx , x , out_splits , in_splits , group ):
32
- T_out = int (sum (out_splits ))
33
- y = x .new_empty ((T_out ,) + tuple (x .shape [1 :])) # allocate by output splits
34
- dist .all_to_all_single (y , x .contiguous (), out_splits , in_splits , group = group )
35
-
36
- ctx .in_splits = in_splits
37
- ctx .out_splits = out_splits
38
- ctx .group = group
39
- return y
40
-
41
- @staticmethod
42
- def backward (ctx , grad_y ):
43
- # grad wrt input has length sum(in_splits)
44
- T_in = int (sum (ctx .in_splits ))
45
- grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
46
- dist .all_to_all_single (
47
- grad_x , grad_y .contiguous (), ctx .in_splits , ctx .out_splits , group = ctx .group
48
- )
49
- return grad_x , None , None , None
30
+ # This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31
+ # if the a2a is saved.
32
+ @torch .library .custom_op ("titan::_sync_and_all_to_all" , mutates_args = ())
33
+ def _sync_and_all_to_all (x : torch .Tensor ,
34
+ out_splits : torch .Tensor ,
35
+ in_splits : torch .Tensor ,
36
+ group_name : str ) -> List [torch .Tensor ]:
37
+ group = None if group_name == "None" else _resolve_process_group (group_name )
38
+ if out_splits .is_cpu :
39
+ # custom ops don't allow returning outputs that are aliased to inputs
40
+ out_splits_cpu , in_splits_cpu = torch .empty ((0 ,)), torch .empty ((0 ,))
41
+ out_splits , in_splits = out_splits .tolist (), in_splits .tolist ()
42
+ else :
43
+ out_splits_cpu , in_splits_cpu = out_splits .to (device = "cpu" , non_blocking = True ), in_splits .to (device = "cpu" , non_blocking = True )
44
+ torch .cuda .current_stream ().synchronize ()
45
+ out_splits , in_splits = out_splits_cpu .tolist (), in_splits_cpu .tolist ()
46
+ T_out = int (sum (out_splits ))
47
+ y = x .new_empty ((T_out ,) + tuple (x .shape [1 :]))
48
+ dist .all_to_all_single (y , x .contiguous (), out_splits , in_splits , group = group , async_op = False )
49
+ return [y , out_splits_cpu , in_splits_cpu ]
50
+
51
+
52
+ @_sync_and_all_to_all .register_fake
53
+ def _ (x , out_splits , in_splits , group ):
54
+ T_out = int (out_splits .to (torch .int64 ).sum ())
55
+ out = [x .new_empty ((T_out ,) + tuple (x .shape [1 :]))]
56
+ if out_splits .is_cuda :
57
+ out += [torch .empty_like (out_splits ), torch .empty_like (in_splits )]
58
+ else :
59
+ out += [torch .empty ((0 ,)), torch .empty (0 ,)]
60
+ return out
61
+
62
+
63
+ def _backward (ctx , grads ):
64
+ grad_y = grads [0 ]
65
+ out_splits = ctx .out_splits_cpu .tolist ()
66
+ in_splits = ctx .in_splits_cpu .tolist ()
67
+ T_in = int (sum (ctx .in_splits_cpu ))
68
+ grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
69
+ group = None if ctx .group_name == "None" else _resolve_process_group (ctx .group_name )
70
+ dist .all_to_all_single (
71
+ grad_x , grad_y .contiguous (), in_splits , out_splits , group = group , async_op = False
72
+ )
73
+ return grad_x , None , None , None
74
+
75
+
76
+ def _setup_context (ctx , inputs , output ):
77
+ _ , out_splits , in_splits , group_name = inputs
78
+ _ , out_splits_cpu , in_splits_cpu = output
79
+ ctx .group_name = group_name
80
+ ctx .in_splits_cpu = in_splits_cpu if not in_splits .is_cpu else in_splits
81
+ ctx .out_splits_cpu = out_splits_cpu if not out_splits .is_cpu else out_splits
82
+
83
+
84
+ # Register autograd: PyTorch will use this in compiled graphs
85
+ _sync_and_all_to_all .register_autograd (_backward , setup_context = _setup_context )
50
86
51
87
52
88
def all_to_all_single_autograd (x , out_splits , in_splits , group ):
53
- return _A2A .apply (x , out_splits , in_splits , group )
89
+ pg_name = "None" if group is None else _get_process_group_name (group )
90
+ return tuple (torch .ops .titan ._sync_and_all_to_all (x , out_splits , in_splits , pg_name ))
54
91
55
92
56
93
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -186,25 +223,25 @@ def _token_dispatch(self, mod, inputs, device_mesh):
186
223
input_splits = (
187
224
num_tokens_per_expert .view (ep_size , - 1 )
188
225
.sum (dim = 1 )
189
- .to (torch .device ("cpu" ), non_blocking = True )
190
226
)
191
227
output_splits = (
192
228
num_tokens_per_expert_group .view (ep_size , - 1 )
193
229
.sum (dim = 1 )
194
- .to (torch .device ("cpu" ), non_blocking = True )
195
230
)
196
231
# NOTE: this would incur a device-to-host sync
197
- torch .cuda .current_stream ().synchronize ()
198
- self .input_splits = input_splits .tolist ()
199
- self .output_splits = output_splits .tolist ()
232
+ self .input_splits = input_splits
233
+ self .output_splits = output_splits
200
234
201
235
# perform all-to-all
202
- routed_input = all_to_all_single_autograd (
236
+ routed_input , out_splits_cpu , in_splits_cpu = all_to_all_single_autograd (
203
237
routed_input ,
204
238
self .output_splits ,
205
239
self .input_splits ,
206
240
device_mesh .get_group (),
207
241
)
242
+ assert not self .input_splits .is_cpu
243
+ self .out_splits_cpu = out_splits_cpu
244
+ self .in_splits_cpu = in_splits_cpu
208
245
209
246
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210
247
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh):
227
264
228
265
# performing all-to-all combine on the output
229
266
def _token_combine (self , mod , routed_output , device_mesh ):
230
- routed_output = all_to_all_single_autograd (
267
+ routed_output , _ , _ = all_to_all_single_autograd (
231
268
routed_output ,
232
- self .input_splits ,
233
- self .output_splits ,
269
+ self .in_splits_cpu ,
270
+ self .out_splits_cpu ,
234
271
device_mesh .get_group (),
235
272
)
236
273
return routed_output
0 commit comments