Skip to content

Commit 2599057

Browse files
authored
Source Metadata: Parse source writer information from the first event of each run (#6014)
Motivation: Googlers, see go/tb-writer-source-metadata for context. Note that: - The current assumption is that each run would have at most one source writer. - The maximum allowed length for the writer name is set to 128. - For simplicity, only source writer name rather than the entire `SourceMetadata` proto is parsed (since this is the only field). - In the future if more fields are added, the `SourceWriter` property can be easily generalized to `SourceMetadata` in the `EventAccumulator`.
1 parent df448ad commit 2599057

11 files changed

+401
-51
lines changed

tensorboard/backend/event_processing/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,30 @@ py_test(
208208
],
209209
)
210210

211+
py_library(
212+
name = "event_util",
213+
srcs = ["event_util.py"],
214+
srcs_version = "PY3",
215+
visibility = ["//visibility:public"],
216+
deps = [
217+
"//tensorboard/compat/proto:protos_all_py_pb2",
218+
"//tensorboard/util:tb_logging",
219+
],
220+
)
221+
222+
py_test(
223+
name = "event_util_test",
224+
size = "small",
225+
srcs = ["event_util_test.py"],
226+
srcs_version = "PY3",
227+
deps = [
228+
":event_util",
229+
"//tensorboard:test",
230+
"//tensorboard/compat/proto:protos_all_py_pb2",
231+
"//tensorboard/util:tb_logging",
232+
],
233+
)
234+
211235
py_library(
212236
name = "tag_types",
213237
srcs = ["tag_types.py"],
@@ -226,6 +250,7 @@ py_library(
226250
":directory_loader",
227251
":directory_watcher",
228252
":event_file_loader",
253+
":event_util",
229254
":io_wrapper",
230255
":plugin_asset_util",
231256
":reservoir",

tensorboard/backend/event_processing/event_accumulator.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
import collections
1818
import threading
1919

20+
from typing import Optional
21+
2022
from tensorboard.backend.event_processing import directory_watcher
2123
from tensorboard.backend.event_processing import event_file_loader
24+
from tensorboard.backend.event_processing import event_util
2225
from tensorboard.backend.event_processing import io_wrapper
2326
from tensorboard.backend.event_processing import plugin_asset_util
2427
from tensorboard.backend.event_processing import reservoir
@@ -224,6 +227,9 @@ def __init__(
224227
self.most_recent_wall_time = -1
225228
self.file_version = None
226229

230+
# Name of the source writer that writes the event.
231+
self._source_writer = None
232+
227233
# The attributes that get built up by the accumulator
228234
self.accumulated_attrs = (
229235
"scalars",
@@ -302,6 +308,21 @@ def FirstEventTimestamp(self):
302308
except StopIteration:
303309
raise ValueError("No event timestamp could be found")
304310

311+
@property
312+
def SourceWriter(self) -> Optional[str]:
313+
"""Returns the name of the event writer."""
314+
if self._source_writer is not None:
315+
return self._source_writer
316+
with self._generator_mutex:
317+
try:
318+
event = next(self._generator.Load())
319+
self._ProcessEvent(event)
320+
return self._source_writer
321+
except StopIteration:
322+
logger.info(
323+
"End of file in %s, no source writer was found.", self.path
324+
)
325+
305326
def PluginTagToContent(self, plugin_name):
306327
"""Returns a dict mapping tags to content specific to that plugin.
307328
@@ -339,8 +360,22 @@ def _ProcessEvent(self, event):
339360
if self._first_event_timestamp is None:
340361
self._first_event_timestamp = event.wall_time
341362

363+
if event.HasField("source_metadata"):
364+
new_source_writer = event_util.GetSourceWriter(
365+
event.source_metadata
366+
)
367+
if self._source_writer and self._source_writer != new_source_writer:
368+
# This should not happen.
369+
logger.warning(
370+
(
371+
"Found new source writer for event.proto. "
372+
"Old: {0}, New: {1}"
373+
).format(self._source_writer, new_source_writer)
374+
)
375+
self._source_writer = new_source_writer
376+
342377
if event.HasField("file_version"):
343-
new_file_version = _ParseFileVersion(event.file_version)
378+
new_file_version = event_util.ParseFileVersion(event.file_version)
344379
if self.file_version and self.file_version != new_file_version:
345380
## This should not happen.
346381
logger.warning(
@@ -824,27 +859,3 @@ def _GeneratorFromPath(path):
824859
event_file_loader.LegacyEventFileLoader,
825860
io_wrapper.IsSummaryEventsFile,
826861
)
827-
828-
829-
def _ParseFileVersion(file_version):
830-
"""Convert the string file_version in event.proto into a float.
831-
832-
Args:
833-
file_version: String file_version from event.proto
834-
835-
Returns:
836-
Version number as a float.
837-
"""
838-
tokens = file_version.split("brain.Event:")
839-
try:
840-
return float(tokens[-1])
841-
except ValueError:
842-
## This should never happen according to the definition of file_version
843-
## specified in event.proto.
844-
logger.warning(
845-
(
846-
"Invalid event.proto file_version. Defaulting to use of "
847-
"out-of-order event.step logic for purging expired events."
848-
)
849-
)
850-
return -1

tensorboard/backend/event_processing/event_accumulator_test.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,65 @@ def testFirstEventTimestampLoadsEvent(self):
709709
gen.AddEvent(
710710
event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
711711
)
712-
713712
self.assertEqual(acc.FirstEventTimestamp(), 1)
714713
acc.Reload()
715714
self.assertEqual(acc.file_version, 2.0)
716715

716+
def testSourceWriter(self):
717+
gen = _EventGenerator(self)
718+
acc = ea.EventAccumulator(gen)
719+
gen.AddEvent(
720+
event_pb2.Event(
721+
wall_time=10,
722+
step=20,
723+
source_metadata=event_pb2.SourceMetadata(
724+
writer="custom_writer"
725+
),
726+
)
727+
)
728+
gen.AddScalar("s1", wall_time=30, step=40, value=20)
729+
self.assertEqual(acc.SourceWriter, "custom_writer")
730+
731+
def testReloadPopulatesSourceWriter(self):
732+
"""Test that Reload() means SourceWriter won't load events."""
733+
gen = _EventGenerator(self)
734+
acc = ea.EventAccumulator(gen)
735+
gen.AddEvent(
736+
event_pb2.Event(
737+
wall_time=1,
738+
step=2,
739+
source_metadata=event_pb2.SourceMetadata(
740+
writer="custom_writer"
741+
),
742+
)
743+
)
744+
acc.Reload()
745+
746+
def _Die(*args, **kwargs): # pylint: disable=unused-argument
747+
raise RuntimeError("Load() should not be called")
748+
749+
self.stubs.Set(gen, "Load", _Die)
750+
self.assertEqual(acc.SourceWriter, "custom_writer")
751+
752+
def testSourceWriterLoadsEvent(self):
753+
"""Test that SourceWriter doesn't discard the loaded event."""
754+
gen = _EventGenerator(self)
755+
acc = ea.EventAccumulator(gen)
756+
gen.AddEvent(
757+
event_pb2.Event(
758+
wall_time=1,
759+
step=2,
760+
file_version="brain.Event:2",
761+
source_metadata=event_pb2.SourceMetadata(
762+
writer="custom_writer"
763+
),
764+
)
765+
)
766+
767+
self.assertEqual(acc.SourceWriter, "custom_writer")
768+
acc.Reload()
769+
self.assertEqual(acc.file_version, 2.0)
770+
717771
def testTFSummaryScalar(self):
718772
"""Verify processing of tf.summary.scalar."""
719773
event_sink = _EventGenerator(self, zero_out_timestamps=True)

tensorboard/backend/event_processing/event_multiplexer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import threading
2020

21+
from typing import Optional
2122

2223
from tensorboard.backend.event_processing import directory_watcher
2324
from tensorboard.backend.event_processing import event_accumulator
@@ -258,6 +259,21 @@ def FirstEventTimestamp(self, run):
258259
accumulator = self.GetAccumulator(run)
259260
return accumulator.FirstEventTimestamp()
260261

262+
def GetSourceWriter(self, run) -> Optional[str]:
263+
"""Returns the source writer name from the first event of the given run.
264+
265+
Assuming each run has only one source writer.
266+
267+
Args:
268+
run: A string name of the run from which the event source information
269+
is retrieved.
270+
271+
Returns:
272+
Name of the writer that wrote the events in the run.
273+
"""
274+
accumulator = self.GetAccumulator(run)
275+
return accumulator.SourceWriter
276+
261277
def Scalars(self, run, tag):
262278
"""Retrieve the scalar events associated with a run and tag.
263279

tensorboard/backend/event_processing/event_multiplexer_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def Tags(self):
6767
def FirstEventTimestamp(self):
6868
return 0
6969

70+
@property
71+
def SourceWriter(self):
72+
return "%s_writer" % self._path
73+
7074
def _TagHelper(self, tag_name, enum):
7175
if tag_name not in self.Tags()[enum]:
7276
raise KeyError
@@ -151,6 +155,13 @@ def testReload(self):
151155
self.assertTrue(x.GetAccumulator("run1").reload_called)
152156
self.assertTrue(x.GetAccumulator("run2").reload_called)
153157

158+
def testGetSourceWriter(self):
159+
x = event_multiplexer.EventMultiplexer(
160+
{"run1": "path1", "run2": "path2"}
161+
)
162+
self.assertEqual(x.GetSourceWriter("run1"), "path1_writer")
163+
self.assertEqual(x.GetSourceWriter("run2"), "path2_writer")
164+
154165
def testScalars(self):
155166
"""Tests Scalars function returns suitable values."""
156167
x = event_multiplexer.EventMultiplexer(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Functionality for processing events."""
17+
18+
from typing import Optional
19+
20+
from tensorboard.compat.proto import event_pb2
21+
from tensorboard.util import tb_logging
22+
23+
logger = tb_logging.get_logger()
24+
25+
# Maxmimum length for event writer name.
26+
_MAX_WRITER_NAME_LEN = 128
27+
28+
29+
def ParseFileVersion(file_version: str) -> float:
30+
"""Convert the string file_version in event.proto into a float.
31+
32+
Args:
33+
file_version: String file_version from event.proto
34+
35+
Returns:
36+
Version number as a float.
37+
"""
38+
tokens = file_version.split("brain.Event:")
39+
try:
40+
return float(tokens[-1])
41+
except ValueError:
42+
## This should never happen according to the definition of file_version
43+
## specified in event.proto.
44+
logger.warning(
45+
(
46+
"Invalid event.proto file_version. Defaulting to use of "
47+
"out-of-order event.step logic for purging expired events."
48+
)
49+
)
50+
return -1
51+
52+
53+
def GetSourceWriter(
54+
source_metadata: event_pb2.SourceMetadata,
55+
) -> Optional[str]:
56+
"""Gets the source writer name from the source metadata proto."""
57+
writer_name = source_metadata.writer
58+
if not writer_name:
59+
return None
60+
# Checks the length of the writer name.
61+
if len(writer_name) > _MAX_WRITER_NAME_LEN:
62+
logger.error(
63+
"Source writer name `%s` is too long, maximum allowed length is %d.",
64+
writer_name,
65+
_MAX_WRITER_NAME_LEN,
66+
)
67+
return None
68+
return writer_name
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for event_util."""
17+
18+
from unittest import mock
19+
20+
from tensorboard import test as tb_test
21+
from tensorboard.backend.event_processing import event_util
22+
from tensorboard.compat.proto import event_pb2
23+
from tensorboard.util import tb_logging
24+
25+
logger = tb_logging.get_logger()
26+
27+
28+
class EventUtilTest(tb_test.TestCase):
29+
def testParseFileVersion_success(self):
30+
self.assertEqual(event_util.ParseFileVersion("brain.Event:1.0"), 1.0)
31+
32+
def testParseFileVersion_invalidFileVersion(self):
33+
with mock.patch.object(
34+
logger, "warning", autospec=True, spec_set=True
35+
) as mock_log:
36+
version = event_util.ParseFileVersion("invalid")
37+
self.assertEqual(version, -1)
38+
mock_log.assert_called_once_with(
39+
"Invalid event.proto file_version. Defaulting to use of "
40+
"out-of-order event.step logic for purging expired events."
41+
)
42+
43+
def testGetSourceWriter_success(self):
44+
expected_writer = "tensorboard.summary.writer.event_file_writer"
45+
actual_writer = event_util.GetSourceWriter(
46+
event_pb2.SourceMetadata(writer=expected_writer)
47+
)
48+
self.assertEqual(actual_writer, expected_writer)
49+
50+
def testGetSourceWriter_noWriter(self):
51+
actual_writer = event_util.GetSourceWriter(
52+
event_pb2.SourceMetadata(writer="")
53+
)
54+
self.assertIsNone(actual_writer)
55+
56+
def testGetSourceWriter_writerNameTooLong(self):
57+
long_writer_name = "a" * (event_util._MAX_WRITER_NAME_LEN + 1)
58+
with mock.patch.object(
59+
logger, "error", autospec=True, spec_set=True
60+
) as mock_log:
61+
actual_writer = event_util.GetSourceWriter(
62+
event_pb2.SourceMetadata(writer=long_writer_name)
63+
)
64+
self.assertIsNone(actual_writer)
65+
mock_log.assert_called_once_with(
66+
"Source writer name `%s` is too long, maximum allowed length is %d.",
67+
long_writer_name,
68+
event_util._MAX_WRITER_NAME_LEN,
69+
)
70+
71+
72+
if __name__ == "__main__":
73+
tb_test.main()

0 commit comments

Comments
 (0)