Skip to content

Commit d7cdb2f

Browse files
authored
Hparams: Apply limit to hparams retrieved from protos with _hparams_/experiment tag. (#6577)
Limit for summary hparams fetched from `data_provider.list_tensors` will be applied later (see `TODO` for details). #hparams
1 parent 81b1ab0 commit d7cdb2f

File tree

3 files changed

+220
-7
lines changed

3 files changed

+220
-7
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def experiment_from_metadata(
6060
include_metrics,
6161
hparams_run_to_tag_to_content,
6262
data_provider_hparams,
63+
hparams_limit=None,
6364
):
6465
"""Returns the experiment proto defining the experiment.
6566
@@ -85,6 +86,8 @@ def experiment_from_metadata(
8586
data_provider_hparams: The ouput from an hparams_from_data_provider()
8687
call, corresponding to DataProvider.list_hyperparameters().
8788
A provider.ListHyperpararametersResult.
89+
hparams_limit: Optional number of hyperparameter metadata to include in the
90+
result. If unset or zero, all metadata will be included.
8891
8992
Returns:
9093
The experiment proto. If no data is found for an experiment proto to
@@ -94,12 +97,15 @@ def experiment_from_metadata(
9497
hparams_run_to_tag_to_content, include_metrics
9598
)
9699
if experiment:
100+
_sort_and_reduce_to_hparams_limit(experiment, hparams_limit)
97101
return experiment
98102

99103
experiment_from_runs = self._compute_experiment_from_runs(
100104
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
101105
)
102106
if experiment_from_runs:
107+
# TODO(yatbear): Apply `hparams_limit` to `experiment_from_runs` after `differs`
108+
# fields are populated in `_compute_hparam_info_from_values()`.
103109
return experiment_from_runs
104110

105111
experiment_from_data_provider_hparams = (
@@ -325,6 +331,7 @@ def _compute_hparam_info_from_values(self, name, values):
325331
if result.type == api_pb2.DATA_TYPE_UNSET:
326332
return None
327333

334+
# TODO(yatbear): Populate `differs` fields for hparams once go/tbpr/6574 is merged.
328335
if result.type == api_pb2.DATA_TYPE_STRING:
329336
distinct_string_values = set(
330337
_protobuf_value_to_string(v)
@@ -576,3 +583,28 @@ def _protobuf_value_to_string(value):
576583
# Remove the quotations.
577584
return value_in_json[1:-1]
578585
return value_in_json
586+
587+
588+
def _sort_and_reduce_to_hparams_limit(experiment, hparams_limit=None):
589+
"""Sorts and applies limit to the hparams in the given experiment proto.
590+
591+
Args:
592+
experiment: An api_pb2.Experiment proto, which will be modified in place.
593+
hparams_limit: Optional number of hyperparameter metadata to include in the
594+
result. If unset or zero, no limit will be applied.
595+
596+
Returns:
597+
None. `experiment` proto will be modified in place.
598+
"""
599+
if not hparams_limit:
600+
hparams_limit = -1
601+
602+
# Prioritizes returning HParamInfo protos with `differed` values.
603+
limited_hparam_infos = sorted(
604+
experiment.hparam_infos,
605+
key=lambda hparam_info: hparam_info.differs,
606+
reverse=True,
607+
)[:hparams_limit]
608+
609+
experiment.ClearField("hparam_infos")
610+
experiment.hparam_infos.extend(limited_hparam_infos)

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def _mock_list_hyperparameters(
153153
):
154154
return self._hyperparameters
155155

156-
def _experiment_from_metadata(self, *, include_metrics=True):
156+
def _experiment_from_metadata(
157+
self, *, include_metrics=True, hparams_limit=None
158+
):
157159
"""Calls the expected operations for generating an Experiment proto."""
158160
ctxt = backend_context.Context(self._mock_tb_context)
159161
request_ctx = context.RequestContext()
@@ -162,7 +164,10 @@ def _experiment_from_metadata(self, *, include_metrics=True):
162164
"123",
163165
include_metrics,
164166
ctxt.hparams_metadata(request_ctx, "123"),
165-
ctxt.hparams_from_data_provider(request_ctx, "123", limit=None),
167+
ctxt.hparams_from_data_provider(
168+
request_ctx, "123", limit=hparams_limit
169+
),
170+
hparams_limit,
166171
)
167172

168173
def test_experiment_with_experiment_tag(self):
@@ -897,6 +902,178 @@ def test_experiment_from_data_provider_old_response_type(self):
897902
"""
898903
self.assertProtoEquals(expected_exp, actual_exp)
899904

905+
def test_experiment_from_tags_with_hparams_limit_no_differed_hparams(self):
906+
experiment = """
907+
name: 'Test experiment'
908+
hparam_infos: {
909+
name: 'batch_size'
910+
type: DATA_TYPE_FLOAT64
911+
differs: false
912+
}
913+
hparam_infos: {
914+
name: 'lr'
915+
type: DATA_TYPE_FLOAT64
916+
differs: false
917+
}
918+
hparam_infos: {
919+
name: 'use_batch_norm'
920+
type: DATA_TYPE_BOOL
921+
differs: false
922+
}
923+
hparam_infos: {
924+
name: 'model_type'
925+
type: DATA_TYPE_STRING
926+
differs: false
927+
}
928+
"""
929+
t = provider.TensorTimeSeries(
930+
max_step=0,
931+
max_wall_time=0,
932+
plugin_content=self._serialized_plugin_data(
933+
DATA_TYPE_EXPERIMENT, experiment
934+
),
935+
description="",
936+
display_name="",
937+
)
938+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
939+
self._mock_tb_context.data_provider.list_tensors.return_value = {
940+
"train": {metadata.EXPERIMENT_TAG: t}
941+
}
942+
expected_exp = """
943+
name: 'Test experiment'
944+
hparam_infos: {
945+
name: 'batch_size'
946+
type: DATA_TYPE_FLOAT64
947+
differs: false
948+
}
949+
hparam_infos: {
950+
name: 'lr'
951+
type: DATA_TYPE_FLOAT64
952+
differs: false
953+
}
954+
"""
955+
actual_exp = self._experiment_from_metadata(
956+
include_metrics=False, hparams_limit=2
957+
)
958+
self.assertProtoEquals(expected_exp, actual_exp)
959+
960+
def test_experiment_from_tags_with_hparams_limit_returns_differed_hparams_first(
961+
self,
962+
):
963+
experiment = """
964+
name: 'Test experiment'
965+
hparam_infos: {
966+
name: 'batch_size'
967+
type: DATA_TYPE_FLOAT64
968+
differs: false
969+
}
970+
hparam_infos: {
971+
name: 'lr'
972+
type: DATA_TYPE_FLOAT64
973+
differs: true
974+
}
975+
hparam_infos: {
976+
name: 'use_batch_norm'
977+
type: DATA_TYPE_BOOL
978+
differs: false
979+
}
980+
hparam_infos: {
981+
name: 'model_type'
982+
type: DATA_TYPE_STRING
983+
differs: true
984+
}
985+
"""
986+
t = provider.TensorTimeSeries(
987+
max_step=0,
988+
max_wall_time=0,
989+
plugin_content=self._serialized_plugin_data(
990+
DATA_TYPE_EXPERIMENT, experiment
991+
),
992+
description="",
993+
display_name="",
994+
)
995+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
996+
self._mock_tb_context.data_provider.list_tensors.return_value = {
997+
"train": {metadata.EXPERIMENT_TAG: t}
998+
}
999+
expected_exp = """
1000+
name: 'Test experiment'
1001+
hparam_infos: {
1002+
name: 'lr'
1003+
type: DATA_TYPE_FLOAT64
1004+
differs: true
1005+
},
1006+
hparam_infos: {
1007+
name: 'model_type'
1008+
type: DATA_TYPE_STRING
1009+
differs: true
1010+
}
1011+
"""
1012+
actual_exp = self._experiment_from_metadata(
1013+
include_metrics=False, hparams_limit=2
1014+
)
1015+
self.assertProtoEquals(expected_exp, actual_exp)
1016+
1017+
def test_experiment_from_tags_sorts_differed_hparams_first(self):
1018+
experiment = """
1019+
name: 'Test experiment'
1020+
hparam_infos: {
1021+
name: 'batch_size'
1022+
type: DATA_TYPE_FLOAT64
1023+
differs: false
1024+
}
1025+
hparam_infos: {
1026+
name: 'lr'
1027+
type: DATA_TYPE_FLOAT64
1028+
differs: true
1029+
}
1030+
hparam_infos: {
1031+
name: 'use_batch_norm'
1032+
type: DATA_TYPE_BOOL
1033+
differs: false
1034+
}
1035+
hparam_infos: {
1036+
name: 'model_type'
1037+
type: DATA_TYPE_STRING
1038+
differs: true
1039+
}
1040+
"""
1041+
t = provider.TensorTimeSeries(
1042+
max_step=0,
1043+
max_wall_time=0,
1044+
plugin_content=self._serialized_plugin_data(
1045+
DATA_TYPE_EXPERIMENT, experiment
1046+
),
1047+
description="",
1048+
display_name="",
1049+
)
1050+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
1051+
self._mock_tb_context.data_provider.list_tensors.return_value = {
1052+
"train": {metadata.EXPERIMENT_TAG: t}
1053+
}
1054+
expected_exp = """
1055+
name: 'Test experiment'
1056+
hparam_infos: {
1057+
name: 'lr'
1058+
type: DATA_TYPE_FLOAT64
1059+
differs: true
1060+
}
1061+
hparam_infos: {
1062+
name: 'model_type'
1063+
type: DATA_TYPE_STRING
1064+
differs: true
1065+
}
1066+
hparam_infos: {
1067+
name: 'batch_size'
1068+
type: DATA_TYPE_FLOAT64
1069+
differs: false
1070+
}
1071+
"""
1072+
actual_exp = self._experiment_from_metadata(
1073+
include_metrics=False, hparams_limit=None
1074+
)
1075+
self.assertProtoEquals(expected_exp, actual_exp)
1076+
9001077
def _serialized_plugin_data(self, data_oneof_field, text_protobuffer):
9011078
oneof_type_dict = {
9021079
DATA_TYPE_EXPERIMENT: api_pb2.Experiment,

tensorboard/plugins/hparams/get_experiment.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,20 @@ def run(self):
4646
Returns:
4747
An Experiment object.
4848
"""
49+
data_provider_hparams = (
50+
self._backend_context.hparams_from_data_provider(
51+
self._request_context,
52+
self._experiment_id,
53+
limit=self._hparams_limit,
54+
)
55+
)
4956
return self._backend_context.experiment_from_metadata(
5057
self._request_context,
5158
self._experiment_id,
5259
self._include_metrics,
5360
self._backend_context.hparams_metadata(
5461
self._request_context, self._experiment_id
5562
),
56-
self._backend_context.hparams_from_data_provider(
57-
self._request_context,
58-
self._experiment_id,
59-
limit=self._hparams_limit,
60-
),
63+
data_provider_hparams,
64+
self._hparams_limit,
6165
)

0 commit comments

Comments
 (0)