Skip to content

Commit 0a8947b

Browse files
authored
Hparams: Adds filters to DataProvider.read_hyperparameters(). (#6506)
Adds support for specifying filters in DataProvider.read_hyperparameters(), represented as keyword argument `filters: Collection[HyperparameterFilter]`. Updates list_session_groups.py to translate ColParams from the HTTP request into HyperparameterFilter objects and include it in the request to the DataProvider implementation.
1 parent 79bec59 commit 0a8947b

File tree

4 files changed

+257
-7
lines changed

4 files changed

+257
-7
lines changed

tensorboard/data/provider.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,15 @@ def list_hyperparameters(self, ctx=None, *, experiment_ids):
396396
"""
397397
return []
398398

399-
def read_hyperparameters(self, ctx=None, *, experiment_ids):
399+
def read_hyperparameters(self, ctx=None, *, experiment_ids, filters=None):
400400
"""Read hyperparameter values.
401401
402402
Args:
403403
ctx: A TensorBoard `RequestContext` value.
404404
experiment_ids: A Collection[string] of IDs of the enclosing
405405
experiments.
406+
filters: A Collection[HyperparameterFilter] that constrain the
407+
returned session groups based on hyperparameter value.
406408
407409
Returns:
408410
A Collection[HyperparameterSessionGroup] describing the groups and
@@ -668,6 +670,45 @@ class HyperparameterSessionGroup:
668670
hyperparameter_values: Collection[HyperparameterValue]
669671

670672

673+
class HyperparameterFilterType(enum.Enum):
674+
"""Describes how to represent filter values."""
675+
676+
# A regular expression string. Normally represented as str.
677+
REGEX = "regex"
678+
# A range of numeric values. Normally represented as Tuple[float, float].
679+
INTERVAL = "interval"
680+
# A finite set of values. Normally represented as Collection[float|str|bool].
681+
DISCRETE = "discrete"
682+
683+
684+
@dataclasses.dataclass(frozen=True)
685+
class HyperparameterFilter:
686+
"""A constraint based on hyperparameter value.
687+
688+
Attributes:
689+
hyperparameter_name: A string identifier for the hyperparameter to use for
690+
the filter. It corresponds to the hyperparameter_name field in the
691+
Hyperparameter class.
692+
filter_type: A HyperparameterFilterType describing how we represent the
693+
filter values in the 'filter' attribute.
694+
filter: A representation of the set of the filter values.
695+
696+
If filter_type is REGEX, a str containing the regular expression.
697+
If filter_type is INTERVAL, a Tuple[float, float] describing the min and
698+
max values of the filter interval.
699+
If filter_type is DISCRETE a Collection[float|str|bool] describing the
700+
finite set of filter values.
701+
"""
702+
703+
hyperparameter_name: str
704+
filter_type: HyperparameterFilterType
705+
filter: Union[
706+
str,
707+
Tuple[float, float],
708+
Collection[Union[float, str, bool]],
709+
]
710+
711+
671712
class _TimeSeries:
672713
"""Metadata about time series data for a particular run and tag.
673714

tensorboard/plugins/hparams/backend_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ def hparams_from_data_provider(self, ctx, experiment_id):
192192
ctx, experiment_ids=[experiment_id]
193193
)
194194

195-
def session_groups_from_data_provider(self, ctx, experiment_id):
195+
def session_groups_from_data_provider(self, ctx, experiment_id, filters):
196196
"""Calls DataProvider.read_hyperparameters() and returns the result."""
197197
return self._tb_context.data_provider.read_hyperparameters(
198-
ctx, experiment_ids=[experiment_id]
198+
ctx, experiment_ids=[experiment_id], filters=filters
199199
)
200200

201201
def _find_experiment_tag(self, hparams_run_to_tag_to_content):

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,11 @@ def _session_groups_from_tags(self):
107107

108108
def _session_groups_from_data_provider(self):
109109
"""Constructs lists of SessionGroups based on DataProvider results."""
110+
filters = _build_data_provider_filters(self._request.col_params)
110111
response = self._backend_context.session_groups_from_data_provider(
111-
self._request_context, self._experiment_id
112+
self._request_context,
113+
self._experiment_id,
114+
filters,
112115
)
113116

114117
session_groups = []
@@ -811,3 +814,45 @@ def _measurements(session_group, metric_name):
811814
if not metric_value:
812815
continue
813816
yield _Measurement(metric_value, session_index)
817+
818+
819+
def _build_data_provider_filters(col_params):
820+
"""Builds HyperparameterFilters from ColParams."""
821+
filters = []
822+
for col_param in col_params:
823+
fltr = _build_data_provider_filter(col_param)
824+
if fltr is None:
825+
continue
826+
filters.append(fltr)
827+
return filters
828+
829+
830+
def _build_data_provider_filter(col_param):
831+
"""Builds HyperparameterFilter from ColParam.
832+
833+
Args:
834+
col_param: ColParam that possibly contains filter information.
835+
836+
Returns:
837+
None if col_param does not specify filter information.
838+
"""
839+
if col_param.HasField("filter_regexp"):
840+
filter_type = provider.HyperparameterFilterType.REGEX
841+
fltr = col_param.filter_regexp
842+
elif col_param.HasField("filter_interval"):
843+
filter_type = provider.HyperparameterFilterType.INTERVAL
844+
fltr = (
845+
col_param.filter_interval.min_value,
846+
col_param.filter_interval.max_value,
847+
)
848+
elif col_param.HasField("filter_discrete"):
849+
filter_type = provider.HyperparameterFilterType.DISCRETE
850+
fltr = [_value_to_python(b) for b in col_param.filter_discrete.values]
851+
else:
852+
return None
853+
854+
return provider.HyperparameterFilter(
855+
hyperparameter_name=col_param.hparam,
856+
filter_type=filter_type,
857+
filter=fltr,
858+
)

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,8 @@ def _mock_read_scalars(
363363

364364
def _mock_read_hyperparameters(
365365
self,
366-
ctx,
367-
*,
368-
experiment_ids,
366+
*args,
367+
**kwargs,
369368
):
370369
return self._hyperparameters
371370

@@ -1313,6 +1312,165 @@ def test_experiment_without_any_hparams(self):
13131312
response = self._run_handler(request)
13141313
self.assertProtoEquals("", response)
13151314

1315+
def test_experiment_from_data_provider_sends_regex_filter(self):
1316+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1317+
request = """
1318+
col_params: {
1319+
hparam: 'hparam1'
1320+
filter_regexp: 'v.*ue'
1321+
}
1322+
"""
1323+
self._run_handler(request)
1324+
self.assertEquals(
1325+
self._get_read_hyperparameters_call_filters(),
1326+
[
1327+
provider.HyperparameterFilter(
1328+
hyperparameter_name="hparam1",
1329+
filter_type=provider.HyperparameterFilterType.REGEX,
1330+
filter="v.*ue",
1331+
)
1332+
],
1333+
)
1334+
1335+
def test_experiment_from_data_provider_sends_interval_filter(self):
1336+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1337+
request = """
1338+
col_params: {
1339+
hparam: 'hparam1'
1340+
filter_interval: {
1341+
min_value: 0.1
1342+
max_value: 0.2
1343+
}
1344+
}
1345+
col_params: {
1346+
hparam: 'hparam2'
1347+
filter_interval: {
1348+
min_value: 0.1
1349+
max_value: Infinity
1350+
}
1351+
}
1352+
col_params: {
1353+
hparam: 'hparam3'
1354+
filter_interval: {
1355+
min_value: -Infinity
1356+
max_value: 0.2
1357+
}
1358+
}
1359+
col_params: {
1360+
hparam: 'hparam4'
1361+
filter_interval: {
1362+
}
1363+
}
1364+
"""
1365+
self._run_handler(request)
1366+
self.assertEquals(
1367+
self._get_read_hyperparameters_call_filters(),
1368+
[
1369+
provider.HyperparameterFilter(
1370+
hyperparameter_name="hparam1",
1371+
filter_type=provider.HyperparameterFilterType.INTERVAL,
1372+
filter=(0.1, 0.2),
1373+
),
1374+
provider.HyperparameterFilter(
1375+
hyperparameter_name="hparam2",
1376+
filter_type=provider.HyperparameterFilterType.INTERVAL,
1377+
filter=(0.1, float("inf")),
1378+
),
1379+
provider.HyperparameterFilter(
1380+
hyperparameter_name="hparam3",
1381+
filter_type=provider.HyperparameterFilterType.INTERVAL,
1382+
filter=(float("-inf"), 0.2),
1383+
),
1384+
provider.HyperparameterFilter(
1385+
hyperparameter_name="hparam4",
1386+
filter_type=provider.HyperparameterFilterType.INTERVAL,
1387+
filter=(0.0, 0.0),
1388+
),
1389+
],
1390+
)
1391+
1392+
def test_experiment_from_data_provider_sends_discrete_filter(self):
1393+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1394+
request = """
1395+
col_params: {
1396+
hparam: 'hparam1'
1397+
filter_discrete: {
1398+
values: {
1399+
bool_value: true
1400+
}
1401+
values: {
1402+
bool_value: false
1403+
}
1404+
}
1405+
}
1406+
col_params: {
1407+
hparam: 'hparam2'
1408+
filter_discrete: {
1409+
values: {
1410+
number_value: 2.0
1411+
}
1412+
}
1413+
}
1414+
col_params: {
1415+
hparam: 'hparam3'
1416+
filter_discrete: {
1417+
values: {
1418+
string_value: '3_string'
1419+
}
1420+
}
1421+
}
1422+
col_params: {
1423+
hparam: 'hparam4'
1424+
filter_discrete: {
1425+
values: {
1426+
string_value: '4_string'
1427+
}
1428+
values: {
1429+
bool_value: true
1430+
}
1431+
values: {
1432+
number_value: 4.0
1433+
}
1434+
}
1435+
}
1436+
col_params: {
1437+
hparam: 'hparam5'
1438+
filter_discrete: {}
1439+
}
1440+
"""
1441+
self._run_handler(request)
1442+
1443+
self.assertEquals(
1444+
self._get_read_hyperparameters_call_filters(),
1445+
[
1446+
provider.HyperparameterFilter(
1447+
hyperparameter_name="hparam1",
1448+
filter_type=provider.HyperparameterFilterType.DISCRETE,
1449+
filter=[True, False],
1450+
),
1451+
provider.HyperparameterFilter(
1452+
hyperparameter_name="hparam2",
1453+
filter_type=provider.HyperparameterFilterType.DISCRETE,
1454+
filter=[2.0],
1455+
),
1456+
provider.HyperparameterFilter(
1457+
hyperparameter_name="hparam3",
1458+
filter_type=provider.HyperparameterFilterType.DISCRETE,
1459+
filter=["3_string"],
1460+
),
1461+
provider.HyperparameterFilter(
1462+
hyperparameter_name="hparam4",
1463+
filter_type=provider.HyperparameterFilterType.DISCRETE,
1464+
filter=["4_string", True, 4.0],
1465+
),
1466+
provider.HyperparameterFilter(
1467+
hyperparameter_name="hparam5",
1468+
filter_type=provider.HyperparameterFilterType.DISCRETE,
1469+
filter=[],
1470+
),
1471+
],
1472+
)
1473+
13161474
def test_experiment_from_data_provider_with_no_sessions_or_hparam_values(
13171475
self,
13181476
):
@@ -1749,6 +1907,12 @@ def _serialized_plugin_data(self, data_oneof_field, text_protobuffer):
17491907
getattr(plugin_data, data_oneof_field).CopyFrom(protobuffer)
17501908
return metadata.create_summary_metadata(plugin_data).plugin_data.content
17511909

1910+
def _get_read_hyperparameters_call_filters(self):
1911+
call_args = (
1912+
self._mock_tb_context.data_provider.read_hyperparameters.call_args
1913+
)
1914+
return call_args[1]["filters"]
1915+
17521916

17531917
def _reduce_session_group_to_names(session_group):
17541918
return [session.name for session in session_group.sessions]

0 commit comments

Comments
 (0)