Skip to content

Commit 5ccf430

Browse files
authored
cleanup: replace namedtuple with dataclass - part 2 (#6003)
This is a follow-up PR after #5998.
1 parent dc06f6b commit 5ccf430

File tree

2 files changed

+125
-111
lines changed

2 files changed

+125
-111
lines changed

tensorboard/manager.py

Lines changed: 118 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -16,64 +16,45 @@
1616

1717

1818
import base64
19-
import collections
19+
import dataclasses
2020
import datetime
2121
import errno
2222
import json
2323
import os
2424
import subprocess
2525
import tempfile
2626
import time
27+
import typing
2728

29+
from typing import Optional
2830

2931
from tensorboard import version
3032
from tensorboard.util import tb_logging
3133

3234

33-
# Type descriptors for `TensorBoardInfo` fields.
34-
#
35-
# We represent timestamps as int-seconds-since-epoch rather than
36-
# datetime objects to work around a bug in Python on Windows. See:
37-
# https://github.com/tensorflow/tensorboard/issues/2017.
38-
_FieldType = collections.namedtuple(
39-
"_FieldType",
40-
(
41-
"serialized_type",
42-
"runtime_type",
43-
"serialize",
44-
"deserialize",
45-
),
46-
)
47-
_type_int = _FieldType(
48-
serialized_type=int,
49-
runtime_type=int,
50-
serialize=lambda n: n,
51-
deserialize=lambda n: n,
52-
)
53-
_type_str = _FieldType(
54-
serialized_type=str, # `json.loads` always gives Unicode
55-
runtime_type=str,
56-
serialize=str,
57-
deserialize=str,
58-
)
59-
60-
# Information about a running TensorBoard instance.
61-
_TENSORBOARD_INFO_FIELDS = collections.OrderedDict(
62-
(
63-
("version", _type_str),
64-
("start_time", _type_int), # seconds since epoch
65-
("pid", _type_int),
66-
("port", _type_int),
67-
("path_prefix", _type_str), # may be empty
68-
("logdir", _type_str), # may be empty
69-
("db", _type_str), # may be empty
70-
("cache_key", _type_str), # opaque, as given by `cache_key` below
71-
)
72-
)
73-
TensorBoardInfo = collections.namedtuple(
74-
"TensorBoardInfo",
75-
_TENSORBOARD_INFO_FIELDS,
76-
)
35+
@dataclasses.dataclass(frozen=True)
36+
class TensorBoardInfo:
37+
"""Holds the information about a running TensorBoard instance.
38+
39+
Attributes:
40+
version: Version of the running TensorBoard.
41+
start_time: Seconds since epoch.
42+
pid: ID of the process running TensorBoard.
43+
port: Port on which TensorBoard is running.
44+
path_prefix: Relative prefix to the path, may be empty.
45+
logdir: Data location used by the TensorBoard server, may be empty.
46+
db: Database connection used by the TensorBoard server, may be empty.
47+
cache_key: Opaque, as given by `cache_key` below.
48+
"""
49+
50+
version: str
51+
start_time: int
52+
pid: int
53+
port: int
54+
path_prefix: str
55+
logdir: str
56+
db: str
57+
cache_key: str
7758

7859

7960
def data_source_from_info(info):
@@ -107,22 +88,19 @@ def _info_to_string(info):
10788
Returns:
10889
A string representation of the provided `TensorBoardInfo`.
10990
"""
110-
for key in _TENSORBOARD_INFO_FIELDS:
111-
field_type = _TENSORBOARD_INFO_FIELDS[key]
112-
if not isinstance(getattr(info, key), field_type.runtime_type):
91+
field_name_to_type = typing.get_type_hints(TensorBoardInfo)
92+
for key, field_type in field_name_to_type.items():
93+
if not isinstance(getattr(info, key), field_type):
11394
raise ValueError(
11495
"expected %r of type %s, but found: %r"
115-
% (key, field_type.runtime_type, getattr(info, key))
96+
% (key, field_type, getattr(info, key))
11697
)
11798
if info.version != version.VERSION:
11899
raise ValueError(
119100
"expected 'version' to be %r, but found: %r"
120101
% (version.VERSION, info.version)
121102
)
122-
json_value = {
123-
k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k))
124-
for k in _TENSORBOARD_INFO_FIELDS
125-
}
103+
json_value = dataclasses.asdict(info)
126104
return json.dumps(json_value, sort_keys=True, indent=4)
127105

128106

@@ -140,14 +118,14 @@ def _info_from_string(info_string):
140118
ValueError: If the provided string is not valid JSON, or if it is
141119
missing any required fields, or if any field is of incorrect type.
142120
"""
143-
121+
field_name_to_type = typing.get_type_hints(TensorBoardInfo)
144122
try:
145123
json_value = json.loads(info_string)
146124
except ValueError:
147125
raise ValueError("invalid JSON: %r" % (info_string,))
148126
if not isinstance(json_value, dict):
149127
raise ValueError("not a JSON object: %r" % (json_value,))
150-
expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS)
128+
expected_keys = frozenset(field_name_to_type.keys())
151129
actual_keys = frozenset(json_value)
152130
missing_keys = expected_keys - actual_keys
153131
if missing_keys:
@@ -158,14 +136,13 @@ def _info_from_string(info_string):
158136

