6
6
7
7
8
8
from functools import partial
9
- from typing import Callable , Literal
9
+ from typing import Callable , List , Literal
10
10
11
11
import torch
12
12
import torch .distributed as dist
13
13
import torch .nn as nn
14
+ from torch .distributed .distributed_c10d import (
15
+ _get_process_group_name ,
16
+ _resolve_process_group ,
17
+ )
14
18
from torch .distributed .tensor import (
15
19
DeviceMesh ,
16
20
distribute_module ,
23
27
from torch .distributed .tensor .placement_types import Placement
24
28
25
29
26
- # from torch.distributed._functional_collectives import all_to_all_single_autograd
27
- # TODO: there is memory leak issue with AC + all_to_all_single_autograd
28
- # This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
29
- class _A2A (torch .autograd .Function ):
30
- @staticmethod
31
- def forward (ctx , x , out_splits , in_splits , group ):
32
- T_out = int (sum (out_splits ))
33
- y = x .new_empty ((T_out ,) + tuple (x .shape [1 :])) # allocate by output splits
34
- dist .all_to_all_single (y , x .contiguous (), out_splits , in_splits , group = group )
30
+ # This custom op syncs the inputs AND runs a2a. Doing both allows SAC in AC1 to avoid a sync
31
+ # if the a2a is saved.
32
+ @torch .library .custom_op ("titan::_sync_and_all_to_all" , mutates_args = ())
33
+ def _sync_and_all_to_all (
34
+ x : torch .Tensor , out_splits : torch .Tensor , in_splits : torch .Tensor , group_name : str
35
+ ) -> List [torch .Tensor ]:
36
+ group = None if group_name == "None" else _resolve_process_group (group_name )
37
+ if out_splits .is_cpu :
38
+ # custom ops don't allow returning outputs that are aliased to inputs
39
+ out_splits_cpu , in_splits_cpu = torch .empty ((0 ,)), torch .empty ((0 ,))
40
+ out_splits , in_splits = out_splits .tolist (), in_splits .tolist ()
41
+ else :
42
+ out_splits_cpu = out_splits .to (device = "cpu" , non_blocking = True )
43
+ in_splits_cpu = in_splits .to (device = "cpu" , non_blocking = True )
44
+ torch .cuda .current_stream ().synchronize ()
45
+ out_splits , in_splits = out_splits_cpu .tolist (), in_splits_cpu .tolist ()
46
+ T_out = int (sum (out_splits ))
47
+ y = x .new_empty ((T_out ,) + tuple (x .shape [1 :]))
48
+ dist .all_to_all_single (
49
+ y , x .contiguous (), out_splits , in_splits , group = group , async_op = False
50
+ )
51
+ return [y , out_splits_cpu , in_splits_cpu ]
52
+
53
+
54
+ @_sync_and_all_to_all .register_fake
55
+ def _ (x , out_splits , in_splits , group ):
56
+ T_out = int (out_splits .to (torch .int64 ).sum ())
57
+ out = [x .new_empty ((T_out ,) + tuple (x .shape [1 :]))]
58
+ if out_splits .is_cuda :
59
+ out += [torch .empty_like (out_splits ), torch .empty_like (in_splits )]
60
+ else :
61
+ out += [
62
+ torch .empty ((0 ,)),
63
+ torch .empty (
64
+ 0 ,
65
+ ),
66
+ ]
67
+ return out
35
68
36
- ctx .in_splits = in_splits
37
- ctx .out_splits = out_splits
38
- ctx .group = group
39
- return y
40
69
41
- @staticmethod
42
- def backward (ctx , grad_y ):
43
- # grad wrt input has length sum(in_splits)
44
- T_in = int (sum (ctx .in_splits ))
45
- grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
46
- dist .all_to_all_single (
47
- grad_x , grad_y .contiguous (), ctx .in_splits , ctx .out_splits , group = ctx .group
48
- )
49
- return grad_x , None , None , None
70
+ def _backward (ctx , grads ):
71
+ grad_y = grads [0 ]
72
+ out_splits = ctx .out_splits_cpu .tolist ()
73
+ in_splits = ctx .in_splits_cpu .tolist ()
74
+ T_in = int (sum (ctx .in_splits_cpu ))
75
+ grad_x = grad_y .new_empty ((T_in ,) + tuple (grad_y .shape [1 :]))
76
+ group = None if ctx .group_name == "None" else _resolve_process_group (ctx .group_name )
77
+ dist .all_to_all_single (
78
+ grad_x , grad_y .contiguous (), in_splits , out_splits , group = group , async_op = False
79
+ )
80
+ return grad_x , None , None , None
81
+
82
+
83
+ def _setup_context (ctx , inputs , output ):
84
+ _ , out_splits , in_splits , group_name = inputs
85
+ _ , out_splits_cpu , in_splits_cpu = output
86
+ ctx .group_name = group_name
87
+ ctx .in_splits_cpu = in_splits_cpu if not in_splits .is_cpu else in_splits
88
+ ctx .out_splits_cpu = out_splits_cpu if not out_splits .is_cpu else out_splits
89
+
90
+
91
+ # Register autograd: PyTorch will use this in compiled graphs
92
+ _sync_and_all_to_all .register_autograd (_backward , setup_context = _setup_context )
50
93
51
94
52
95
def all_to_all_single_autograd (x , out_splits , in_splits , group ):
53
- return _A2A .apply (x , out_splits , in_splits , group )
96
+ pg_name = "None" if group is None else _get_process_group_name (group )
97
+ return tuple (
98
+ torch .ops .titan ._sync_and_all_to_all (x , out_splits , in_splits , pg_name )
99
+ )
54
100
55
101
56
102
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -183,28 +229,19 @@ def _token_dispatch(self, mod, inputs, device_mesh):
183
229
num_tokens_per_expert ,
184
230
group = device_mesh .get_group (),
185
231
)
186
- input_splits = (
187
- num_tokens_per_expert .view (ep_size , - 1 )
188
- .sum (dim = 1 )
189
- .to (torch .device ("cpu" ), non_blocking = True )
190
- )
191
- output_splits = (
192
- num_tokens_per_expert_group .view (ep_size , - 1 )
193
- .sum (dim = 1 )
194
- .to (torch .device ("cpu" ), non_blocking = True )
195
- )
196
- # NOTE: this would incur a device-to-host sync
197
- torch .cuda .current_stream ().synchronize ()
198
- self .input_splits = input_splits .tolist ()
199
- self .output_splits = output_splits .tolist ()
232
+ input_splits = num_tokens_per_expert .view (ep_size , - 1 ).sum (dim = 1 )
233
+ output_splits = num_tokens_per_expert_group .view (ep_size , - 1 ).sum (dim = 1 )
200
234
201
235
# perform all-to-all
202
- routed_input = all_to_all_single_autograd (
236
+ routed_input , out_splits_cpu , in_splits_cpu = all_to_all_single_autograd (
203
237
routed_input ,
204
- self . output_splits ,
205
- self . input_splits ,
238
+ output_splits ,
239
+ input_splits ,
206
240
device_mesh .get_group (),
207
241
)
242
+ assert not input_splits .is_cpu
243
+ self .out_splits_cpu = out_splits_cpu
244
+ self .in_splits_cpu = in_splits_cpu
208
245
209
246
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
210
247
# However, the num_tokens_per_expert_group is not of the final target format
@@ -227,10 +264,10 @@ def _partition_fn(name, mod, device_mesh):
227
264
228
265
# performing all-to-all combine on the output
229
266
def _token_combine (self , mod , routed_output , device_mesh ):
230
- routed_output = all_to_all_single_autograd (
267
+ routed_output , _ , _ = all_to_all_single_autograd (
231
268
routed_output ,
232
- self .input_splits ,
233
- self .output_splits ,
269
+ self .in_splits_cpu ,
270
+ self .out_splits_cpu ,
234
271
device_mesh .get_group (),
235
272
)
236
273
return routed_output
0 commit comments