Skip to content

Commit 07d74ac

Browse files
authored
Hparams: Treat Nan/Infinity numeric hparam values as unset values. (#6496)
## Motivation for features / changes There are cases where users might log hparam numeric values that are NaN or Infinity. Recall, we use proto's Value message for logging and storing hparam values. While processing these values in the python layer of the Hparams plugin, we may run these Value message through json_format.MessageToJson, at which point the json_format library raises an error. The limitation is documented at https://protobuf.dev/reference/protobuf/google.protobuf/#value. "[A]ttempting to serialize NaN or Infinity results in error." ## Technical description of changes We now will try to avoid calling MessageToJson on these problematic hparam values. We instead chose to treat these as "unset" values -- equivalent to having not set an hparam value at all. It will mean the values will not contribute to the calculation of discrete domains. And it means that the values will appear blank in the hparams UI rather than appearing as "NaN". ## Alternate designs / implementations considered (or N/A) Given infinite amount of time: * Fix json_format to handle NaN and Infinity :O. * Change the hparams api.proto to use something other than proto's Value for representing hparam values. Other reasonable options: * Change the NaN numeric value to "NaN" string. This was my first choice but it unfortunately leads to other complications - for instance the plugin will raise errors when it encounters a "NaN" string for an hparam that is supposed to have a numeric value.
1 parent 6febb50 commit 07d74ac

File tree

7 files changed

+290
-5
lines changed

7 files changed

+290
-5
lines changed

tensorboard/plugins/hparams/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ py_library(
2929
"download_data.py",
3030
"get_experiment.py",
3131
"hparams_plugin.py",
32+
"json_format_compat.py",
3233
"list_metric_evals.py",
3334
"list_session_groups.py",
3435
"metrics.py",
@@ -115,6 +116,15 @@ py_test(
115116
],
116117
)
117118

119+
py_test(
120+
name = "json_format_compat_test",
121+
size = "small",
122+
srcs = [
123+
"json_format_compat_test.py",
124+
],
125+
deps = [":hparams_plugin"],
126+
)
127+
118128
py_binary(
119129
name = "hparams_demo",
120130
srcs = ["hparams_demo.py"],

tensorboard/plugins/hparams/backend_context.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorboard.data import provider
2424
from tensorboard.plugins.hparams import api_pb2
25+
from tensorboard.plugins.hparams import json_format_compat
2526
from tensorboard.plugins.hparams import metadata
2627
from google.protobuf import json_format
2728
from tensorboard.plugins.scalar import metadata as scalar_metadata
@@ -282,11 +283,6 @@ def _compute_hparam_info_from_values(self, name, values):
282283
# If all values have the same type, then that is the type used.
283284
# Otherwise, the returned type is DATA_TYPE_STRING.
284285
result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET)
285-
distinct_values = set(
286-
_protobuf_value_to_string(v)
287-
for v in values
288-
if _protobuf_value_type(v)
289-
)
290286
for v in values:
291287
v_type = _protobuf_value_type(v)
292288
if not v_type:
@@ -304,6 +300,11 @@ def _compute_hparam_info_from_values(self, name, values):
304300
return None
305301

306302
if result.type == api_pb2.DATA_TYPE_STRING:
303+
distinct_values = set(
304+
_protobuf_value_to_string(v)
305+
for v in values
306+
if _can_be_converted_to_string(v)
307+
)
307308
result.domain_discrete.extend(distinct_values)
308309

309310
if result.type == api_pb2.DATA_TYPE_BOOL:
@@ -453,6 +454,12 @@ def _find_longest_parent_path(path_set, path):
453454
return path
454455

455456

457+
def _can_be_converted_to_string(value):
458+
if not _protobuf_value_type(value):
459+
return False
460+
return json_format_compat.is_serializable_value(value)
461+
462+
456463
def _protobuf_value_type(value):
457464
"""Returns the type of the google.protobuf.Value message as an
458465
api.DataType.

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,33 @@ def test_experiment_with_bool_types(self):
344344
_canonicalize_experiment(actual_exp)
345345
self.assertProtoEquals(expected_exp, actual_exp)
346346

347+
def test_experiment_with_string_domain_and_invalid_number_values(self):
348+
self.session_1_start_info_ = """
349+
hparams:[
350+
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
351+
]
352+
"""
353+
self.session_2_start_info_ = """
354+
hparams:[
355+
{key: 'maybe_invalid' value: {number_value: NaN}}
356+
]
357+
"""
358+
self.session_3_start_info_ = """
359+
hparams:[
360+
{key: 'maybe_invalid' value: {number_value: Infinity}}
361+
]
362+
"""
363+
expected_hparam_info = """
364+
name: 'maybe_invalid'
365+
type: DATA_TYPE_STRING
366+
domain_discrete: {
367+
values: [{string_value: 'force_to_string_type'}]
368+
}
369+
"""
370+
actual_exp = self._experiment_from_metadata()
371+
self.assertLen(actual_exp.hparam_infos, 1)
372+
self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0])
373+
347374
def test_experiment_without_any_hparams(self):
348375
request_ctx = context.RequestContext()
349376
actual_exp = self._experiment_from_metadata()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import math
17+
18+
19+
def is_serializable_value(value):
20+
"""Returns whether a protobuf Value will be serializable by MessageToJson.
21+
22+
The json_format documentation states that "attempting to serialize NaN or
23+
Infinity results in error."
24+
25+
https://protobuf.dev/reference/protobuf/google.protobuf/#value
26+
27+
Args:
28+
value: A value of type protobuf.Value.
29+
30+
Returns:
31+
True if the Value should be serializable without error by MessageToJson.
32+
False, otherwise.
33+
"""
34+
if not value.HasField("number_value"):
35+
return True
36+
37+
number_value = value.number_value
38+
return not math.isnan(number_value) and not math.isinf(number_value)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from google.protobuf import struct_pb2
17+
from tensorboard.plugins.hparams import json_format_compat
18+
19+
20+
class TestCase(absltest.TestCase):
21+
def test_real_value_is_serializable(self):
22+
self.assertTrue(
23+
json_format_compat.is_serializable_value(
24+
struct_pb2.Value(number_value=1.0)
25+
)
26+
)
27+
self.assertTrue(
28+
json_format_compat.is_serializable_value(
29+
struct_pb2.Value(string_value="nan")
30+
)
31+
)
32+
self.assertTrue(
33+
json_format_compat.is_serializable_value(
34+
struct_pb2.Value(bool_value=False)
35+
)
36+
)
37+
38+
def test_empty_value_is_serializable(self):
39+
self.assertTrue(
40+
json_format_compat.is_serializable_value(struct_pb2.Value())
41+
)
42+
43+
def test_nan_value_is_not_serializable(self):
44+
self.assertFalse(
45+
json_format_compat.is_serializable_value(
46+
struct_pb2.Value(number_value=float("nan"))
47+
)
48+
)
49+
50+
def test_infinity_value_is_not_serializable(self):
51+
self.assertFalse(
52+
json_format_compat.is_serializable_value(
53+
struct_pb2.Value(number_value=float("inf"))
54+
)
55+
)
56+
self.assertFalse(
57+
json_format_compat.is_serializable_value(
58+
struct_pb2.Value(number_value=float("-inf"))
59+
)
60+
)
61+
62+
63+
if __name__ == "__main__":
64+
absltest.main()

tensorboard/plugins/hparams/list_session_groups.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorboard.data import provider
2727
from tensorboard.plugins.hparams import api_pb2
2828
from tensorboard.plugins.hparams import error
29+
from tensorboard.plugins.hparams import json_format_compat
2930
from tensorboard.plugins.hparams import metadata
3031
from tensorboard.plugins.hparams import metrics
3132

@@ -246,6 +247,13 @@ def _add_session(self, session, start_info, groups_by_name):
246247
# There doesn't seem to be a way to initialize a protobuffer map in the
247248
# constructor.
248249
for (key, value) in start_info.hparams.items():
250+
if not json_format_compat.is_serializable_value(value):
251+
# NaN number_value cannot be serialized by higher level layers
252+
# that are using json_format.MessageToJson(). To workaround
253+
# the issue we do not copy them to the session group and
254+
# effectively treat them as "unset".
255+
continue
256+
249257
group.hparams[key].CopyFrom(value)
250258
groups_by_name[group_name] = group
251259

tensorboard/plugins/hparams/list_session_groups_test.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,137 @@ def test_sort_one_column_with_missing_values(self):
11721172
expected_total_size=3,
11731173
)
11741174

1175+
def _mock_list_tensors_invalid_number_values(
1176+
self, ctx, *, experiment_id, plugin_name, run_tag_filter
1177+
):
1178+
hparams_content = {
1179+
"session_1": {
1180+
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
1181+
DATA_TYPE_SESSION_START_INFO,
1182+
"""
1183+
hparams:{ key: 'maybe_bad' value: { number_value: 1 } }
1184+
group_name: 'group_1'
1185+
""",
1186+
)
1187+
},
1188+
"session_2": {
1189+
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
1190+
DATA_TYPE_SESSION_START_INFO,
1191+
"""
1192+
hparams:{ key: 'maybe_bad' value: { number_value: nan } }
1193+
group_name: 'group_2'
1194+
""",
1195+
),
1196+
},
1197+
"session_3": {
1198+
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
1199+
DATA_TYPE_SESSION_START_INFO,
1200+
"""
1201+
hparams:{ key: 'maybe_bad' value: { number_value: -infinity } }
1202+
group_name: 'group_3'
1203+
""",
1204+
),
1205+
},
1206+
"session_4": {
1207+
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
1208+
DATA_TYPE_SESSION_START_INFO,
1209+
"""
1210+
hparams:{ key: 'maybe_bad' value: { number_value: 4.0 } }
1211+
group_name: 'group_4'
1212+
""",
1213+
),
1214+
},
1215+
}
1216+
result = {}
1217+
for (run, tag_to_content) in hparams_content.items():
1218+
result.setdefault(run, {})
1219+
for (tag, content) in tag_to_content.items():
1220+
t = provider.TensorTimeSeries(
1221+
max_step=0,
1222+
max_wall_time=0,
1223+
plugin_content=content,
1224+
description="",
1225+
display_name="",
1226+
)
1227+
result[run][tag] = t
1228+
return result
1229+
1230+
def test_hparams_with_invalid_number_values(self):
1231+
self._mock_tb_context.data_provider.list_tensors.side_effect = (
1232+
self._mock_list_tensors_invalid_number_values
1233+
)
1234+
request = """
1235+
start_index: 0
1236+
slice_size: 10
1237+
allowed_statuses: [STATUS_UNKNOWN]
1238+
"""
1239+
groups = self._run_handler(request).session_groups
1240+
self.assertLen(groups, 4)
1241+
self.assertEqual(1, groups[0].hparams.get("maybe_bad").number_value)
1242+
self.assertEqual(None, groups[1].hparams.get("maybe_bad"))
1243+
self.assertEqual(None, groups[2].hparams.get("maybe_bad"))
1244+
self.assertEqual(4, groups[3].hparams.get("maybe_bad").number_value)
1245+
1246+
def test_sort_hparams_with_invalid_number_values(self):
1247+
self._mock_tb_context.data_provider.list_tensors.side_effect = (
1248+
self._mock_list_tensors_invalid_number_values
1249+
)
1250+
self._verify_handler(
1251+
request="""
1252+
start_index: 0
1253+
slice_size: 10
1254+
allowed_statuses: [STATUS_UNKNOWN]
1255+
col_params: {
1256+
hparam: 'maybe_bad'
1257+
order: ORDER_DESC
1258+
}
1259+
""",
1260+
expected_session_group_names=[
1261+
"group_4",
1262+
"group_1",
1263+
"group_2",
1264+
"group_3",
1265+
],
1266+
expected_total_size=4,
1267+
)
1268+
1269+
def test_filter_hparams_include_invalid_number_values(self):
1270+
self._mock_tb_context.data_provider.list_tensors.side_effect = (
1271+
self._mock_list_tensors_invalid_number_values
1272+
)
1273+
self._verify_handler(
1274+
request="""
1275+
start_index: 0
1276+
slice_size: 10
1277+
allowed_statuses: [STATUS_UNKNOWN]
1278+
col_params: {
1279+
hparam: 'maybe_bad'
1280+
order: ORDER_DESC
1281+
filter_interval: { min_value: 2.0 max_value: 10.0 }
1282+
}
1283+
""",
1284+
expected_session_group_names=["group_4", "group_2", "group_3"],
1285+
expected_total_size=3,
1286+
)
1287+
1288+
def test_filer_hparams_exclude_invalid_number_values(self):
1289+
self._mock_tb_context.data_provider.list_tensors.side_effect = (
1290+
self._mock_list_tensors_invalid_number_values
1291+
)
1292+
self._verify_handler(
1293+
request="""
1294+
start_index: 0
1295+
slice_size: 10
1296+
allowed_statuses: [STATUS_UNKNOWN]
1297+
col_params: {
1298+
hparam: 'maybe_bad'
1299+
exclude_missing_values: true
1300+
}
1301+
""",
1302+
expected_session_group_names=["group_1", "group_4"],
1303+
expected_total_size=2,
1304+
)
1305+
11751306
def test_experiment_without_any_hparams(self):
11761307
self._mock_tb_context.data_provider.list_tensors.side_effect = None
11771308
self._hyperparameters = []

0 commit comments

Comments
 (0)