Skip to content

Commit 85414c9

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
fix bucketized allreduce (meta-pytorch#278)
Summary: - update the callback to work with the new ManagedWork - provide an option to use bucketization using env var Differential Revision: D84101245
1 parent d596ec7 commit 85414c9

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

torchft/local_sgd.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import logging
1313
import math
14+
import os
1415
from contextlib import nullcontext
1516
from types import TracebackType
1617
from typing import Any, Dict, List, Optional, Tuple, Type
@@ -25,6 +26,8 @@
2526

2627
logger: logging.Logger = logging.getLogger(__name__)
2728

29+
USE_BUCKETIZATION_ENV: str = "TORCHFT_USE_BUCKETIZATION"
30+
2831

2932
def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
3033
"""
@@ -171,7 +174,7 @@ def _average(self) -> list[torch.Tensor]:
171174

172175

173176
class _StreamingDiLoCoFragment:
174-
bucket_cap_mb: int = 32 * 1024 * 1024
177+
bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
175178
use_bucketization: bool = False
176179

177180
def __init__(
@@ -220,7 +223,11 @@ def __init__(
220223
if bucket_cap_mb is not None:
221224
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
222225

223-
self.use_bucketization = use_bucketization
226+
if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
227+
self.use_bucketization = True
228+
else:
229+
self.use_bucketization = use_bucketization
230+
224231
self.should_quantize = should_quantize
225232

226233
self._grads: Dict[str, torch.Tensor] = {}
@@ -535,14 +542,9 @@ def _bucketize_and_allreduce(
535542
def callback(
536543
fut: torch.futures.Future[list[torch.Tensor]],
537544
) -> list[torch.Tensor]:
538-
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
539-
nonlocal bucket_tensors, flat_buffer
540-
# Setup stream dependency
541-
fut.wait()
542-
for t, pack_offset, numel in bucket_tensors:
543-
t.copy_(
544-
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
545-
)
545+
nonlocal bucket_tensors, flat_buffer
546+
for t, pack_offset, numel in bucket_tensors:
547+
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
546548

547549
return []
548550

0 commit comments

Comments
 (0)