Skip to content

Commit a399065

Browse files
authored
Hparams: Adds sort argument to DataProvider.read_hyperparameters(). (#6518)
Adds support for specifying sort in DataProvider.read_hyperparameters(), represented as keyword argument `filters: Sequence[HyperparameterSort]`. Updates list_session_groups.py to translate ColParams from the HTTP request into HyperparameterSort objects and include it in the request to the DataProvider implementation.
1 parent 29dff2f commit a399065

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed

tensorboard/data/provider.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,9 @@ def list_hyperparameters(self, ctx=None, *, experiment_ids):
396396
"""
397397
return []
398398

399-
def read_hyperparameters(self, ctx=None, *, experiment_ids, filters=None):
399+
def read_hyperparameters(
400+
self, ctx=None, *, experiment_ids, filters=None, sort=None
401+
):
400402
"""Read hyperparameter values.
401403
402404
Args:
@@ -405,6 +407,8 @@ def read_hyperparameters(self, ctx=None, *, experiment_ids, filters=None):
405407
experiments.
406408
filters: A Collection[HyperparameterFilter] that constrain the
407409
returned session groups based on hyperparameter value.
410+
sort: A Sequence[HyperparameterSort] that specify how the results
411+
should be sorted.
408412
409413
Returns:
410414
A Collection[HyperparameterSessionGroup] describing the groups and
@@ -709,6 +713,30 @@ class HyperparameterFilter:
709713
]
710714

711715

716+
class HyperparameterSortDirection(enum.Enum):
717+
"""Describes which direction to sort a value."""
718+
719+
# Sort values ascending.
720+
ASCENDING = "ascending"
721+
# Sort values descending.
722+
DESCENDING = "descending"
723+
724+
725+
@dataclasses.dataclass(frozen=True)
726+
class HyperparameterSort:
727+
"""A sort criterium based on hyperparameter value.
728+
729+
Attributes:
730+
hyperparameter_name: A string identifier for the hyperparameter to use for
731+
the sort. It corresponds to the hyperparameter_name field in the
732+
Hyperparameter class.
733+
sort_direction: The direction to sort.
734+
"""
735+
736+
hyperparameter_name: str
737+
sort_direction: HyperparameterSortDirection
738+
739+
712740
class _TimeSeries:
713741
"""Metadata about time series data for a particular run and tag.
714742

tensorboard/plugins/hparams/backend_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,12 @@ def hparams_from_data_provider(self, ctx, experiment_id):
192192
ctx, experiment_ids=[experiment_id]
193193
)
194194

195-
def session_groups_from_data_provider(self, ctx, experiment_id, filters):
195+
def session_groups_from_data_provider(
196+
self, ctx, experiment_id, filters, sort
197+
):
196198
"""Calls DataProvider.read_hyperparameters() and returns the result."""
197199
return self._tb_context.data_provider.read_hyperparameters(
198-
ctx, experiment_ids=[experiment_id], filters=filters
200+
ctx, experiment_ids=[experiment_id], filters=filters, sort=sort
199201
)
200202

201203
def _find_experiment_tag(self, hparams_run_to_tag_to_content):

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@ def _session_groups_from_tags(self):
108108
def _session_groups_from_data_provider(self):
109109
"""Constructs lists of SessionGroups based on DataProvider results."""
110110
filters = _build_data_provider_filters(self._request.col_params)
111+
sort = _build_data_provider_sort(self._request.col_params)
111112
response = self._backend_context.session_groups_from_data_provider(
112113
self._request_context,
113114
self._experiment_id,
114115
filters,
116+
sort,
115117
)
116118

117119
session_groups = []
@@ -856,3 +858,37 @@ def _build_data_provider_filter(col_param):
856858
filter_type=filter_type,
857859
filter=fltr,
858860
)
861+
862+
863+
def _build_data_provider_sort(col_params):
864+
"""Builds HyperparameterSorts from ColParams."""
865+
sort = []
866+
for col_param in col_params:
867+
sort_item = _build_data_provider_sort_item(col_param)
868+
if sort_item is None:
869+
continue
870+
sort.append(sort_item)
871+
return sort
872+
873+
874+
def _build_data_provider_sort_item(col_param):
875+
"""Builds HyperparameterSort from ColParam.
876+
877+
Args:
878+
col_param: ColParam that possibly contains sort information.
879+
880+
Returns:
881+
None if col_param does not specify sort information.
882+
"""
883+
if col_param.order == api_pb2.ORDER_UNSPECIFIED:
884+
return None
885+
886+
sort_direction = (
887+
provider.HyperparameterSortDirection.ASCENDING
888+
if col_param.order == api_pb2.ORDER_ASC
889+
else provider.HyperparameterSortDirection.DESCENDING
890+
)
891+
return provider.HyperparameterSort(
892+
hyperparameter_name=col_param.hparam,
893+
sort_direction=sort_direction,
894+
)

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,26 @@ def test_experiment_without_any_hparams(self):
13121312
response = self._run_handler(request)
13131313
self.assertProtoEquals("", response)
13141314

1315+
def test_experiment_from_data_provider_sends_empty_filter_and_sort_from_col_params(
1316+
self,
1317+
):
1318+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1319+
# The request specifies col params but without filter or sort information.
1320+
request = """
1321+
col_params: {
1322+
hparam: 'hparam1'
1323+
}
1324+
col_params: {
1325+
hparam: 'hparam2'
1326+
}
1327+
"""
1328+
self._run_handler(request)
1329+
self.assertEquals(
1330+
self._get_read_hyperparameters_call_filters(),
1331+
[],
1332+
)
1333+
self.assertEquals(self._get_read_hyperparameters_call_sort(), [])
1334+
13151335
def test_experiment_from_data_provider_sends_regex_filter(self):
13161336
self._mock_tb_context.data_provider.list_tensors.side_effect = None
13171337
request = """
@@ -1471,6 +1491,37 @@ def test_experiment_from_data_provider_sends_discrete_filter(self):
14711491
],
14721492
)
14731493

