Skip to content

Commit db8367b

Browse files
rkoblerfacebook-github-bot
authored andcommitted
Add log_histogram_raw to TensorBoardLogger (#1006)
Summary: Pull Request resolved: #1006 The `SummaryWriter` supports `log_histogram_raw`. Enabling this feature in torchtnt will allow logging of pre-computed histograms. Reviewed By: alanhdu Differential Revision: D75455963 fbshipit-source-id: 000550ce9072093dc61e22735bd8734d30434d7c
1 parent f7cf5b7 commit db8367b

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

tests/utils/loggers/test_tensorboard.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,61 @@ def test_log_dict(self: TensorBoardLoggerTest) -> None:
8181
)
8282
self.assertEqual(tensor_tag.step, 1)
8383

84+
def test_log_histogram_raw(self: TensorBoardLoggerTest) -> None:
85+
with tempfile.TemporaryDirectory() as log_dir:
86+
logger = TensorBoardLogger(path=log_dir)
87+
88+
# generate a histogram with 4 bins in the range [0, 1]
89+
data_range = [0.0, 1.0]
90+
bucket_counts = [1, 3, 5, 4]
91+
bucket_width = (data_range[1] - data_range[0]) / len(bucket_counts)
92+
bucket_limits = [
93+
ix * bucket_width + data_range[0]
94+
for ix in range(len(bucket_counts) + 1)
95+
]
96+
bucket_centers = [
97+
(lower + upper) / 2
98+
for lower, upper in zip(bucket_limits[:-1], bucket_limits[1:])
99+
]
100+
# sum of the binned values
101+
value_sum = float(
102+
sum(
103+
value * count for value, count in zip(bucket_centers, bucket_counts)
104+
)
105+
)
106+
107+
logger.log_histogram_raw(
108+
"histogram_raw",
109+
min=0,
110+
max=1,
111+
num=sum(bucket_counts),
112+
sum=value_sum,
113+
sum_squares=value_sum**2,
114+
bucket_limits=bucket_limits,
115+
# add an extra leading 0 to match the format of the histogram_raw
116+
bucket_counts=[0] + bucket_counts,
117+
)
118+
logger.close()
119+
120+
acc = EventAccumulator(log_dir)
121+
acc.Reload()
122+
123+
# check that the histogram is logged correctly
124+
self.assertIn("histogram_raw", acc.Tags()["histograms"])
125+
# ensure that we logged exactly one histogram
126+
self.assertEqual(len(acc.Histograms("histogram_raw")), 1)
127+
histogram_event = acc.Histograms("histogram_raw")[0]
128+
histogram_value = histogram_event.histogram_value
129+
# check that the histogram is logged correctly
130+
self.assertEqual(histogram_value.min, 0)
131+
self.assertEqual(histogram_value.max, 1)
132+
self.assertEqual(histogram_value.num, sum(bucket_counts))
133+
self.assertEqual(histogram_value.sum, value_sum)
134+
self.assertEqual(histogram_value.sum_squares, value_sum**2)
135+
self.assertListEqual(histogram_value.bucket_limit, bucket_limits)
136+
self.assertListEqual(histogram_value.bucket[1:], bucket_counts)
137+
self.assertEqual(histogram_value.bucket[0], 0)
138+
84139
def test_log_text(self: TensorBoardLoggerTest) -> None:
85140
with tempfile.TemporaryDirectory() as log_dir:
86141
logger = TensorBoardLogger(path=log_dir)

torchtnt/utils/loggers/tensorboard.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def log_scalars(
210210
)
211211

212212
def log_histogram(self: TensorBoardLogger, *args: Any, **kwargs: Any) -> None:
213-
"""Add histogram to TensorBoard.
213+
"""Compute and add histogram to TensorBoard.
214214
215215
Args:
216216
*args (Any): Positional arguments passed to SummaryWriter.add_histogram
@@ -219,6 +219,16 @@ def log_histogram(self: TensorBoardLogger, *args: Any, **kwargs: Any) -> None:
219219
if self._writer:
220220
self._writer.add_histogram(*args, **kwargs)
221221

222+
def log_histogram_raw(self: TensorBoardLogger, *args: Any, **kwargs: Any) -> None:
223+
"""Add pre-computed histogram to TensorBoard.
224+
225+
Args:
226+
*args (Any): Positional arguments passed to SummaryWriter.add_histogram_raw
227+
**kwargs(Any): Keyword arguments passed to SummaryWriter.add_histogram_raw
228+
"""
229+
if self._writer:
230+
self._writer.add_histogram_raw(*args, **kwargs)
231+
222232
def flush(self: TensorBoardLogger) -> None:
223233
"""Writes pending logs to disk."""
224234

0 commit comments

Comments
 (0)