Skip to content

Commit 2c75e99

Browse files
lingvo-botcopybara-github
authored andcommitted
Now that we require Py3.8, use the walrus operator in a few places.
PiperOrigin-RevId: 487858073
1 parent 6569919 commit 2c75e99

File tree

4 files changed

+8
-15
lines changed

4 files changed

+8
-15
lines changed

lingvo/core/base_layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,7 @@ def GetDescendant(self, path: str) -> BaseLayerT:
545545
# If child_name is being indexed as a list then we separate the name and
546546
# the index.
547547
index = None
548-
match = re.match(r'(.*)\[([-]?[0-9]+)\]$', child_name)
549-
if match:
548+
if match := re.match(r'(.*)\[([-]?[0-9]+)\]$', child_name):
550549
child_name, index = match.group(1), int(match.group(2))
551550

552551
# Validate that child_name is a child of the current parent layer.

lingvo/core/hyperparams.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def _UnquoteString(quoted):
7272

7373
def _EndsWithTerminalQuote(s, quote_char):
7474
"""Returns whether a string ends with a valid terminal quote."""
75-
endm = re.search(r'(\\*)%s$' % quote_char, s)
76-
if not endm:
75+
if not (endm := re.search(r'(\\*)%s$' % quote_char, s)):
7776
return False
7877
backslashes = endm.group(1)
7978
if len(backslashes) % 2 == 0:
@@ -362,8 +361,7 @@ def _GetNested(self, name: str) -> Tuple[ParamsT, str]:
362361
for i, part in enumerate(parts[:-1]):
363362
# Get the value (nested Params object) associated with name 'part'.
364363
try:
365-
is_list = re.match(r'^(.+)\[(.+)\]$', part)
366-
if is_list:
364+
if is_list := re.match(r'^(.+)\[(.+)\]$', part):
367365
part = is_list.group(1)
368366
list_index = int(is_list.group(2))
369367
# pylint: disable=protected-access

lingvo/core/py_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,8 +1471,7 @@ def GetVariableName(name):
14711471
('Renaming variables is not supported in eager mode. '
14721472
'Please look into migrating away from variable renaming.'), 1)
14731473
for regexp, name_format in renames:
1474-
match = re.match(regexp, name)
1475-
if match:
1474+
if match := re.match(regexp, name):
14761475
if matched:
14771476
tf.logging.warning('Multiple matches for: %s', name)
14781477
matched = True
@@ -2702,9 +2701,8 @@ def _GetVarsToLoad(all_vars,
27022701
for model_var in all_vars:
27032702
loaded = False
27042703
for regexp, name_format in variable_loading_rules:
2705-
match = re.match(regexp, model_var.name)
27062704
# Skip if var doesn't match the loading rules, or if it should be ignored.
2707-
if not match:
2705+
if not (match := re.match(regexp, model_var.name)):
27082706
if not suppress_logging:
27092707
tf.logging.debug('Loading rules do not match %s.', model_var.name)
27102708
continue
@@ -5142,9 +5140,7 @@ def RecordFormatFromFilePattern(file_pattern):
51425140
- record_format: String record format, e.g., "tfrecord", etc.
51435141
- file_pattern: The file pattern without any prefixes.
51445142
"""
5145-
result = re.match(_RECORD_FORMAT_RE, file_pattern)
5146-
5147-
if result is None:
5143+
if (result := re.match(_RECORD_FORMAT_RE, file_pattern)) is None:
51485144
# TODO(vrv): Fix all callers so that file_pattern must contain
51495145
# the record format prefix.
51505146
return 'sstable', file_pattern

lingvo/models_test_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def CreateTestMethodsForAllRegisteredModels(cls,
181181
print(f'Creating tests for {task_regexes}, excluding {exclude_regexes}')
182182
valid_models = []
183183
for model_name in sorted(model_names):
184-
if not any([re.search(regex, model_name) for regex in task_regexes]):
184+
if not any(re.search(regex, model_name) for regex in task_regexes):
185185
print(f'Skipping tests for registered model {model_name}')
186186
continue
187-
if any([re.search(regex, model_name) for regex in exclude_regexes]):
187+
if any(re.search(regex, model_name) for regex in exclude_regexes):
188188
print(f'Explicitly excluding tests for registered model {model_name}')
189189
continue
190190
valid_models.append(model_name)

0 commit comments

Comments
 (0)