33"""
44
55from __future__ import annotations
6- from typing import Optional , Any , Dict
6+ from abc import abstractmethod , ABC
7+ import logging
8+ import numpy
79import os
810import socket
9- import logging
11+ from typing import Callable , Optional , Any , Dict , Type , Union
1012
1113import torch
1214from torch .nn .parallel import DistributedDataParallel
1315
14- from returnn .config import Config
15- from returnn .util .basic import CollectionReadCheckCovered
16+ from returnn .util .basic import CollectionReadCheckCovered , OptionalNotImplementedError
1617
1718_logger = logging .getLogger ("returnn.torch.distributed" )
1819
1920
21+ class ParamSynchronizer (ABC ):
22+ """
23+ Custom parameter synchronization primitive.
24+
25+ Contains a callback that is called after every train step to synchronize model parameters
26+ across processes/nodes.
27+ """
28+
29+ @abstractmethod
30+ def __init__ (self , * , rank : int , size : int , local_rank : int , local_size : int , ** kwargs ):
31+ """
32+ `__init__` called after the default global process group is created.
33+ Can be used to initialize any additional custom process (sub)groups.
34+
35+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
36+
37+ :param rank: global rank of the current process across all nodes
38+ :param size: global world size across all nodes
39+ :param local_rank: local rank of the current process on the current node
40+ :param local_rank: local world size on the current node
41+ """
42+ super ().__init__ ()
43+
44+ def make_distributed_model (self , * , module : torch .nn .Module , ** kwargs ) -> DistributedDataParallel :
45+ """
46+ Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
47+
48+ This function can be left unimplemented if no gradient synchronization is done.
49+
50+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
51+ """
52+ raise OptionalNotImplementedError
53+
54+ @abstractmethod
55+ def step (self , * , module : torch .nn .Module , train_step_idx : int , ** kwargs ):
56+ """
57+ Parameter synchronization callback called after every train step.
58+
59+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
60+
61+ :param module: the NN being trained
62+ :param train_step_idx: the current train step
63+ :param kwargs: any additional kwargs.
64+ """
65+ raise NotImplementedError
66+
67+ def __call__ (self , * args , ** kwargs ):
68+ """forwards to :func:``step``"""
69+ return self .step (* args , ** kwargs )
70+
71+
2072class DistributedContext :
2173 """
2274 This class setups some helper functions for torch distributed training
@@ -26,6 +78,9 @@ def __init__(self, options: Dict[str, Any]):
2678 import torch .distributed as dist
2779
2880 self ._opts = CollectionReadCheckCovered (options )
81+ # Only used to generate forwards compatibility ensuring random kwargs, therefore
82+ # the seed is not important
83+ self ._rng = numpy .random .default_rng ()
2984
3085 # when no backend is specified, both gloo and nccl backends will be created
3186 # the gloo backend will be used for collectives with CPU tensors and
@@ -42,8 +97,13 @@ def __init__(self, options: Dict[str, Any]):
4297 % (socket .gethostname (), os .getpid (), self ._rank , self ._size , self ._local_rank , self ._local_size )
4398 )
4499
100+ self ._custom_sync_class : Optional [Union [Callable , Type [ParamSynchronizer ]]] = self ._opts .get (
101+ "synchronizer" , None
102+ )
103+ self ._custom_sync : Optional [Callable ] = None
45104 self ._reduce_type = self ._opts .get ("reduce_type" , "grad" )
46105 self ._param_sync_step : Optional [int ] = self ._opts .get ("param_sync_step" , None )
106+
47107 if self ._reduce_type == "param" :
48108 assert isinstance (self ._param_sync_step , int ) and self ._param_sync_step > 0 , (
49109 f"reduce_type param: param_sync_step must be a positive int,"
@@ -52,6 +112,23 @@ def __init__(self, options: Dict[str, Any]):
52112 _logger .info (f"reduce_type param: param_sync_step { self ._param_sync_step } " )
53113 elif self ._reduce_type == "grad" :
54114 _logger .info ("reduce_type grad" )
115+ elif self ._reduce_type == "custom" :
116+ if issubclass (self ._custom_sync_class , ParamSynchronizer ):
117+ self ._custom_sync = self ._custom_sync_class (
118+ rank = self ._rank ,
119+ size = self ._size ,
120+ local_rank = self ._local_rank ,
121+ local_size = self ._local_size ,
122+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
123+ )
124+ elif isinstance (self ._custom_sync_class , Callable ):
125+ self ._custom_sync = self ._custom_sync_class
126+ else :
127+ raise ValueError (
128+ f"synchronizer must either be a callable or a class inheriting from { ParamSynchronizer .__name__ } "
129+ )
130+
131+ _logger .info (f"reduce_type custom: { type (self ._custom_sync )} " )
55132 else :
56133 raise ValueError (f"invalid reduce_type { self ._reduce_type !r} " )
57134
@@ -70,6 +147,8 @@ def _check_no_unknown_opts(self):
70147 self ._opts .get ("options" )
71148 if self ._reduce_type == "param" :
72149 self ._opts .get ("sync_on_cpu" )
150+ if self ._reduce_type == "custom" :
151+ self ._opts .get ("synchronizer" )
73152
74153 self ._opts .assert_all_read ()
75154
@@ -102,7 +181,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
102181 """
103182 if self ._reduce_type == "param" :
104183 return None
105- assert self ._reduce_type == "grad"
184+ assert self ._reduce_type in ["custom" , "grad" ]
185+
186+ if self ._reduce_type == "custom" :
187+ assert isinstance (self ._custom_sync , (ParamSynchronizer , Callable ))
188+
189+ if isinstance (self ._custom_sync , ParamSynchronizer ):
190+ try :
191+ return self ._custom_sync .make_distributed_model (
192+ module = module , ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None }
193+ )
194+ except OptionalNotImplementedError :
195+ pass
196+ else :
197+ # callable short form does not have support for DistributedDataParallel
198+ pass
199+
200+ return None
201+
106202 cls = self ._opts .get ("class" , DistributedDataParallel )
107203 if cls is not DistributedDataParallel :
108204 _logger .warning (f"Using custom class { cls } instead of DistributedDataParallel, might be unsupported." )
@@ -115,7 +211,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
115211
116212 def step_after_param_update (self , * , module : torch .nn .Module , epoch_step_idx : int ):
117213 """one train step"""
118- if self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
214+ if self ._reduce_type == "custom" :
215+ with torch .no_grad (): # TODO: do we want this for all syncers?
216+ self ._custom_sync (
217+ module = module ,
218+ train_step_idx = epoch_step_idx ,
219+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
220+ )
221+ elif self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
119222 _sync_params_avg (module = module , sync_on_cpu = self ._opts .get ("sync_on_cpu" , False ))
120223
121224
@@ -155,7 +258,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
155258
156259 if sync_on_cpu :
157260 for param in module .parameters ():
158- # Separately move each param to CPU (instead of the whole module), to safe CPU memory.
261+ # Separately move each param to CPU (instead of the whole module), to save CPU memory.
159262 param_cpu = param .to (torch .device ("cpu" ))
160263 # On CPU, we are likely using Gloo, and Gloo does not support AVG
161264 dist .all_reduce (param_cpu .data , op = dist .ReduceOp .SUM )
@@ -166,12 +269,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
166269 if dist .get_backend () == "gloo" :
167270 # Gloo does not support AVG
168271 reduce_op = dist .ReduceOp .SUM
272+ elif hasattr (dist .ReduceOp , "AVG" ):
273+ reduce_op = dist .ReduceOp .AVG
169274 else :
170- if hasattr (dist .ReduceOp , "AVG" ):
171- reduce_op = dist .ReduceOp .AVG
172- else :
173- # Older PyTorch versions do not have ReduceOp.AVG.
174- reduce_op = dist .ReduceOp .SUM
275+ # Older PyTorch versions do not have ReduceOp.AVG.
276+ reduce_op = dist .ReduceOp .SUM
175277
176278 for param in module .parameters ():
177279 dist .all_reduce (param .data , op = reduce_op )
0 commit comments