Skip to content

Commit 14b035d

Browse files
lingvo-botcopybara-github
authored andcommitted
Encode dicts of params into text.
PiperOrigin-RevId: 491943111
1 parent be05a35 commit 14b035d

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

lingvo/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ pytype_library(
833833
":symbolic",
834834
# Implicit python proto dependency.
835835
"//lingvo:compat",
836+
# Implicit numpy dependency.
836837
# Implicit typing_extensions dependency.
837838
],
838839
)

lingvo/core/hyperparams.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ class _SortedDict(dict):
104104

105105
def __repr__(self):
106106
return '{' + ', '.join(
107-
'%r: %r' % item for item in sorted(self.items())) + '}'
107+
'%r: %s' % (k, _QuoteString(v) if isinstance(v, str) else repr(v))
108+
for k, v in sorted(self.items())) + '}'
108109

109110

110111
class _Param:
@@ -390,17 +391,18 @@ def _GetNested(self, name: str) -> Tuple[ParamsT, str]:
390391
for i, part in enumerate(parts[:-1]):
391392
# Get the value (nested Params object) associated with name 'part'.
392393
try:
393-
if is_list := re.match(r'^(.+)\[(.+)\]$', part):
394-
part = is_list.group(1)
395-
list_index = int(is_list.group(2))
394+
if is_list_or_dict := re.match(r'^(.+)\[(.+)\]$', part):
395+
part = is_list_or_dict.group(1)
396+
list_index = ast.literal_eval(is_list_or_dict.group(2))
396397
# pylint: disable=protected-access
397398
curr = curr._params[part].Get()
398-
if is_list:
399+
if is_list_or_dict:
399400
curr = curr[list_index]
400401
except KeyError:
401402
raise AttributeError('.'.join(parts[:i + 1]))
402403
assert isinstance(curr, Params), ('Cannot introspect %s for %s' %
403404
(type(curr), '.'.join(parts[:i + 1])))
405+
404406
return curr, parts[-1]
405407

406408
def Set(self: ParamsT, **kwargs: Any) -> ParamsT:
@@ -697,7 +699,7 @@ def _Visit(key: str, val: Any):
697699
elif isinstance(val, dict):
698700
if enter_fn(key, val):
699701
for k, v in val.items():
700-
_Visit(_SubKey(key, k), v)
702+
_Visit(f'{key}[\'{k}\']', v)
701703
exit_fn(key, val)
702704
else:
703705
visit_fn(key, val)
@@ -803,6 +805,9 @@ def _Enter(key: str, p: Any) -> bool:
803805
isinstance(x, tuple) and len(x) == 2 and isinstance(x[0], str) and
804806
isinstance(x[1], Params) for x in p)):
805807
return True
808+
elif (isinstance(p, dict) and p and all(
809+
isinstance(k, str) and isinstance(v, Params) for k, v in p.items())):
810+
return True
806811
return False
807812

808813
kv = {}

lingvo/core/hyperparams_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ def Visit(key, value):
424424
outer.Visit(Visit)
425425

426426
self.assertListEqual(visit_keys, [
427-
'inner.a[0][0]', 'inner.a[0][1][0]', 'inner.a[0][1][1]',
428-
'inner.a[0][1][2][x]'
427+
'inner[\'a\'][0][0]', 'inner[\'a\'][0][1][0]', 'inner[\'a\'][0][1][1]',
428+
'inner[\'a\'][0][1][2][x]'
429429
])
430430

431431
def testToText(self):
@@ -446,6 +446,10 @@ def testToText(self):
446446
outer.Define('plain_dict', {'a': 10}, '')
447447
outer.Define('complex_dict', {'a': 10, 'b': inner}, '')
448448
outer.Define('complex_dict_escape', {'a': 'abc"\'\ndef'}, '')
449+
outer.Define('complex_dict_with_params', {
450+
'a': inner,
451+
'b': inner.Copy()
452+
}, '')
449453
outer.Define('some_class', complex(0, 1), '')
450454
outer.Define('optional_bool', None, '')
451455
outer.Define('enum', TestEnum.B, '')
@@ -460,7 +464,12 @@ def testToText(self):
460464
'\n' + outer.ToText(), r"""
461465
class : type/__main__/TestClass1
462466
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'hello'}}
463-
complex_dict_escape : {'a': 'abc"\'\ndef'}
467+
complex_dict_escape : {'a': 'abc"\'
468+
def'}
469+
complex_dict_with_params['a'].bar : 2.71
470+
complex_dict_with_params['a'].baz : 'hello'
471+
complex_dict_with_params['b'].bar : 2.71
472+
complex_dict_with_params['b'].baz : 'hello'
464473
dataclass : {'a': [42], 'b': 'float32'}
465474
dtype : float32
466475
dtype2 : int32
@@ -488,6 +497,10 @@ class : type/__main__/TestClass1
488497
dtype2 : float32
489498
inner.baz : 'world'
490499
# foo : 123
500+
complex_dict_with_params['a'].bar : 2.71
501+
complex_dict_with_params['a'].baz : 'world'
502+
complex_dict_with_params['b'].bar : 2.71
503+
complex_dict_with_params['b'].baz : 'hello'
491504
optional_bool : true
492505
list_of_params[0].bar : 2.72
493506
seqlen : [1, 2.0, '3', [4]]
@@ -507,7 +520,12 @@ class : type/__main__/TestClass2
507520
'\n' + outer.ToText(), r"""
508521
class : type/__main__/TestClass2
509522
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'world'}}
510-
complex_dict_escape : {'a': 'abc"\'\ndef'}
523+
complex_dict_escape : {'a': 'abc"\'
524+
def'}
525+
complex_dict_with_params['a'].bar : 2.71
526+
complex_dict_with_params['a'].baz : 'world'
527+
complex_dict_with_params['b'].bar : 2.71
528+
complex_dict_with_params['b'].baz : 'hello'
511529
dataclass : {'a': 27, 'b': 'int32'}
512530
dtype : float32
513531
dtype2 : float32

0 commit comments

Comments
 (0)