1494+
def test_experiment_from_data_provider_sends_sort(self):
1495+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1496+
request = """
1497+
col_params: {
1498+
hparam: 'hparam1'
1499+
order: ORDER_ASC
1500+
}
1501+
col_params: {
1502+
hparam: 'hparam2'
1503+
order: ORDER_UNSPECIFIED
1504+
}
1505+
col_params: {
1506+
hparam: 'hparam3'
1507+
order: ORDER_DESC
1508+
}
1509+
"""
1510+
self._run_handler(request)
1511+
self.assertEquals(
1512+
self._get_read_hyperparameters_call_sort(),
1513+
[
1514+
provider.HyperparameterSort(
1515+
hyperparameter_name="hparam1",
1516+
sort_direction=provider.HyperparameterSortDirection.ASCENDING,
1517+
),
1518+
provider.HyperparameterSort(
1519+
hyperparameter_name="hparam3",
1520+
sort_direction=provider.HyperparameterSortDirection.DESCENDING,
1521+
),
1522+
],
1523+
)
1524+
14741525
def test_experiment_from_data_provider_with_no_sessions_or_hparam_values(
14751526
self,
14761527
):
@@ -1913,6 +1964,12 @@ def _get_read_hyperparameters_call_filters(self):
19131964
)
19141965
return call_args[1]["filters"]
19151966

1967+
def _get_read_hyperparameters_call_sort(self):
1968+
call_args = (
1969+
self._mock_tb_context.data_provider.read_hyperparameters.call_args
1970+
)
1971+
return call_args[1]["sort"]
1972+
19161973

19171974
def _reduce_session_group_to_names(session_group):
19181975
return [session.name for session in session_group.sessions]

0 commit comments

Comments
 (0)