159137
# Validate and deserialize fields.
160138
fields = {}
161-
for key in _TENSORBOARD_INFO_FIELDS:
162-
field_type = _TENSORBOARD_INFO_FIELDS[key]
163-
if not isinstance(json_value[key], field_type.serialized_type):
139+
for key, field_type in field_name_to_type.items():
140+
if not isinstance(json_value[key], field_type):
164141
raise ValueError(
165142
"expected %r of type %s, but found: %r"
166-
% (key, field_type.serialized_type, json_value[key])
143+
% (key, field_type, json_value[key])
167144
)
168-
fields[key] = field_type.deserialize(json_value[key])
145+
fields[key] = json_value[key]
169146

170147
return TensorBoardInfo(**fields)
171148

@@ -325,50 +302,87 @@ def get_all():
325302
return results
326303

327304

328-
# The following five types enumerate the possible return values of the
329-
# `start` function.
330-
331-
# Indicates that a call to `start` was compatible with an existing
332-
# TensorBoard process, which can be reused according to the provided
333-
# info.
334-
StartReused = collections.namedtuple("StartReused", ("info",))
335-
336-
# Indicates that a call to `start` successfully launched a new
337-
# TensorBoard process, which is available with the provided info.
338-
StartLaunched = collections.namedtuple("StartLaunched", ("info",))
339-
340-
# Indicates that a call to `start` tried to launch a new TensorBoard
341-
# instance, but the subprocess exited with the given exit code and
342-
# output streams. (If the contents of the output streams are no longer
343-
# available---e.g., because the user has emptied /tmp/---then the
344-
# corresponding values will be `None`.)
345-
StartFailed = collections.namedtuple(
346-
"StartFailed",
347-
(
348-
"exit_code", # int, as `Popen.returncode` (negative for signal)
349-
"stdout", # str, or `None` if the stream could not be read
350-
"stderr", # str, or `None` if the stream could not be read
351-
),
352-
)
353-
354-
# Indicates that a call to `start` failed to invoke the subprocess.
355-
#
356-
# If the TensorBoard executable was chosen via the `TENSORBOARD_BINARY`
357-
# environment variable, then the `explicit_binary` field contains the
358-
# path to that binary; otherwise, the field is `None`.
359-
StartExecFailed = collections.namedtuple(
360-
"StartExecFailed",
361-
(
362-
"os_error", # `OSError` due to `Popen` invocation
363-
"explicit_binary", # `str` or `None`; see type-level comment
364-
),
365-
)
366-
367-
# Indicates that a call to `start` launched a TensorBoard process, but
368-
# that process neither exited nor wrote its info file within the allowed
369-
# timeout period. The process may still be running under the included
370-
# PID.
371-
StartTimedOut = collections.namedtuple("StartTimedOut", ("pid",))
305+
@dataclasses.dataclass(frozen=True)
306+
class StartReused:
307+
"""Possible return value of the `start` function.
308+
309+
Indicates that a call to `start` was compatible with an existing
310+
TensorBoard process, which can be reused according to the provided
311+
info.
312+
313+
Attributes:
314+
info: A `TensorBoardInfo` object.
315+
"""
316+
317+
info: TensorBoardInfo
318+
319+
320+
@dataclasses.dataclass(frozen=True)
321+
class StartLaunched:
322+
"""Possible return value of the `start` function.
323+
324+
Indicates that a call to `start` successfully launched a new
325+
TensorBoard process, which is available with the provided info.
326+
327+
Attributes:
328+
info: A `TensorBoardInfo` object.
329+
"""
330+
331+
info: TensorBoardInfo
332+
333+
334+
@dataclasses.dataclass(frozen=True)
335+
class StartFailed:
336+
"""Possible return value of the `start` function.
337+
338+
Indicates that a call to `start` tried to launch a new TensorBoard
339+
instance, but the subprocess exited with the given exit code and
340+
output streams. (If the contents of the output streams are no longer
341+
available---e.g., because the user has emptied /tmp/---then the
342+
corresponding values will be `None`.)
343+
344+
Attributes:
345+
exit_code: As `Popen.returncode` (negative for signal).
346+
stdout: Error message to stdout if the stream could not be read.
347+
stderr: Error message to stderr if the stream could not be read.
348+
"""
349+
350+
exit_code: int
351+
stdout: Optional[str]
352+
stderr: Optional[str]
353+
354+
355+
@dataclasses.dataclass(frozen=True)
356+
class StartExecFailed:
357+
"""Possible return value of the `start` function.
358+
359+
Indicates that a call to `start` failed to invoke the subprocess.
360+
361+
Attributes:
362+
os_error: `OSError` due to `Popen` invocation.
363+
explicit_binary: If the TensorBoard executable was chosen via the
364+
`TENSORBOARD_BINARY` environment variable, then this field contains
365+
the path to that binary; otherwise `None`.
366+
"""
367+
368+
os_error: OSError
369+
explicit_binary: Optional[str]
370+
371+
372+
@dataclasses.dataclass(frozen=True)
373+
class StartTimedOut:
374+
"""Possible return value of the `start` function.
375+
376+
Indicates that a call to `start` launched a TensorBoard process, but
377+
that process neither exited nor wrote its info file within the allowed
378+
timeout period. The process may still be running under the included
379+
PID.
380+
381+
Attributes:
382+
pid: ID of the process running TensorBoard.
383+
"""
384+
385+
pid: int
372386

