Skip to content

Commit e923b8a

Browse files
LouYu2015tensorflower-gardener
authored andcommitted
Fix parsing error when handling validation expressions with inequality
Currently, when there an expression such as "a >= 1", the parser will think that the operator is ">" and the right hand size is "= 1". However, the operator should be ">=". PiperOrigin-RevId: 592669272
1 parent b092458 commit e923b8a

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

official/modeling/hyperparams/params_dict.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,24 +297,17 @@ def _get_kvs(tokens, params_dict):
297297
raise KeyError(
298298
'Found inconsistency between key `{}` and key `{}`.'.format(
299299
tokens[0], tokens[1]))
300-
elif '<' in restriction:
301-
tokens = restriction.split('<')
302-
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
303-
if left_v >= right_v:
304-
raise KeyError(
305-
'Found inconsistency between key `{}` and key `{}`.'.format(
306-
tokens[0], tokens[1]))
307300
elif '<=' in restriction:
308301
tokens = restriction.split('<=')
309302
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
310303
if left_v > right_v:
311304
raise KeyError(
312305
'Found inconsistency between key `{}` and key `{}`.'.format(
313306
tokens[0], tokens[1]))
314-
elif '>' in restriction:
315-
tokens = restriction.split('>')
307+
elif '<' in restriction:
308+
tokens = restriction.split('<')
316309
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
317-
if left_v <= right_v:
310+
if left_v >= right_v:
318311
raise KeyError(
319312
'Found inconsistency between key `{}` and key `{}`.'.format(
320313
tokens[0], tokens[1]))
@@ -325,6 +318,13 @@ def _get_kvs(tokens, params_dict):
325318
raise KeyError(
326319
'Found inconsistency between key `{}` and key `{}`.'.format(
327320
tokens[0], tokens[1]))
321+
elif '>' in restriction:
322+
tokens = restriction.split('>')
323+
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
324+
if left_v <= right_v:
325+
raise KeyError(
326+
'Found inconsistency between key `{}` and key `{}`.'.format(
327+
tokens[0], tokens[1]))
328328
else:
329329
raise ValueError('Unsupported relation in restriction.')
330330

official/modeling/hyperparams/params_dict_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_validate(self):
167167
'a': 10
168168
}
169169
}, ['b == c'])
170+
params.validate()
170171

171172
# Raise error due to inconsistency
172173
with self.assertRaises(KeyError):
@@ -198,6 +199,10 @@ def test_validate(self):
198199
}, ['a == None', 'c.a == 1'])
199200
params.validate()
200201

202+
# Valid restrictions with inequality.
203+
params = params_dict.ParamsDict({'a': 1}, ['a >= 1'])
204+
params.validate()
205+
201206

202207
class ParamsDictIOTest(tf.test.TestCase):
203208

0 commit comments

Comments
 (0)