6
6
7
7
8
8
from functools import partial
9
- from typing import Callable , Literal
9
+ from typing import Callable , List , Literal
10
10
11
11
import torch
12
12
import torch .distributed as dist
21
21
)
22
22
from torch .distributed .tensor .parallel import ParallelStyle
23
23
from torch .distributed .tensor .placement_types import Placement
24
+ from torch .distributed .distributed_c10d import (
25
+ _resolve_process_group ,
26
+ _get_process_group_name
27
+ )
24
28
25
29
26
- # from torch.distributed._functional_collectives import all_to_all_single_autograd
27
- # TODO: there is memory leak issue with AC + all_to_all_single_autograd
28
- # This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29
- class _A2A (torch .autograd .Function ):
30
- @staticmethod
31
- def forward (ctx , x , out_splits , in_splits , group ):
32
- T_out = int (sum (out_splits ))
33
- y = x .new_empty ((T_out ,) + tuple (x .shape [1 :])) # allocate by output splits
34
- dist .all_to_all_single (y , x .contiguous (), out_splits , in_splits , group = group )
35
-
36
- ctx .in_splits = in_splits
37
- ctx .out_splits = out_splits
38
- ctx .group = group
39
- return y
40
-
41
- @staticmethod
42
- def backward (ctx , grad_y ):
43
- # grad wrt input has length sum(in_splits)
44
- T_in = int (sum (ctx .in_splits ))
45
- grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
46
- dist .all_to_all_single (
47
- grad_x , grad_y .contiguous (), ctx .in_splits , ctx .out_splits , group = ctx .group
48
- )
49
- return grad_x , None , None , None
30
+ # This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31
+ # if the a2a is saved.
32
+ @torch .library .custom_op ("titan::_sync_and_all_to_all" , mutates_args = ())
33
+ def _sync_and_all_to_all (x : torch .Tensor ,
34
+ out_splits : torch .Tensor ,
35
+ in_splits : torch .Tensor ,
36
+ group_name : str ) -> List [torch .Tensor ]:
37
+ group = None if group_name == "None" else _resolve_process_group (group_name )
38
+ if out_splits .is_cpu :
39
+ # custom ops don't allow returning outputs that are aliased to inputs
40
+ out_splits_cpu , in_splits_cpu = torch .empty ((0 ,)), torch .empty ((0 ,))
41
+ out_splits , in_splits = out_splits .tolist (), in_splits .tolist ()
42
+ else :
43
+ out_splits_cpu = out_splits .to (device = "cpu" , non_blocking = True )
44
+ in_splits_cpu = in_splits .to (device = "cpu" , non_blocking = True )
45
+ torch .cuda .current_stream ().synchronize ()
46
+ out_splits , in_splits = out_splits_cpu .tolist (), in_splits_cpu .tolist ()
47
+ T_out = int (sum (out_splits ))
48
+ y = x .new_empty ((T_out ,) + tuple (x .shape [1 :]))
49
+ dist .all_to_all_single (y , x .contiguous (), out_splits , in_splits , group = group , async_op = False )
50
+ return [y , out_splits_cpu , in_splits_cpu ]
51
+
52
+
53
+ @_sync_and_all_to_all .register_fake
54
+ def _ (x , out_splits , in_splits , group ):
55
+ T_out = int (out_splits .to (torch .int64 ).sum ())
56
+ out = [x .new_empty ((T_out ,) + tuple (x .shape [1 :]))]
57
+ if out_splits .is_cuda :
58
+ out += [torch .empty_like (out_splits ), torch .empty_like (in_splits )]
59
+ else :
60
+ out += [torch .empty ((0 ,)), torch .empty (0 ,)]
61
+ return out
62
+
63
+
64
+ def _backward (ctx , grads ):
65
+ grad_y = grads [0 ]
66
+ out_splits = ctx .out_splits_cpu .tolist ()
67
+ in_splits = ctx .in_splits_cpu .tolist ()
68
+ T_in = int (sum (ctx .in_splits_cpu ))
69
+ grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
70
+ group = None if ctx .group_name == "None" else _resolve_process_group (ctx .group_name )
71
+ dist .all_to_all_single (
72
+ grad_x , grad_y .contiguous (), in_splits , out_splits , group = group , async_op = False
73
+ )
74
+ return grad_x , None , None , None
75
+
76
+
77
+ def _setup_context (ctx , inputs , output ):
78
+ _ , out_splits , in_splits , group_name = inputs
79
+ _ , out_splits_cpu , in_splits_cpu = output
80
+ ctx .group_name = group_name
81
+ ctx .in_splits_cpu = in_splits_cpu if not in_splits .is_cpu else in_splits
82
+ ctx .out_splits_cpu = out_splits_cpu if not out_splits .is_cpu else out_splits
83
+
84
+
85
+ # Register autograd: PyTorch will use this in compiled graphs
86
+ _sync_and_all_to_all .register_autograd (_backward , setup_context = _setup_context )
50
87
51
88
52
89
def all_to_all_single_autograd (x , out_splits , in_splits , group ):
53
- return _A2A .apply (x , out_splits , in_splits , group )
90
+ pg_name = "None" if group is None else _get_process_group_name (group )
91
+ return tuple (torch .ops .titan ._sync_and_all_to_all (x , out_splits , in_splits , pg_name ))
54
92
55
93
56
94
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -186,25 +224,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
186
224
input_splits = (
187
225
num_tokens_per_expert .view (ep_size , - 1 )
188
226
.sum (dim = 1 )
189
- .to (torch .device ("cpu" ), non_blocking = True )
190
227
)
191
228
output_splits = (
192
229
num_tokens_per_expert_group .view (ep_size , - 1 )
193
230
.sum (dim = 1 )
194
- .to (torch .device ("cpu" ), non_blocking = True )
195
231
)
196
- # NOTE: this would incur a device-to-host sync
197
- torch .cuda .current_stream ().synchronize ()
198
- self .input_splits = input_splits .tolist ()
199
- self .output_splits = output_splits .tolist ()
200
232
201
233
# perform all-to-all
202
- routed_input = all_to_all_single_autograd (
234
+ routed_input , out_splits_cpu , in_splits_cpu = all_to_all_single_autograd (
203
235
routed_input ,
204
- self . output_splits ,
205
- self . input_splits ,
236
+ output_splits ,
237
+ input_splits ,
206
238
device_mesh .get_group (),
207
239
)
240
+ assert not input_splits .is_cpu
241
+ self .out_splits_cpu = out_splits_cpu
242
+ self .in_splits_cpu = in_splits_cpu
208
243
209
244
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210
245
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +262,10 @@ def _partition_fn(name, mod, device_mesh):
227
262
228
263
# performing all-to-all combine on the output
229
264
def _token_combine (self , mod , routed_output , device_mesh ):
230
- routed_output = all_to_all_single_autograd (
265
+ routed_output , _ , _ = all_to_all_single_autograd (
231
266
routed_output ,
232
- self .input_splits ,
233
- self .output_splits ,
267
+ self .in_splits_cpu ,
268
+ self .out_splits_cpu ,
234
269
device_mesh .get_group (),
235
270
)
236
271
return routed_output
0 commit comments