@@ -98,7 +98,9 @@ def experiment_from_metadata(
98
98
return experiment_from_runs
99
99
100
100
experiment_from_data_provider_hparams = (
101
- self ._experiment_from_data_provider_hparams (data_provider_hparams )
101
+ self ._experiment_from_data_provider_hparams (
102
+ ctx , experiment_id , data_provider_hparams
103
+ )
102
104
)
103
105
return (
104
106
experiment_from_data_provider_hparams
@@ -224,7 +226,7 @@ def _compute_experiment_from_runs(
224
226
"""
225
227
hparam_infos = self ._compute_hparam_infos (hparams_run_to_tag_to_content )
226
228
if hparam_infos :
227
- metric_infos = self ._compute_metric_infos (
229
+ metric_infos = self ._compute_metric_infos_from_runs (
228
230
ctx , experiment_id , hparams_run_to_tag_to_content
229
231
)
230
232
else :
@@ -316,6 +318,8 @@ def _compute_hparam_info_from_values(self, name, values):
316
318
317
319
def _experiment_from_data_provider_hparams (
318
320
self ,
321
+ ctx ,
322
+ experiment_id ,
319
323
data_provider_hparams ,
320
324
):
321
325
"""Returns an experiment protobuffer based on data provider hparams.
@@ -334,18 +338,24 @@ def _experiment_from_data_provider_hparams(
334
338
# until all internal implementations of DataProvider can be
335
339
# migrated to use new return value of provider.ListHyperparametersResult.
336
340
hyperparameters = data_provider_hparams
341
+ session_groups = []
337
342
else :
338
343
# Is instance of provider.ListHyperparametersResult
339
344
hyperparameters = data_provider_hparams .hyperparameters
340
-
341
- if not hyperparameters :
342
- return None
345
+ session_groups = data_provider_hparams .session_groups
343
346
344
347
hparam_infos = [
345
348
self ._convert_data_provider_hparam (dp_hparam )
346
349
for dp_hparam in hyperparameters
347
350
]
348
- return api_pb2 .Experiment (hparam_infos = hparam_infos )
351
+ metric_infos = (
352
+ self .compute_metric_infos_from_data_provider_session_groups (
353
+ ctx , experiment_id , session_groups
354
+ )
355
+ )
356
+ return api_pb2 .Experiment (
357
+ hparam_infos = hparam_infos , metric_infos = metric_infos
358
+ )
349
359
350
360
def _convert_data_provider_hparam (self , dp_hparam ):
351
361
"""Builds an HParamInfo message from data provider Hyperparameter.
@@ -374,19 +384,37 @@ def _convert_data_provider_hparam(self, dp_hparam):
374
384
hparam_info .domain_discrete .extend (dp_hparam .domain )
375
385
return hparam_info
376
386
377
- def _compute_metric_infos (
387
+ def _compute_metric_infos_from_runs (
378
388
self , ctx , experiment_id , hparams_run_to_tag_to_content
379
389
):
390
+ session_runs = set (
391
+ run
392
+ for run , tags in hparams_run_to_tag_to_content .items ()
393
+ if metadata .SESSION_START_INFO_TAG in tags
394
+ )
380
395
return (
381
396
api_pb2 .MetricInfo (name = api_pb2 .MetricName (group = group , tag = tag ))
382
397
for tag , group in self ._compute_metric_names (
383
- ctx , experiment_id , hparams_run_to_tag_to_content
398
+ ctx , experiment_id , session_runs
384
399
)
385
400
)
386
401
387
- def _compute_metric_names (
388
- self , ctx , experiment_id , hparams_run_to_tag_to_content
402
+ def compute_metric_infos_from_data_provider_session_groups (
403
+ self , ctx , experiment_id , session_groups
389
404
):
405
+ session_runs = set (
406
+ f"{ s .experiment_id } /{ s .run } "
407
+ for sg in session_groups
408
+ for s in sg .sessions
409
+ )
410
+ return [
411
+ api_pb2 .MetricInfo (name = api_pb2 .MetricName (group = group , tag = tag ))
412
+ for tag , group in self ._compute_metric_names (
413
+ ctx , experiment_id , session_runs
414
+ )
415
+ ]
416
+
417
+ def _compute_metric_names (self , ctx , experiment_id , session_runs ):
390
418
"""Computes the list of metric names from all the scalar (run, tag)
391
419
pairs.
392
420
@@ -412,11 +440,6 @@ def _compute_metric_names(
412
440
A python list containing pairs. Each pair is a (tag, group) pair
413
441
representing a metric name used in some session.
414
442
"""
415
- session_runs = set (
416
- run
417
- for run , tags in hparams_run_to_tag_to_content .items ()
418
- if metadata .SESSION_START_INFO_TAG in tags
419
- )
420
443
metric_names_set = set ()
421
444
scalars_run_to_tag_to_content = self .scalars_metadata (
422
445
ctx , experiment_id
0 commit comments