Skip to content

Commit 523ff4a

Browse files
tonybruguiercopybara-github
authored andcommitted
Add a 'keep' parameter to the function CopyFieldsTo in hyperparams.py
PiperOrigin-RevId: 490126504
1 parent 8ca46ae commit 523ff4a

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

lingvo/core/hyperparams.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,35 @@ def CopyFieldsTo(from_p, to_p, skip=None, ignore_unknown_keys=False):
210210
return to_p
211211

212212

213+
def CopyFieldsSubsetTo(from_p, to_p, fields_to_set):
214+
"""Copy fields from one Params to another, with optional skipped params.
215+
216+
Preserves `type(to_p.Instantiate())`. Use `from_p.Copy()` instead if requiring
217+
a deep copy of `from_p`, without updating `to_p`.
218+
219+
Args:
220+
from_p: Source params to copy from.
221+
to_p: Destination params to copy to.
222+
fields_to_set: A string, a list of strings or None. Params to copy.
223+
224+
Returns:
225+
The updated to_p.
226+
"""
227+
if not isinstance(fields_to_set, list):
228+
fields_to_set = [fields_to_set]
229+
230+
for key, value in from_p.IterParams():
231+
if key == 'cls':
232+
continue
233+
if key not in fields_to_set:
234+
continue
235+
if isinstance(value, Params):
236+
to_p.Set(**{key: value.Copy()})
237+
else:
238+
to_p.Set(**{key: value})
239+
return to_p
240+
241+
213242
ParamsT = TypeVar('ParamsT', bound='Params')
214243

215244

lingvo/core/hyperparams_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def testCopyFieldsToDoesNotCopyClass(self):
166166
hyperparams.CopyFieldsTo(source, dest)
167167
self.assertEqual(dest.cls, hyperparams.InstantiableParams)
168168

169+
def testCopyFieldsSubsetTo(self):
170+
source = hyperparams.Params()
171+
dest = hyperparams.Params()
172+
source.Define('a', 'a', '')
173+
source.Define('b', 'b', '')
174+
source.Define('c', 'c', '')
175+
dest.Define('a', '', '')
176+
dest.Define('d', 'd', '')
177+
hyperparams.CopyFieldsSubsetTo(source, dest, ['a'])
178+
self.assertEqual(source.a, dest.a)
179+
self.assertNotIn('b', dest)
180+
self.assertNotIn('c', dest)
181+
self.assertEqual(dest.d, 'd')
182+
169183
def testDefineExisting(self):
170184
p = hyperparams.Params()
171185
p.Define('foo', 1, '')

0 commit comments

Comments
 (0)