Skip to content

Commit 82a6b62

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
remove path sync in TensorBoardLogger
Reviewed By: galrotem Differential Revision: D56645859 fbshipit-source-id: b14b03c2440876b50e835d4df4f8ff48658451e2
1 parent 0159a07 commit 82a6b62

File tree

2 files changed

+1
-42
lines changed

2 files changed

+1
-42
lines changed

tests/utils/loggers/test_tensorboard.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,13 @@
99

1010
from __future__ import annotations
1111

12-
import os
1312
import tempfile
1413
import unittest
1514
from unittest.mock import Mock, patch
1615

17-
import torch.distributed.launcher as launcher
1816
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
19-
from torch import distributed as dist
2017

2118
from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
22-
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed
2319

2420

2521
class TensorBoardLoggerTest(unittest.TestCase):
@@ -74,26 +70,6 @@ def test_log_rank_zero(self: TensorBoardLoggerTest) -> None:
7470
logger = TensorBoardLogger(path=log_dir)
7571
self.assertEqual(logger.writer, None)
7672

77-
@staticmethod
78-
def _test_distributed() -> None:
79-
dist.init_process_group("gloo")
80-
rank = dist.get_rank()
81-
with tempfile.TemporaryDirectory() as log_dir:
82-
test_path = "correct"
83-
invalid_path = "invalid"
84-
if rank == 0:
85-
logger = TensorBoardLogger(os.path.join(log_dir, test_path))
86-
else:
87-
logger = TensorBoardLogger(os.path.join(log_dir, invalid_path))
88-
89-
assert test_path in logger.path
90-
assert invalid_path not in logger.path
91-
92-
@skip_if_not_distributed
93-
def test_multiple_workers(self: TensorBoardLoggerTest) -> None:
94-
config = get_pet_launch_config(2)
95-
launcher.elastic_launch(config, entrypoint=self._test_distributed)()
96-
9773
def test_add_scalars_call_is_correctly_passed_to_summary_writer(
9874
self: TensorBoardLoggerTest,
9975
) -> None:

torchtnt/utils/loggers/tensorboard.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ class TensorBoardLogger(MetricLogger):
5353

5454
def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None:
5555
self._writer: Optional[SummaryWriter] = None
56-
56+
self._path: str = path
5757
self._rank: int = get_global_rank()
58-
self._sync_path_to_workers(path)
5958

6059
if self._rank == 0:
6160
logger.info(
@@ -69,22 +68,6 @@ def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> N
6968

7069
atexit.register(self.close)
7170

72-
def _sync_path_to_workers(self: TensorBoardLogger, path: str) -> None:
73-
if not (dist.is_available() and dist.is_initialized()):
74-
self._path: str = path
75-
return
76-
77-
pg = PGWrapper(dist.group.WORLD)
78-
path_container: List[str] = [path] if self._rank == 0 else [""]
79-
pg.broadcast_object_list(path_container, 0)
80-
updated_path = path_container[0]
81-
if updated_path != path:
82-
# because the logger only logs on rank 0, if users pass in a different path
83-
# the logger will output the wrong `path` property, so we update it to match
84-
# the correct path.
85-
logger.info(f"Updating TensorBoard path to match rank 0: {updated_path}")
86-
self._path: str = updated_path
87-
8871
@property
8972
def writer(self: TensorBoardLogger) -> Optional[SummaryWriter]:
9073
return self._writer

0 commit comments

Comments
 (0)