373387

374388
def start(arguments, timeout=datetime.timedelta(seconds=60)):

tensorboard/manager_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
"""Unit tests for `tensorboard.manager`."""
1616

17-
17+
import dataclasses
1818
import datetime
1919
import errno
2020
import json
@@ -63,15 +63,15 @@ def test_roundtrip_serialization(self):
6363

6464
def test_serialization_rejects_bad_types(self):
6565
bad_time = datetime.datetime.fromtimestamp(1549061116) # not an int
66-
info = _make_info()._replace(start_time=bad_time)
66+
info = dataclasses.replace(_make_info(), start_time=bad_time)
6767
with self.assertRaisesRegex(
6868
ValueError,
6969
r"expected 'start_time' of type.*int.*, but found: datetime\.",
7070
):
7171
manager._info_to_string(info)
7272

7373
def test_serialization_rejects_wrong_version(self):
74-
info = _make_info()._replace(version="reversion")
74+
info = dataclasses.replace(_make_info(), version="reversion")
7575
with self.assertRaisesRegex(
7676
ValueError,
7777
"expected 'version' to be '.*', but found: 'reversion'",
@@ -139,11 +139,11 @@ def test_deserialization_rejects_bad_types(self):
139139
manager._info_from_string(bad_input)
140140

141141
def test_logdir_data_source_format(self):
142-
info = _make_info()._replace(logdir="~/foo", db="")
142+
info = dataclasses.replace(_make_info(), logdir="~/foo", db="")
143143
self.assertEqual(manager.data_source_from_info(info), "logdir ~/foo")
144144

145145
def test_db_data_source_format(self):
146-
info = _make_info()._replace(logdir="", db="sqlite:~/bar")
146+
info = dataclasses.replace(_make_info(), logdir="", db="sqlite:~/bar")
147147
self.assertEqual(manager.data_source_from_info(info), "db sqlite:~/bar")
148148

149149

@@ -327,7 +327,7 @@ def test_write_info_file_rejects_bad_types(self):
327327
# The particulars of validation are tested more thoroughly in
328328
# `TensorBoardInfoTest` above.
329329
bad_time = datetime.datetime.fromtimestamp(1549061116)
330-
info = _make_info()._replace(start_time=bad_time)
330+
info = dataclasses.replace(_make_info(), start_time=bad_time)
331331
with self.assertRaisesRegex(
332332
ValueError,
333333
r"expected 'start_time' of type.*int.*, but found: datetime\.",
@@ -338,7 +338,7 @@ def test_write_info_file_rejects_bad_types(self):
338338
def test_write_info_file_rejects_wrong_version(self):
339339
# The particulars of validation are tested more thoroughly in
340340
# `TensorBoardInfoTest` above.
341-
info = _make_info()._replace(version="reversion")
341+
info = dataclasses.replace(_make_info(), version="reversion")
342342
with self.assertRaisesRegex(
343343
ValueError,
344344
"expected 'version' to be '.*', but found: 'reversion'",

0 commit comments

Comments
 (0)