Skip to content

Commit 50d55da

Browse files
lingvo-botcopybara-github
authored andcommitted
Changes to step decoder with adaptive joint network and prediction network.
PiperOrigin-RevId: 477291556
1 parent 8ce7aa8 commit 50d55da

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

lingvo/core/step.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@ class Step(base_layer.BaseLayer):
4343
This can be seen as an RNNCell extended with optional external inputs.
4444
"""
4545

46+
adaptive_task_ids = None
47+
adaptive_task_dim = 0
48+
49+
@classmethod
50+
def Params(cls):
51+
p = super().Params()
52+
p.Define('adaptive_task_ids', None,
53+
'Name of task IDs to adaptively control joint network.')
54+
p.Define('adaptive_task_dim', 0,
55+
'Dimension of task IDs to adaptively control joint network.')
56+
return p
57+
58+
def __init__(self, params):
59+
super().__init__(params)
60+
p = params
61+
# set the adaptive controllers
62+
self.adaptive_task_ids = p.adaptive_task_ids
63+
self.adaptive_task_dim = p.adaptive_task_dim
64+
4665
def PrepareExternalInputs(self, theta, external_inputs):
4766
"""Returns the prepared external inputs, e.g., packed_src for attention."""
4867
if not external_inputs:

0 commit comments

Comments
 (0)