Skip to content

Commit 07486dd

Browse files
authored
Clean up usage of data attributes in list_session_groups. (#6503)
This is a cleanup PR to list_session_groups.py, reducing its reliance on data attributes. We now keep the set of data attributes small - to correspond exactly to the arguments of the `__init__` method (`self._request_context`, `self._backend_context`, `self._experiment_id`, and `self._request`). Two motivating factors: * Recently we have split list_session_groups.py into two major branches of logic: The branch that generates "classic" hparams from tensor metadata; and the branch that generates hparams from the data provider. Many of the existing data attributes applied only to the first branch and, in some cases, were not strictly guaranteed to be set since they are no longer set in `__init__`. * We can now move `self._extractors` and `self._filters` out of the`__init__` method now that we know they won't apply to the data provider-based hparams. It means no calculation for the "classic" hparams logic is necessary in the `__init__`.
1 parent 88b8fae commit 07486dd

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ def __init__(
4949
self._backend_context = backend_context
5050
self._experiment_id = experiment_id
5151
self._request = request
52-
self._extractors = _create_extractors(request.col_params)
53-
self._filters = _create_filters(request.col_params, self._extractors)
5452

5553
def run(self):
5654
"""Handles the request specified on construction.
@@ -82,27 +80,29 @@ def run(self):
8280

8381
def _session_groups_from_tags(self):
8482
"""Constructs lists of SessionGroups based on hparam tag metadata."""
85-
# Query for all Hparams summary metadata up front to minimize calls to
83+
# Query for all Hparams summary metadata one time to minimize calls to
8684
# the underlying DataProvider.
87-
self._hparams_run_to_tag_to_content = (
88-
self._backend_context.hparams_metadata(
89-
self._request_context, self._experiment_id
90-
)
85+
hparams_run_to_tag_to_content = self._backend_context.hparams_metadata(
86+
self._request_context, self._experiment_id
9187
)
92-
# Since an context.experiment() call may search through all the runs, we
93-
# cache it here.
94-
self._experiment = self._backend_context.experiment_from_metadata(
88+
# Construct the experiment one time since an context.experiment() call
89+
# may search through all the runs.
90+
experiment = self._backend_context.experiment_from_metadata(
9591
self._request_context,
9692
self._experiment_id,
97-
self._hparams_run_to_tag_to_content,
93+
hparams_run_to_tag_to_content,
9894
# Don't pass any information from the DataProvider since we are only
9995
# examining session groups based on tag metadata
10096
[],
10197
)
98+
extractors = _create_extractors(self._request.col_params)
99+
filters = _create_filters(self._request.col_params, extractors)
102100

103-
session_groups = self._build_session_groups()
104-
session_groups = self._filter(session_groups)
105-
self._sort(session_groups)
101+
session_groups = self._build_session_groups(
102+
hparams_run_to_tag_to_content, experiment
103+
)
104+
session_groups = self._filter(session_groups, filters)
105+
self._sort(session_groups, extractors)
106106
return session_groups
107107

108108
def _session_groups_from_data_provider(self):
@@ -151,7 +151,7 @@ def _session_groups_from_data_provider(self):
151151

152152
return session_groups
153153

154-
def _build_session_groups(self):
154+
def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
155155
"""Returns a list of SessionGroups protobuffers from the summary
156156
data."""
157157

@@ -167,13 +167,13 @@ def _build_session_groups(self):
167167
# contain metrics (may be in subdirectories).
168168
session_names = [
169169
run
170-
for (run, tags) in self._hparams_run_to_tag_to_content.items()
170+
for (run, tags) in hparams_run_to_tag_to_content.items()
171171
if metadata.SESSION_START_INFO_TAG in tags
172172
]
173173
metric_runs = set()
174174
metric_tags = set()
175175
for session_name in session_names:
176-
for metric in self._experiment.metric_infos:
176+
for metric in experiment.metric_infos:
177177
metric_name = metric.name
178178
(run, tag) = metrics.run_tag_from_session_and_metric(
179179
session_name, metric_name
@@ -190,7 +190,7 @@ def _build_session_groups(self):
190190
for (
191191
session_name,
192192
tag_to_content,
193-
) in self._hparams_run_to_tag_to_content.items():
193+
) in hparams_run_to_tag_to_content.items():
194194
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
195195
continue
196196
start_info = metadata.parse_session_start_info_plugin_data(
@@ -202,7 +202,7 @@ def _build_session_groups(self):
202202
tag_to_content[metadata.SESSION_END_INFO_TAG]
203203
)
204204
session = self._build_session(
205-
session_name, start_info, end_info, all_metric_evals
205+
experiment, session_name, start_info, end_info, all_metric_evals
206206
)
207207
if session.status in self._request.allowed_statuses:
208208
self._add_session(session, start_info, groups_by_name)
@@ -257,7 +257,9 @@ def _add_session(self, session, start_info, groups_by_name):
257257
group.hparams[key].CopyFrom(value)
258258
groups_by_name[group_name] = group
259259

260-
def _build_session(self, name, start_info, end_info, all_metric_evals):
260+
def _build_session(
261+
self, experiment, name, start_info, end_info, all_metric_evals
262+
):
261263
"""Builds a session object."""
262264

263265
assert start_info is not None
@@ -266,7 +268,7 @@ def _build_session(self, name, start_info, end_info, all_metric_evals):
266268
start_time_secs=start_info.start_time_secs,
267269
model_uri=start_info.model_uri,
268270
metric_values=self._build_session_metric_values(
269-
name, all_metric_evals
271+
experiment, name, all_metric_evals
270272
),
271273
monitor_url=start_info.monitor_url,
272274
)
@@ -275,13 +277,14 @@ def _build_session(self, name, start_info, end_info, all_metric_evals):
275277
result.end_time_secs = end_info.end_time_secs
276278
return result
277279

278-
def _build_session_metric_values(self, session_name, all_metric_evals):
280+
def _build_session_metric_values(
281+
self, experiment, session_name, all_metric_evals
282+
):
279283
"""Builds the session metric values."""
280284

281285
# result is a list of api_pb2.MetricValue instances.
282286
result = []
283-
metric_infos = self._experiment.metric_infos
284-
for metric_info in metric_infos:
287+
for metric_info in experiment.metric_infos:
285288
metric_name = metric_info.name
286289
(run, tag) = metrics.run_tag_from_session_and_metric(
287290
session_name, metric_name
@@ -327,13 +330,15 @@ def _aggregate_metrics(self, session_group):
327330
% self._request.aggregation_type
328331
)
329332

330-
def _filter(self, session_groups):
331-
return [sg for sg in session_groups if self._passes_all_filters(sg)]
333+
def _filter(self, session_groups, filters):
334+
return [
335+
sg for sg in session_groups if self._passes_all_filters(sg, filters)
336+
]
332337

333-
def _passes_all_filters(self, session_group):
334-
return all(filter_fn(session_group) for filter_fn in self._filters)
338+
def _passes_all_filters(self, session_group, filters):
339+
return all(filter_fn(session_group) for filter_fn in filters)
335340

336-
def _sort(self, session_groups):
341+
def _sort(self, session_groups, extractors):
337342
"""Sorts 'session_groups' in place according to _request.col_params."""
338343

339344
# Sort by session_group name so we have a deterministic order.
@@ -344,7 +349,7 @@ def _sort(self, session_groups):
344349
# need to iterate on these columns in reverse order (thus the primary key
345350
# is the key used in the last sort).
346351
for col_param, extractor in reversed(
347-
list(zip(self._request.col_params, self._extractors))
352+
list(zip(self._request.col_params, extractors))
348353
):
349354
if col_param.order == api_pb2.ORDER_UNSPECIFIED:
350355
continue

0 commit comments

Comments
 (0)