Skip to content

Commit a495a48

Browse files
authored
Hparams: Add ListSessionGroups support for DataProvider results. (#6486)
Use the result from DataProvider.read_hyperparameters() to generate SessionGroups. Note: Like in #6391 we don't yet support retrieval of Metric values for these SessionGroups and we don't yet merge the new SessionGroups with the tensor-based SessionGroups.
1 parent 63b3bc6 commit a495a48

File tree

3 files changed

+513
-15
lines changed

3 files changed

+513
-15
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def hparams_from_data_provider(self, ctx, experiment_id):
196196
ctx, experiment_ids=[experiment_id]
197197
)
198198

199+
def session_groups_from_data_provider(self, ctx, experiment_id):
200+
"""Calls DataProvider.read_hyperparameters() and returns the result."""
201+
return self._tb_context.data_provider.read_hyperparameters(
202+
ctx, experiment_ids=[experiment_id]
203+
)
204+
199205
def _find_experiment_tag(self, hparams_run_to_tag_to_content):
200206
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
201207
tag.

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,105 @@ def __init__(
5050
self._request = request
5151
self._extractors = _create_extractors(request.col_params)
5252
self._filters = _create_filters(request.col_params, self._extractors)
53+
54+
def run(self):
55+
"""Handles the request specified on construction.
56+
57+
This operation first attempts to construct SessionGroup information
58+
from hparam tags metadata.EXPERIMENT_TAG and
59+
metadata.SESSION_START_INFO.
60+
61+
If no such tags are found, then will build SessionGroup information
62+
using the results from DataProvider.read_hyperparameters().
63+
64+
Returns:
65+
A ListSessionGroupsResponse object.
66+
"""
67+
68+
session_groups_from_tags = self._session_groups_from_tags()
69+
if session_groups_from_tags:
70+
return self._create_response(session_groups_from_tags)
71+
72+
session_groups_from_data_provider = (
73+
self._session_groups_from_data_provider()
74+
)
75+
if session_groups_from_data_provider:
76+
return self._create_response(session_groups_from_data_provider)
77+
78+
return api_pb2.ListSessionGroupsResponse(
79+
session_groups=[], total_size=0
80+
)
81+
82+
def _session_groups_from_tags(self):
83+
"""Constructs lists of SessionGroups based on hparam tag metadata."""
5384
# Query for all Hparams summary metadata up front to minimize calls to
5485
# the underlying DataProvider.
55-
self._hparams_run_to_tag_to_content = backend_context.hparams_metadata(
56-
request_context, experiment_id
86+
self._hparams_run_to_tag_to_content = (
87+
self._backend_context.hparams_metadata(
88+
self._request_context, self._experiment_id
89+
)
5790
)
5891
# Since an context.experiment() call may search through all the runs, we
5992
# cache it here.
60-
self._experiment = backend_context.experiment_from_metadata(
61-
request_context,
62-
experiment_id,
93+
self._experiment = self._backend_context.experiment_from_metadata(
94+
self._request_context,
95+
self._experiment_id,
6396
self._hparams_run_to_tag_to_content,
64-
self._backend_context.hparams_from_data_provider(
65-
request_context, experiment_id
66-
),
97+
# Don't pass any information from the DataProvider since we are only
98+
# examining session groups based on tag metadata
99+
[],
67100
)
68101

69-
def run(self):
70-
"""Handles the request specified on construction.
71-
72-
Returns:
73-
A ListSessionGroupsResponse object.
74-
"""
75102
session_groups = self._build_session_groups()
76103
session_groups = self._filter(session_groups)
77104
self._sort(session_groups)
78-
return self._create_response(session_groups)
105+
return session_groups
106+
107+
def _session_groups_from_data_provider(self):
108+
"""Constructs lists of SessionGroups based on DataProvider results."""
109+
response = self._backend_context.session_groups_from_data_provider(
110+
self._request_context, self._experiment_id
111+
)
112+
113+
session_groups = []
114+
for provider_group in response:
115+
sessions = [
116+
api_pb2.Session(name=f"{s.experiment_id}/{s.run}")
117+
for s in provider_group.sessions
118+
]
119+
name = (
120+
f"{provider_group.root.experiment_id}/{provider_group.root.run}"
121+
if provider_group.root.run
122+
else provider_group.root.experiment_id
123+
)
124+
session_group = api_pb2.SessionGroup(
125+
name=name,
126+
sessions=sessions,
127+
)
128+
129+
for provider_hparam in provider_group.hyperparameter_values:
130+
hparam = session_group.hparams[
131+
provider_hparam.hyperparameter_name
132+
]
133+
if (
134+
provider_hparam.domain_type
135+
== provider.HyperparameterDomainType.DISCRETE_STRING
136+
):
137+
hparam.string_value = provider_hparam.value
138+
elif provider_hparam.domain_type in [
139+
provider.HyperparameterDomainType.DISCRETE_FLOAT,
140+
provider.HyperparameterDomainType.INTERVAL,
141+
]:
142+
hparam.number_value = provider_hparam.value
143+
elif (
144+
provider_hparam.domain_type
145+
== provider.HyperparameterDomainType.DISCRETE_BOOL
146+
):
147+
hparam.bool_value = provider_hparam.value
148+
149+
session_groups.append(session_group)
150+
151+
return session_groups
79152

80153
def _build_session_groups(self):
81154
"""Returns a list of SessionGroups protobuffers from the summary

0 commit comments

Comments
 (0)