Skip to content

Commit 8616842

Browse files
authored
Hparams: Generate metric_infos from data provider session groups. (#6539)
In calls to the /experiment endpoint for hparams, we now generate metric_infos based on the session_groups returned by the DataProvider. We are able to reuse most of the logic that generates metric_infos for the classic tensor-based hparams. The end result is that metrics are listed in the left pane of the hparams dashboard. We have not yet integrated generating metric values into the /session_groups operation so the values do not yet appear in the table view.
1 parent 2deb9a2 commit 8616842

File tree

2 files changed

+91
-15
lines changed

2 files changed

+91
-15
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def experiment_from_metadata(
9898
return experiment_from_runs
9999

100100
experiment_from_data_provider_hparams = (
101-
self._experiment_from_data_provider_hparams(data_provider_hparams)
101+
self._experiment_from_data_provider_hparams(
102+
ctx, experiment_id, data_provider_hparams
103+
)
102104
)
103105
return (
104106
experiment_from_data_provider_hparams
@@ -224,7 +226,7 @@ def _compute_experiment_from_runs(
224226
"""
225227
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
226228
if hparam_infos:
227-
metric_infos = self._compute_metric_infos(
229+
metric_infos = self._compute_metric_infos_from_runs(
228230
ctx, experiment_id, hparams_run_to_tag_to_content
229231
)
230232
else:
@@ -316,6 +318,8 @@ def _compute_hparam_info_from_values(self, name, values):
316318

317319
def _experiment_from_data_provider_hparams(
318320
self,
321+
ctx,
322+
experiment_id,
319323
data_provider_hparams,
320324
):
321325
"""Returns an experiment protobuffer based on data provider hparams.
@@ -334,18 +338,24 @@ def _experiment_from_data_provider_hparams(
334338
# until all internal implementations of DataProvider can be
335339
# migrated to use new return value of provider.ListHyperparametersResult.
336340
hyperparameters = data_provider_hparams
341+
session_groups = []
337342
else:
338343
# Is instance of provider.ListHyperparametersResult
339344
hyperparameters = data_provider_hparams.hyperparameters
340-
341-
if not hyperparameters:
342-
return None
345+
session_groups = data_provider_hparams.session_groups
343346

344347
hparam_infos = [
345348
self._convert_data_provider_hparam(dp_hparam)
346349
for dp_hparam in hyperparameters
347350
]
348-
return api_pb2.Experiment(hparam_infos=hparam_infos)
351+
metric_infos = (
352+
self.compute_metric_infos_from_data_provider_session_groups(
353+
ctx, experiment_id, session_groups
354+
)
355+
)
356+
return api_pb2.Experiment(
357+
hparam_infos=hparam_infos, metric_infos=metric_infos
358+
)
349359

350360
def _convert_data_provider_hparam(self, dp_hparam):
351361
"""Builds an HParamInfo message from data provider Hyperparameter.
@@ -374,19 +384,37 @@ def _convert_data_provider_hparam(self, dp_hparam):
374384
hparam_info.domain_discrete.extend(dp_hparam.domain)
375385
return hparam_info
376386

377-
def _compute_metric_infos(
387+
def _compute_metric_infos_from_runs(
378388
self, ctx, experiment_id, hparams_run_to_tag_to_content
379389
):
390+
session_runs = set(
391+
run
392+
for run, tags in hparams_run_to_tag_to_content.items()
393+
if metadata.SESSION_START_INFO_TAG in tags
394+
)
380395
return (
381396
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
382397
for tag, group in self._compute_metric_names(
383-
ctx, experiment_id, hparams_run_to_tag_to_content
398+
ctx, experiment_id, session_runs
384399
)
385400
)
386401

387-
def _compute_metric_names(
388-
self, ctx, experiment_id, hparams_run_to_tag_to_content
402+
def compute_metric_infos_from_data_provider_session_groups(
403+
self, ctx, experiment_id, session_groups
389404
):
405+
session_runs = set(
406+
f"{s.experiment_id}/{s.run}"
407+
for sg in session_groups
408+
for s in sg.sessions
409+
)
410+
return [
411+
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
412+
for tag, group in self._compute_metric_names(
413+
ctx, experiment_id, session_runs
414+
)
415+
]
416+
417+
def _compute_metric_names(self, ctx, experiment_id, session_runs):
390418
"""Computes the list of metric names from all the scalar (run, tag)
391419
pairs.
392420
@@ -412,11 +440,6 @@ def _compute_metric_names(
412440
A python list containing pairs. Each pair is a (tag, group) pair
413441
representing a metric name used in some session.
414442
"""
415-
session_runs = set(
416-
run
417-
for run, tags in hparams_run_to_tag_to_content.items()
418-
if metadata.SESSION_START_INFO_TAG in tags
419-
)
420443
metric_names_set = set()
421444
scalars_run_to_tag_to_content = self.scalars_metadata(
422445
ctx, experiment_id

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,59 @@ def test_experiment_from_data_provider_discrete_string_hparam(self):
561561
"""
562562
self.assertProtoEquals(expected_exp, actual_exp)
563563

564+
def test_experiment_from_data_provider_session_group(self):
565+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
566+
# The sessions chosen here mimic those returned in the implementation
567+
# of _mock_list_tensors. These work nicely with the scalars returned
568+
# in _mock_list_scalars and generate the same set of metric_infos as
569+
# the tensor-based tests in this file.
570+
self._hyperparameters = provider.ListHyperparametersResult(
571+
hyperparameters=[],
572+
session_groups=[
573+
provider.HyperparameterSessionGroup(
574+
root=provider.HyperparameterSessionRun(
575+
experiment_id="exp", run=""
576+
),
577+
sessions=[
578+
provider.HyperparameterSessionRun(
579+
experiment_id="exp", run="session_1"
580+
),
581+
provider.HyperparameterSessionRun(
582+
experiment_id="exp", run="session_2"
583+
),
584+
],
585+
hyperparameter_values=[],
586+
),
587+
provider.HyperparameterSessionGroup(
588+
root=provider.HyperparameterSessionRun(
589+
experiment_id="exp", run=""
590+
),
591+
sessions=[
592+
provider.HyperparameterSessionRun(
593+
experiment_id="exp", run="session_3"
594+
),
595+
],
596+
hyperparameter_values=[],
597+
),
598+
],
599+
)
600+
actual_exp = self._experiment_from_metadata()
601+
expected_exp = """
602+
metric_infos: {
603+
name: {group: '', tag: 'accuracy'}
604+
}
605+
metric_infos: {
606+
name: {group: '', tag: 'loss'}
607+
}
608+
metric_infos: {
609+
name: {group: 'eval', tag: 'loss'}
610+
}
611+
metric_infos: {
612+
name: {group: 'train', tag: 'loss'}
613+
}
614+
"""
615+
self.assertProtoEquals(expected_exp, actual_exp)
616+
564617
def test_experiment_from_data_provider_old_response_type(self):
565618
self._hyperparameters = [
566619
provider.Hyperparameter(

0 commit comments

Comments
 (0)