Skip to content

Commit dc06f6b

Browse files
authored
cleanup: replace namedtuple with dataclass - part 1 (#5998)
See #5725 for context. Namedtuples used in `manager.py` is removed in a #6003.
1 parent 9696d8b commit dc06f6b

File tree

14 files changed

+370
-134
lines changed

14 files changed

+370
-134
lines changed

tensorboard/backend/event_processing/event_accumulator.py

Lines changed: 131 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
"""Takes a generator of values, and accumulates them for a frontend."""
1616

1717
import collections
18+
import dataclasses
1819
import threading
1920

21+
from typing import Sequence, Tuple
22+
2023
from tensorboard.backend.event_processing import directory_watcher
2124
from tensorboard.backend.event_processing import event_file_loader
2225
from tensorboard.backend.event_processing import io_wrapper
@@ -27,47 +30,137 @@
2730
from tensorboard.compat.proto import event_pb2
2831
from tensorboard.compat.proto import graph_pb2
2932
from tensorboard.compat.proto import meta_graph_pb2
33+
from tensorboard.compat.proto import tensor_pb2
3034
from tensorboard.plugins.distribution import compressor
3135
from tensorboard.util import tb_logging
3236

3337

3438
logger = tb_logging.get_logger()
3539

36-
namedtuple = collections.namedtuple
37-
ScalarEvent = namedtuple("ScalarEvent", ["wall_time", "step", "value"])
38-
39-
CompressedHistogramEvent = namedtuple(
40-
"CompressedHistogramEvent",
41-
["wall_time", "step", "compressed_histogram_values"],
42-
)
43-
44-
HistogramEvent = namedtuple(
45-
"HistogramEvent", ["wall_time", "step", "histogram_value"]
46-
)
47-
48-
HistogramValue = namedtuple(
49-
"HistogramValue",
50-
["min", "max", "num", "sum", "sum_squares", "bucket_limit", "bucket"],
51-
)
52-
53-
ImageEvent = namedtuple(
54-
"ImageEvent",
55-
["wall_time", "step", "encoded_image_string", "width", "height"],
56-
)
57-
58-
AudioEvent = namedtuple(
59-
"AudioEvent",
60-
[
61-
"wall_time",
62-
"step",
63-
"encoded_audio_string",
64-
"content_type",
65-
"sample_rate",
66-
"length_frames",
67-
],
68-
)
69-
70-
TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"])
40+
41+
@dataclasses.dataclass(frozen=True)
42+
class ScalarEvent:
43+
"""Contains information of a scalar event.
44+
45+
Attributes:
46+
wall_time: Timestamp of the event in seconds.
47+
step: Global step of the event.
48+
value: A float or int value of the scalar.
49+
"""
50+
51+
wall_time: float
52+
step: int
53+
value: float
54+
55+
56+
@dataclasses.dataclass(frozen=True)
57+
class CompressedHistogramEvent:
58+
"""Contains information of a compressed histogram event.
59+
60+
Attributes:
61+
wall_time: Timestamp of the event in seconds.
62+
step: Global step of the event.
63+
compressed_histogram_values: A sequence of tuples of basis points and
64+
associated values in a compressed histogram.
65+
"""
66+
67+
wall_time: float
68+
step: int
69+
compressed_histogram_values: Sequence[Tuple[float, float]]
70+
71+
72+
@dataclasses.dataclass(frozen=True)
73+
class HistogramValue:
74+
"""Holds the information of the histogram values.
75+
76+
Attributes:
77+
min: A float or int min value.
78+
max: A float or int max value.
79+
num: Total number of values.
80+
sum: Sum of all values.
81+
sum_squares: Sum of squares for all values.
82+
bucket_limit: Upper values per bucket.
83+
bucket: Numbers of values per bucket.
84+
"""
85+
86+
min: float
87+
max: float
88+
num: int
89+
sum: float
90+
sum_squares: float
91+
bucket_limit: Sequence[float]
92+
bucket: Sequence[int]
93+
94+
95+
@dataclasses.dataclass(frozen=True)
96+
class HistogramEvent:
97+
"""Contains information of a histogram event.
98+
99+
Attributes:
100+
wall_time: Timestamp of the event in seconds.
101+
step: Global step of the event.
102+
histogram_value: Information of the histogram values.
103+
"""
104+
105+
wall_time: float
106+
step: int
107+
histogram_value: HistogramValue
108+
109+
110+
@dataclasses.dataclass(frozen=True)
111+
class ImageEvent:
112+
"""Contains information of an image event.
113+
114+
Attributes:
115+
wall_time: Timestamp of the event in seconds.
116+
step: Global step of the event.
117+
encoded_image_string: Image content encoded in bytes.
118+
width: Width of the image.
119+
height: Height of the image.
120+
"""
121+
122+
wall_time: float
123+
step: int
124+
encoded_image_string: bytes
125+
width: int
126+
height: int
127+
128+
129+
@dataclasses.dataclass(frozen=True)
130+
class AudioEvent:
131+
"""Contains information of an audio event.
132+
133+
Attributes:
134+
wall_time: Timestamp of the event in seconds.
135+
step: Global step of the event.
136+
encoded_audio_string: Audio content encoded in bytes.
137+
content_type: A string describes the type of the audio content.
138+
sample_rate: Sample rate of the audio in Hz. Must be positive.
139+
length_frames: Length of the audio in frames (samples per channel).
140+
"""
141+
142+
wall_time: float
143+
step: int
144+
encoded_audio_string: bytes
145+
content_type: str
146+
sample_rate: float
147+
length_frames: int
148+
149+
150+
@dataclasses.dataclass(frozen=True)
151+
class TensorEvent:
152+
"""A tensor event.
153+
154+
Attributes:
155+
wall_time: Timestamp of the event in seconds.
156+
step: Global step of the event.
157+
tensor_proto: A `TensorProto`.
158+
"""
159+
160+
wall_time: float
161+
step: int
162+
tensor_proto: tensor_pb2.TensorProto
163+
71164

72165
## Different types of summary events handled by the event_accumulator
73166
SUMMARY_TYPES = {
@@ -664,7 +757,8 @@ def _CheckForOutOfOrderStepAndMaybePurge(self, event):
664757
self.most_recent_step = event.step
665758
self.most_recent_wall_time = event.wall_time
666759

667-
def _ConvertHistogramProtoToTuple(self, histo):
760+
def _ConvertHistogramProtoToPopo(self, histo):
761+
"""Converts histogram proto to Python object."""
668762
return HistogramValue(
669763
min=histo.min,
670764
max=histo.max,
@@ -677,7 +771,7 @@ def _ConvertHistogramProtoToTuple(self, histo):
677771

678772
def _ProcessHistogram(self, tag, wall_time, step, histo):
679773
"""Processes a proto histogram by adding it to accumulated state."""
680-
histo = self._ConvertHistogramProtoToTuple(histo)
774+
histo = self._ConvertHistogramProtoToPopo(histo)
681775
histo_ev = HistogramEvent(wall_time, step, histo)
682776
self.histograms.AddItem(tag, histo_ev)
683777
self.compressed_histograms.AddItem(

tensorboard/backend/event_processing/event_accumulator_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def testCompressedHistograms(self):
374374

375375
# Create the expected values after compressing hst1
376376
expected_vals1 = [
377-
compressor.CompressedHistogramValue(bp, val)
377+
compressor.CompressedHistogramValue(bp, val).as_tuple()
378378
for bp, val in [
379379
(0, 1.0),
380380
(2500, 1.25),
@@ -390,7 +390,7 @@ def testCompressedHistograms(self):
390390

391391
# Create the expected values after compressing hst2
392392
expected_vals2 = [
393-
compressor.CompressedHistogramValue(bp, val)
393+
compressor.CompressedHistogramValue(bp, val).as_tuple()
394394
for bp, val in [
395395
(0, -2),
396396
(2500, 2),

tensorboard/backend/event_processing/event_file_inspector.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@
109109
"""
110110

111111

112-
import collections
112+
import dataclasses
113113
import itertools
114114
import os
115115

116+
from typing import Any, Generator, Mapping
117+
116118
from tensorboard.backend.event_processing import event_accumulator
117119
from tensorboard.backend.event_processing import event_file_loader
118120
from tensorboard.backend.event_processing import io_wrapper
@@ -145,25 +147,44 @@
145147
# All summary types that we can inspect.
146148
TRACKED_FIELDS = SHORT_FIELDS + LONG_FIELDS
147149

148-
# An `Observation` contains the data within each Event file that the inspector
149-
# cares about. The inspector accumulates Observations as it processes events.
150-
Observation = collections.namedtuple(
151-
"Observation", ["step", "wall_time", "tag"]
152-
)
150+
PRINT_SEPARATOR = "=" * 70 + "\n"
151+
153152

154-
# An InspectionUnit is created for each organizational structure in the event
155-
# files visible in the final terminal output. For instance, one InspectionUnit
156-
# is created for each subdirectory in logdir. When asked to inspect a single
157-
# event file, there may only be one InspectionUnit.
153+
@dataclasses.dataclass(frozen=True)
154+
class Observation:
155+
"""Contains the data within each Event file that the inspector cares about.
158156
159-
# The InspectionUnit contains the `name` of the organizational unit that will be
160-
# printed to console, a `generator` that yields `Event` protos, and a mapping
161-
# from string fields to `Observations` that the inspector creates.
162-
InspectionUnit = collections.namedtuple(
163-
"InspectionUnit", ["name", "generator", "field_to_obs"]
164-
)
157+
The inspector accumulates Observations as it processes events.
158+
159+
Attributes:
160+
step: Global step of the event.
161+
wall_time: Timestamp of the event in seconds.
162+
tag: Tag name associated with the event.
163+
"""
164+
165+
step: int
166+
wall_time: float
167+
tag: str
165168

166-
PRINT_SEPARATOR = "=" * 70 + "\n"
169+
170+
@dataclasses.dataclass(frozen=True)
171+
class InspectionUnit:
172+
"""Created for each organizational structure in the event files.
173+
174+
An InspectionUnit is visible in the final terminal output. For instance, one
175+
InspectionUnit is created for each subdirectory in logdir. When asked to inspect
176+
a single event file, there may only be one InspectionUnit.
177+
178+
Attributes:
179+
name: Name of the organizational unit that will be printed to console.
180+
generator: A generator that yields `Event` protos.
181+
field_to_obs: A mapping from string fields to `Observations` that the inspector
182+
creates.
183+
"""
184+
185+
name: str
186+
generator: Generator[event_pb2.Event, Any, Any]
187+
field_to_obs: Mapping[str, Observation]
167188

168189

169190
def get_field_to_observations_map(generator, query_for_tag=""):
@@ -181,9 +202,9 @@ def get_field_to_observations_map(generator, query_for_tag=""):
181202
def increment(stat, event, tag=""):
182203
assert stat in TRACKED_FIELDS
183204
field_to_obs[stat].append(
184-
Observation(
185-
step=event.step, wall_time=event.wall_time, tag=tag
186-
)._asdict()
205+
dataclasses.asdict(
206+
Observation(step=event.step, wall_time=event.wall_time, tag=tag)
207+
)
187208
)
188209

189210
field_to_obs = dict([(t, []) for t in TRACKED_FIELDS])

tensorboard/backend/event_processing/plugin_event_accumulator.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
"""Takes a generator of values, and accumulates them for a frontend."""
1616

1717
import collections
18+
import dataclasses
1819
import threading
1920

20-
2121
from tensorboard.backend.event_processing import directory_loader
2222
from tensorboard.backend.event_processing import directory_watcher
2323
from tensorboard.backend.event_processing import event_file_loader
@@ -29,14 +29,12 @@
2929
from tensorboard.compat.proto import event_pb2
3030
from tensorboard.compat.proto import graph_pb2
3131
from tensorboard.compat.proto import meta_graph_pb2
32+
from tensorboard.compat.proto import tensor_pb2
3233
from tensorboard.util import tb_logging
3334

3435

3536
logger = tb_logging.get_logger()
3637

37-
namedtuple = collections.namedtuple
38-
39-
TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"])
4038

4139
# Legacy aliases
4240
TENSORS = tag_types.TENSORS
@@ -55,6 +53,21 @@
5553
_TENSOR_RESERVOIR_KEY = "." # arbitrary
5654

5755

56+
@dataclasses.dataclass(frozen=True)
57+
class TensorEvent:
58+
"""A tensor event.
59+
60+
Attributes:
61+
wall_time: Timestamp of the event in seconds.
62+
step: Global step of the event.
63+
tensor_proto: A `TensorProto`.
64+
"""
65+
66+
wall_time: float
67+
step: int
68+
tensor_proto: tensor_pb2.TensorProto
69+
70+
5871
class EventAccumulator(object):
5972
"""An `EventAccumulator` takes an event generator, and accumulates the
6073
values.

0 commit comments

Comments
 (0)