Skip to content

Commit 2deb9a2

Browse files
authored
Change return type of list_hyperparameters() operation. (#6537)
Change the return type of DataProvider.list_hyperparameters() to be a complex "Result" type that can contain multiple pieces of information. Where previously it was a Collection[Hyperparameter], it is now a ListHyperparametersResult. It contains a "hyperparameters" field that contains the same data as the previous return type. We also add a "session_groups" field for future usage. We change the hparams plugin to be able to handle both the old and the new return types.
1 parent 5668e5e commit 2deb9a2

File tree

3 files changed

+132
-69
lines changed

3 files changed

+132
-69
lines changed

tensorboard/data/provider.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,13 @@ def list_hyperparameters(self, ctx=None, *, experiment_ids):
388388
experiments.
389389
390390
Returns:
391-
A Collection[Hyperparameter] describing the hyperparameter metadata
392-
for the experiments.
391+
A ListHyperparametersResult describing the hyperparameter-related
392+
metadata for the experiments.
393393
394394
Raises:
395395
tensorboard.errors.PublicError: See `DataProvider` class docstring.
396396
"""
397-
return []
397+
return ListHyperparametersResult(hyperparameters=[], session_groups=[])
398398

399399
def read_hyperparameters(
400400
self, ctx=None, *, experiment_ids, filters=None, sort=None
@@ -737,6 +737,21 @@ class HyperparameterSort:
737737
sort_direction: HyperparameterSortDirection
738738

739739

740+
@dataclasses.dataclass(frozen=True)
741+
class ListHyperparametersResult:
742+
"""The result from calling list_hyperparameters().
743+
744+
Attributes:
745+
hyperparameters: The hyperparameteres belonging to the experiments in the
746+
request.
747+
session_groups: The session groups present in the experiments in the
748+
request.
749+
"""
750+
751+
hyperparameters: Collection[Hyperparameter]
752+
session_groups: Collection[HyperparameterSessionGroup]
753+
754+
740755
class _TimeSeries:
741756
"""Metadata about time series data for a particular run and tag.
742757

tensorboard/plugins/hparams/backend_context.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def experiment_from_metadata(
8181
summary metadata content for the keyed time series.
8282
data_provider_hparams: The ouput from an hparams_from_data_provider()
8383
call, corresponding to DataProvider.list_hyperparameters().
84-
A Collection[provider.Hyperparameter].
84+
A provider.ListHyperpararametersResult.
8585
8686
Returns:
8787
The experiment proto. If no data is found for an experiment proto to
@@ -323,18 +323,27 @@ def _experiment_from_data_provider_hparams(
323323
Args:
324324
data_provider_hparams: The ouput from an hparams_from_data_provider()
325325
call, corresponding to DataProvider.list_hyperparameters().
326-
A Collection[provider.Hyperparameter].
326+
A provider.ListHyperparametersResult.
327327
328328
Returns:
329329
The experiment proto. If there are no hyperparameters in the input,
330330
returns None.
331331
"""
332-
if not data_provider_hparams:
332+
if isinstance(data_provider_hparams, list):
333+
# TODO: Support old return value of Collection[provider.Hyperparameters]
334+
# until all internal implementations of DataProvider can be
335+
# migrated to use new return value of provider.ListHyperparametersResult.
336+
hyperparameters = data_provider_hparams
337+
else:
338+
# Is instance of provider.ListHyperparametersResult
339+
hyperparameters = data_provider_hparams.hyperparameters
340+
341+
if not hyperparameters:
333342
return None
334343

335344
hparam_infos = [
336345
self._convert_data_provider_hparam(dp_hparam)
337-
for dp_hparam in data_provider_hparams
346+
for dp_hparam in hyperparameters
338347
]
339348
return api_pb2.Experiment(hparam_infos=hparam_infos)
340349

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -378,18 +378,21 @@ def test_experiment_without_any_hparams(self):
378378
self.assertProtoEquals("", actual_exp)
379379

380380
def test_experiment_from_data_provider_differs(self):
381-
self._hyperparameters = [
382-
provider.Hyperparameter(
383-
hyperparameter_name="hparam1_name",
384-
hyperparameter_display_name="hparam1_display_name",
385-
differs=True,
386-
),
387-
provider.Hyperparameter(
388-
hyperparameter_name="hparam2_name",
389-
hyperparameter_display_name="hparam2_display_name",
390-
differs=False,
391-
),
392-
]
381+
self._hyperparameters = provider.ListHyperparametersResult(
382+
hyperparameters=[
383+
provider.Hyperparameter(
384+
hyperparameter_name="hparam1_name",
385+
hyperparameter_display_name="hparam1_display_name",
386+
differs=True,
387+
),
388+
provider.Hyperparameter(
389+
hyperparameter_name="hparam2_name",
390+
hyperparameter_display_name="hparam2_display_name",
391+
differs=False,
392+
),
393+
],
394+
session_groups=[],
395+
)
393396
self._mock_tb_context.data_provider.list_tensors.side_effect = None
394397
actual_exp = self._experiment_from_metadata()
395398
expected_exp = """
@@ -407,14 +410,17 @@ def test_experiment_from_data_provider_differs(self):
407410
self.assertProtoEquals(expected_exp, actual_exp)
408411

409412
def test_experiment_from_data_provider_interval_hparam(self):
410-
self._hyperparameters = [
411-
provider.Hyperparameter(
412-
hyperparameter_name="hparam1_name",
413-
hyperparameter_display_name="hparam1_display_name",
414-
domain_type=provider.HyperparameterDomainType.INTERVAL,
415-
domain=(-10.0, 15),
416-
)
417-
]
413+
self._hyperparameters = provider.ListHyperparametersResult(
414+
hyperparameters=[
415+
provider.Hyperparameter(
416+
hyperparameter_name="hparam1_name",
417+
hyperparameter_display_name="hparam1_display_name",
418+
domain_type=provider.HyperparameterDomainType.INTERVAL,
419+
domain=(-10.0, 15),
420+
)
421+
],
422+
session_groups=[],
423+
)
418424
self._mock_tb_context.data_provider.list_tensors.side_effect = None
419425
actual_exp = self._experiment_from_metadata()
420426
expected_exp = """
@@ -431,32 +437,35 @@ def test_experiment_from_data_provider_interval_hparam(self):
431437
self.assertProtoEquals(expected_exp, actual_exp)
432438

433439
def test_experiment_from_data_provider_discrete_bool_hparam(self):
434-
self._hyperparameters = [
435-
provider.Hyperparameter(
436-
hyperparameter_name="hparam1_name",
437-
hyperparameter_display_name="hparam1_display_name",
438-
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
439-
domain=[True],
440-
),
441-
provider.Hyperparameter(
442-
hyperparameter_name="hparam2_name",
443-
hyperparameter_display_name="hparam2_display_name",
444-
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
445-
domain=[True, False],
446-
),
447-
provider.Hyperparameter(
448-
hyperparameter_name="hparam3_name",
449-
hyperparameter_display_name="hparam3_display_name",
450-
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
451-
domain=[False],
452-
),
453-
provider.Hyperparameter(
454-
hyperparameter_name="hparam4_name",
455-
hyperparameter_display_name="hparam4_display_name",
456-
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
457-
domain=[],
458-
),
459-
]
440+
self._hyperparameters = provider.ListHyperparametersResult(
441+
hyperparameters=[
442+
provider.Hyperparameter(
443+
hyperparameter_name="hparam1_name",
444+
hyperparameter_display_name="hparam1_display_name",
445+
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
446+
domain=[True],
447+
),
448+
provider.Hyperparameter(
449+
hyperparameter_name="hparam2_name",
450+
hyperparameter_display_name="hparam2_display_name",
451+
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
452+
domain=[True, False],
453+
),
454+
provider.Hyperparameter(
455+
hyperparameter_name="hparam3_name",
456+
hyperparameter_display_name="hparam3_display_name",
457+
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
458+
domain=[False],
459+
),
460+
provider.Hyperparameter(
461+
hyperparameter_name="hparam4_name",
462+
hyperparameter_display_name="hparam4_display_name",
463+
domain_type=provider.HyperparameterDomainType.DISCRETE_BOOL,
464+
domain=[],
465+
),
466+
],
467+
session_groups=[],
468+
)
460469
self._mock_tb_context.data_provider.list_tensors.side_effect = None
461470
actual_exp = self._experiment_from_metadata()
462471
expected_exp = """
@@ -493,14 +502,17 @@ def test_experiment_from_data_provider_discrete_bool_hparam(self):
493502
self.assertProtoEquals(expected_exp, actual_exp)
494503

495504
def test_experiment_from_data_provider_discrete_float_hparam(self):
496-
self._hyperparameters = [
497-
provider.Hyperparameter(
498-
hyperparameter_name="hparam1_name",
499-
hyperparameter_display_name="hparam1_display_name",
500-
domain_type=provider.HyperparameterDomainType.DISCRETE_FLOAT,
501-
domain=[-1.0, 1.5, 0.0],
502-
),
503-
]
505+
self._hyperparameters = provider.ListHyperparametersResult(
506+
hyperparameters=[
507+
provider.Hyperparameter(
508+
hyperparameter_name="hparam1_name",
509+
hyperparameter_display_name="hparam1_display_name",
510+
domain_type=provider.HyperparameterDomainType.DISCRETE_FLOAT,
511+
domain=[-1.0, 1.5, 0.0],
512+
),
513+
],
514+
session_groups=[],
515+
)
504516
self._mock_tb_context.data_provider.list_tensors.side_effect = None
505517
actual_exp = self._experiment_from_metadata()
506518
expected_exp = """
@@ -520,14 +532,17 @@ def test_experiment_from_data_provider_discrete_float_hparam(self):
520532
self.assertProtoEquals(expected_exp, actual_exp)
521533

522534
def test_experiment_from_data_provider_discrete_string_hparam(self):
523-
self._hyperparameters = [
524-
provider.Hyperparameter(
525-
hyperparameter_name="hparam1_name",
526-
hyperparameter_display_name="hparam1_display_name",
527-
domain_type=provider.HyperparameterDomainType.DISCRETE_STRING,
528-
domain=["one", "two", "aaaa"],
529-
),
530-
]
535+
self._hyperparameters = provider.ListHyperparametersResult(
536+
hyperparameters=[
537+
provider.Hyperparameter(
538+
hyperparameter_name="hparam1_name",
539+
hyperparameter_display_name="hparam1_display_name",
540+
domain_type=provider.HyperparameterDomainType.DISCRETE_STRING,
541+
domain=["one", "two", "aaaa"],
542+
),
543+
],
544+
session_groups=[],
545+
)
531546
self._mock_tb_context.data_provider.list_tensors.side_effect = None
532547
actual_exp = self._experiment_from_metadata()
533548
expected_exp = """
@@ -546,6 +561,30 @@ def test_experiment_from_data_provider_discrete_string_hparam(self):
546561
"""
547562
self.assertProtoEquals(expected_exp, actual_exp)
548563

564+
def test_experiment_from_data_provider_old_response_type(self):
565+
self._hyperparameters = [
566+
provider.Hyperparameter(
567+
hyperparameter_name="hparam1_name",
568+
hyperparameter_display_name="hparam1_display_name",
569+
domain_type=provider.HyperparameterDomainType.INTERVAL,
570+
domain=(-10.0, 15),
571+
)
572+
]
573+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
574+
actual_exp = self._experiment_from_metadata()
575+
expected_exp = """
576+
hparam_infos: {
577+
name: 'hparam1_name'
578+
display_name: 'hparam1_display_name'
579+
type: DATA_TYPE_FLOAT64
580+
domain_interval: {
581+
min_value: -10.0
582+
max_value: 15
583+
}
584+
}
585+
"""
586+
self.assertProtoEquals(expected_exp, actual_exp)
587+
549588
def _serialized_plugin_data(self, data_oneof_field, text_protobuffer):
550589
oneof_type_dict = {
551590
DATA_TYPE_EXPERIMENT: api_pb2.Experiment,

0 commit comments

Comments
 (0)