Skip to content

Commit 34bfbd9

Browse files
authored
Hparams: Add include_in_result field and implement support for hparams_to_include (#6559)
This PR extracts the filtering logic for hparams only (metrics filtering will be done later) from bmd3k's [commit](bmd3k@ab75cd1). Googlers, see comments in cl/559852837 for more context. Test internally at cl/563163418. #hparams
1 parent a0c241d commit 34bfbd9

File tree

5 files changed

+170
-5
lines changed

5 files changed

+170
-5
lines changed

tensorboard/data/provider.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,13 @@ def list_hyperparameters(self, ctx=None, *, experiment_ids):
397397
return ListHyperparametersResult(hyperparameters=[], session_groups=[])
398398

399399
def read_hyperparameters(
400-
self, ctx=None, *, experiment_ids, filters=None, sort=None
400+
self,
401+
ctx=None,
402+
*,
403+
experiment_ids,
404+
filters=None,
405+
sort=None,
406+
hparams_to_include=None,
401407
):
402408
"""Read hyperparameter values.
403409
@@ -409,6 +415,10 @@ def read_hyperparameters(
409415
returned session groups based on hyperparameter value.
410416
sort: A Sequence[HyperparameterSort] that specify how the results
411417
should be sorted.
418+
hparams_to_include: An optional Collection[str] of the full names of
419+
hyperparameters to include in the results. This collection will be
420+
augmented to include all the hyperparameters specified in `filters`
421+
and `sort`. If None, all hyperparameters will be returned.
412422
413423
Returns:
414424
A Sequence[HyperparameterSessionGroup] describing the groups and

tensorboard/plugins/hparams/api.proto

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ message ListSessionGroupsRequest {
322322
optional bool include_metrics = 8;
323323
}
324324

325-
// Defines parmeters for a ListSessionGroupsRequest for a specific column.
325+
// Defines parameters for a ListSessionGroupsRequest for a specific column.
326326
// See the comment for "ListSessionGroupsRequest" above for more details.
327-
// NEXT_TAG: 9
327+
// NEXT_TAG: 10
328328
message ColParams {
329329
oneof name {
330330
MetricName metric = 1;
@@ -373,6 +373,12 @@ message ColParams {
373373
// Specifies whether to exclude session groups whose column value is missing
374374
// from the response.
375375
bool exclude_missing_values = 8;
376+
377+
// Specifies whether to include the values for the field in the operation
378+
// result. Defaults to True.
379+
// Note: Hparams and metrics that do not appear in any `ColParams` in a
380+
// request will also not be included in the result.
381+
optional bool include_in_result = 9;
376382
}
377383

378384
enum SortOrder {

tensorboard/plugins/hparams/backend_context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,15 @@ def hparams_from_data_provider(self, ctx, experiment_id):
200200
)
201201

202202
def session_groups_from_data_provider(
203-
self, ctx, experiment_id, filters, sort
203+
self, ctx, experiment_id, filters, sort, hparams_to_include
204204
):
205205
"""Calls DataProvider.read_hyperparameters() and returns the result."""
206206
return self._tb_context.data_provider.read_hyperparameters(
207-
ctx, experiment_ids=[experiment_id], filters=filters, sort=sort
207+
ctx,
208+
experiment_ids=[experiment_id],
209+
filters=filters,
210+
sort=sort,
211+
hparams_to_include=hparams_to_include,
208212
)
209213

210214
def _find_experiment_tag(

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,29 @@ def _session_groups_from_tags(self):
114114
)
115115
session_groups = self._filter(session_groups, filters)
116116
self._sort(session_groups, extractors)
117+
118+
if _specifies_include(self._request.col_params):
119+
_reduce_to_hparams_to_include(
120+
session_groups, self._request.col_params
121+
)
122+
117123
return session_groups
118124

119125
def _session_groups_from_data_provider(self):
120126
"""Constructs lists of SessionGroups based on DataProvider results."""
121127
filters = _build_data_provider_filters(self._request.col_params)
122128
sort = _build_data_provider_sort(self._request.col_params)
129+
hparams_to_include = (
130+
_get_hparams_to_include(self._request.col_params)
131+
if _specifies_include(self._request.col_params)
132+
else None
133+
)
123134
response = self._backend_context.session_groups_from_data_provider(
124135
self._request_context,
125136
self._experiment_id,
126137
filters,
127138
sort,
139+
hparams_to_include,
128140
)
129141

130142
metric_infos = (
@@ -968,3 +980,62 @@ def _build_data_provider_sort_item(col_param):
968980
hyperparameter_name=col_param.hparam,
969981
sort_direction=sort_direction,
970982
)
983+
984+
985+
def _specifies_include(col_params):
986+
"""Determines whether any `ColParam` contains the `include_in_result` field.
987+
988+
In the case where none of the col_params contains the field, we should assume
989+
that all fields should be included in the response.
990+
"""
991+
return any(
992+
col_param.HasField("include_in_result") for col_param in col_params
993+
)
994+
995+
996+
def _get_hparams_to_include(col_params):
997+
"""Generates the list of hparams to include in the response.
998+
999+
The determination is based on the `include_in_result` field in ColParam. If
1000+
a ColParam either has `include_in_result: True` or does not specify the
1001+
field at all, then it should be included in the result.
1002+
1003+
Args:
1004+
col_params: A collection of `ColParams` protos.
1005+
1006+
Returns:
1007+
A list of names of hyperparameters to include in the response.
1008+
"""
1009+
hparams_to_include = []
1010+
for col_param in col_params:
1011+
if (
1012+
col_param.HasField("include_in_result")
1013+
and not col_param.include_in_result
1014+
):
1015+
# Explicitly set to exclude this hparam.
1016+
continue
1017+
if col_param.hparam:
1018+
hparams_to_include.append(col_param.hparam)
1019+
return hparams_to_include
1020+
1021+
1022+
def _reduce_to_hparams_to_include(session_groups, col_params):
1023+
"""Removes hparams from session_groups that should not be included.
1024+
1025+
Args:
1026+
session_groups: A collection of `SessionGroup` protos, which will be
1027+
modified in place.
1028+
col_params: A collection of `ColParams` protos.
1029+
"""
1030+
hparams_to_include = _get_hparams_to_include(col_params)
1031+
1032+
for session_group in session_groups:
1033+
new_hparams = {
1034+
hparam: value
1035+
for (hparam, value) in session_group.hparams.items()
1036+
if hparam in hparams_to_include
1037+
}
1038+
1039+
session_group.ClearField("hparams")
1040+
for (hparam, value) in new_hparams.items():
1041+
session_group.hparams[hparam].CopyFrom(value)

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,43 @@ def test_some_allowed_statuses(self):
616616
],
617617
)
618618

619+
def test_include_in_result(self):
620+
request = """
621+
start_index: 0
622+
slice_size: 3
623+
allowed_statuses: [
624+
STATUS_UNKNOWN,
625+
STATUS_SUCCESS,
626+
STATUS_FAILURE,
627+
STATUS_RUNNING
628+
]
629+
aggregation_type: AGGREGATION_AVG
630+
col_params {
631+
hparam: "bool_hparam"
632+
include_in_result: True
633+
}
634+
col_params {
635+
hparam: "initial_temp"
636+
}
637+
col_params {
638+
hparam: "string_hparam"
639+
include_in_result: False
640+
}
641+
"""
642+
response = self._run_handler(request)
643+
644+
# Each of the session groups and sessions have two hparams and three metrics to include.
645+
# Only check the first two session groups.
646+
self.assertCountEqual(
647+
response.session_groups[0].hparams, ["bool_hparam", "initial_temp"]
648+
)
649+
self.assertLen(response.session_groups[0].metric_values, 3)
650+
651+
self.assertCountEqual(
652+
response.session_groups[0].hparams, ["bool_hparam", "initial_temp"]
653+
)
654+
self.assertLen(response.session_groups[0].metric_values, 3)
655+
619656
def test_some_allowed_statuses_empty_groups(self):
620657
request = """
621658
start_index: 0
@@ -2405,6 +2442,37 @@ def test_experiment_from_data_provider_include_metrics(
24052442
response.session_groups[0].sessions[0].metric_values, 2
24062443
)
24072444

2445+
def test_experiment_from_data_provider_sends_hparams_include_in_result(
2446+
self,
2447+
):
2448+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
2449+
request = """
2450+
col_params: {
2451+
hparam: 'hparam1'
2452+
include_in_result: True
2453+
}
2454+
col_params: {
2455+
hparam: 'hparam2'
2456+
include_in_result: False
2457+
}
2458+
col_params: {
2459+
hparam: 'hparam3'
2460+
}
2461+
col_params: {
2462+
hparam: 'hparam4'
2463+
include_in_result: True
2464+
}
2465+
col_params: {
2466+
metric: {tag: 'metric1'}
2467+
include_in_result: True
2468+
}
2469+
"""
2470+
self._run_handler(request)
2471+
self.assertCountEqual(
2472+
self._get_read_hyperparameters_call_hparams_to_include(),
2473+
["hparam1", "hparam3", "hparam4"],
2474+
)
2475+
24082476
def _run_handler(self, request):
24092477
request_proto = api_pb2.ListSessionGroupsRequest()
24102478
text_format.Merge(request, request_proto)
@@ -2456,6 +2524,12 @@ def _get_read_hyperparameters_call_sort(self):
24562524
)
24572525
return call_args[1]["sort"]
24582526

2527+
def _get_read_hyperparameters_call_hparams_to_include(self):
2528+
call_args = (
2529+
self._mock_tb_context.data_provider.read_hyperparameters.call_args
2530+
)
2531+
return call_args[1]["hparams_to_include"]
2532+
24592533

24602534
def _reduce_session_group_to_names(session_group):
24612535
return [session.name for session in session_group.sessions]

0 commit comments

Comments
 (0)