Skip to content

Commit 6fac09a

Browse files
authored
Hparams: Populate differs field for TF summary hparams. (#6582)
Currently `differs` fields are not populated for hparams written by TF summary API. Also changed the boolean hparam domain discrete values to be the actual values used in the sessions rather than both True and False. #hparams
1 parent d7cdb2f commit 6fac09a

File tree

2 files changed

+116
-8
lines changed

2 files changed

+116
-8
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def experiment_from_metadata(
104104
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
105105
)
106106
if experiment_from_runs:
107-
# TODO(yatbear): Apply `hparams_limit` to `experiment_from_runs` after `differs`
108-
# fields are populated in `_compute_hparam_info_from_values()`.
107+
# TODO(yatbear): Apply `hparams_limit` to `experiment_from_runs`.
109108
return experiment_from_runs
110109

111110
experiment_from_data_provider_hparams = (
@@ -331,24 +330,27 @@ def _compute_hparam_info_from_values(self, name, values):
331330
if result.type == api_pb2.DATA_TYPE_UNSET:
332331
return None
333332

334-
# TODO(yatbear): Populate `differs` fields for hparams once go/tbpr/6574 is merged.
335333
if result.type == api_pb2.DATA_TYPE_STRING:
336334
distinct_string_values = set(
337335
_protobuf_value_to_string(v)
338336
for v in values
339337
if _can_be_converted_to_string(v)
340338
)
341339
result.domain_discrete.extend(distinct_string_values)
340+
result.differs = len(distinct_string_values) > 1
342341

343342
if result.type == api_pb2.DATA_TYPE_BOOL:
344-
result.domain_discrete.extend([True, False])
343+
distinct_bool_values = set(v.bool_value for v in values)
344+
result.domain_discrete.extend(distinct_bool_values)
345+
result.differs = len(distinct_bool_values) > 1
345346

346347
if result.type == api_pb2.DATA_TYPE_FLOAT64:
347348
# Always uses interval domain type for numeric hparam values.
348349
distinct_float_values = sorted([v.number_value for v in values])
349350
if distinct_float_values:
350351
result.domain_interval.min_value = distinct_float_values[0]
351352
result.domain_interval.max_value = distinct_float_values[-1]
353+
result.differs = len(set(distinct_float_values)) > 1
352354

353355
return result
354356

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def test_experiment_with_session_tags(self):
256256
min_value: 100.0
257257
max_value: 300.0
258258
}
259+
differs: true
259260
},
260261
hparam_infos: {
261262
name: 'lr'
@@ -264,6 +265,7 @@ def test_experiment_with_session_tags(self):
264265
min_value: 0.01
265266
max_value: 0.05
266267
}
268+
differs: true
267269
},
268270
hparam_infos: {
269271
name: 'model_type'
@@ -272,6 +274,106 @@ def test_experiment_with_session_tags(self):
272274
values: [{string_value: 'CNN'},
273275
{string_value: 'LATTICE'}]
274276
}
277+
differs: true
278+
}
279+
metric_infos: {
280+
name: {group: '', tag: 'accuracy'}
281+
}
282+
metric_infos: {
283+
name: {group: '', tag: 'loss'}
284+
}
285+
metric_infos: {
286+
name: {group: 'eval', tag: 'loss'}
287+
}
288+
metric_infos: {
289+
name: {group: 'train', tag: 'loss'}
290+
}
291+
"""
292+
actual_exp = self._experiment_from_metadata()
293+
_canonicalize_experiment(actual_exp)
294+
self.assertProtoEquals(expected_exp, actual_exp)
295+
296+
def test_experiment_with_session_tags_differs_field(self):
297+
self.session_1_start_info_ = """
298+
hparams: [
299+
{key: 'bool_hparam_differs_true' value: {bool_value: false}},
300+
{key: 'bool_hparam_differs_true' value: {bool_value: true}},
301+
{key: 'float_hparam_differs_false' value: {number_value: 1024}},
302+
{key: 'float_hparam_differs_true' value: {number_value: 0.01}},
303+
{key: 'string_hparams_differs_false' value: {string_value: 'momentum'}},
304+
{key: 'string_hparams_differs_true' value: {string_value: 'CNN'}}
305+
]
306+
"""
307+
self.session_2_start_info_ = """
308+
hparams:[
309+
{key: 'bool_hparam_differs_true' value: {bool_value: false}},
310+
{key: 'float_hparam_differs_false' value: {number_value: 1024}},
311+
{key: 'float_hparam_differs_true' value: {number_value: 0.02}},
312+
{key: 'string_hparams_differs_false' value: {string_value: 'momentum'}},
313+
{key: 'string_hparams_differs_true' value: {string_value: 'LATTICE'}}
314+
]
315+
"""
316+
self.session_3_start_info_ = """
317+
hparams:[
318+
{key: 'bool_hparam_differs_false' value: {bool_value: false}},
319+
{key: 'bool_hparam_differs_true' value: {bool_value: false}},
320+
{key: 'float_hparam_differs_false' value: {number_value: 1024}},
321+
{key: 'float_hparam_differs_true' value: {number_value: 0.05}},
322+
{key: 'string_hparams_differs_false' value: {string_value: 'momentum'}},
323+
{key: 'string_hparams_differs_true' value: {string_value: 'CNN'}}
324+
]
325+
"""
326+
expected_exp = """
327+
hparam_infos: {
328+
name: 'bool_hparam_differs_false'
329+
type: DATA_TYPE_BOOL
330+
domain_discrete: {
331+
values: [{bool_value: false}]
332+
}
333+
differs: false
334+
}
335+
hparam_infos: {
336+
name: 'bool_hparam_differs_true'
337+
type: DATA_TYPE_BOOL
338+
domain_discrete: {
339+
values: [{bool_value: false}, {bool_value: true}]
340+
}
341+
differs: true
342+
}
343+
hparam_infos: {
344+
name: 'float_hparam_differs_false'
345+
type: DATA_TYPE_FLOAT64
346+
domain_interval {
347+
min_value: 1024
348+
max_value: 1024
349+
}
350+
differs: false
351+
}
352+
hparam_infos: {
353+
name: 'float_hparam_differs_true'
354+
type: DATA_TYPE_FLOAT64
355+
domain_interval {
356+
min_value: 0.01
357+
max_value: 0.05
358+
}
359+
differs: true
360+
}
361+
hparam_infos: {
362+
name: 'string_hparams_differs_false'
363+
type: DATA_TYPE_STRING
364+
domain_discrete: {
365+
values: [{string_value: 'momentum'}]
366+
}
367+
differs: false
368+
}
369+
hparam_infos: {
370+
name: 'string_hparams_differs_true'
371+
type: DATA_TYPE_STRING
372+
domain_discrete: {
373+
values: [{string_value: 'CNN'},
374+
{string_value: 'LATTICE'}]
375+
}
376+
differs: true
275377
}
276378
metric_infos: {
277379
name: {group: '', tag: 'accuracy'}
@@ -317,6 +419,7 @@ def test_experiment_with_session_tags_different_hparam_types(self):
317419
values: [{string_value: '100.0'},
318420
{string_value: 'true'}]
319421
}
422+
differs: true
320423
}
321424
hparam_infos: {
322425
name: 'lr'
@@ -325,6 +428,7 @@ def test_experiment_with_session_tags_different_hparam_types(self):
325428
values: [{string_value: '0.01'},
326429
{string_value: '0.02'}]
327430
}
431+
differs: true
328432
}
329433
hparam_infos: {
330434
name: 'model_type'
@@ -333,6 +437,7 @@ def test_experiment_with_session_tags_different_hparam_types(self):
333437
values: [{string_value: 'CNN'},
334438
{string_value: 'LATTICE'}]
335439
}
440+
differs: true
336441
}
337442
metric_infos: {
338443
name: {group: '', tag: 'accuracy'}
@@ -354,12 +459,12 @@ def test_experiment_with_session_tags_different_hparam_types(self):
354459
def test_experiment_with_session_tags_bool_types(self):
355460
self.session_1_start_info_ = """
356461
hparams:[
357-
{key: 'batch_size' value: {bool_value: true}}
462+
{key: 'use_batch_norm' value: {bool_value: true}}
358463
]
359464
"""
360465
self.session_2_start_info_ = """
361466
hparams:[
362-
{key: 'batch_size' value: {bool_value: true}}
467+
{key: 'use_batch_norm' value: {bool_value: true}}
363468
]
364469
"""
365470
self.session_3_start_info_ = """
@@ -368,11 +473,12 @@ def test_experiment_with_session_tags_bool_types(self):
368473
"""
369474
expected_exp = """
370475
hparam_infos: {
371-
name: 'batch_size'
476+
name: 'use_batch_norm'
372477
type: DATA_TYPE_BOOL
373478
domain_discrete: {
374-
values: [{bool_value: true}, {bool_value: false}]
479+
values: [{bool_value: true}]
375480
}
481+
differs: false
376482
}
377483
metric_infos: {
378484
name: {group: '', tag: 'accuracy'}

0 commit comments

Comments
 (0)