@@ -1013,13 +1013,13 @@ def zero_state(self, batch_size):
10131013 else :
10141014 return py_utils .NestedMap ()
10151015
1016- def StreamStep (self , theta , inputs , paddings , state0 ):
1016+ def StreamStep (self , theta , in_nmap , state0 ):
10171017 """Streams t steps.
10181018
10191019 Args:
10201020 theta: A NestedMap of read-only layer params.
1021- inputs : A tensor of shape [b, t, d].
1022- paddings: A 0/1 valued tensor of shape [b, t].
1021+ in_nmap : A NestedMap including features tensor of shape [b, t, d], and
1022+ paddings tensor of shape [b, t].
10231023 state0: A NestedMap of tensors of the same struct as returned by
10241024 zero_state().
10251025
@@ -1033,7 +1033,7 @@ def StreamStep(self, theta, inputs, paddings, state0):
10331033 assert not p .remat
10341034
10351035 with tf .name_scope (f'{ p .name } /StreamStep' ):
1036- features , aux_loss = inputs , None
1036+ features , paddings , aux_loss = in_nmap . features , in_nmap . paddings , None
10371037
10381038 if self .has_fflayer_start :
10391039 features , paddings , aux_loss = self ._MoeOrFFLayer (
@@ -1065,6 +1065,12 @@ def StreamStep(self, theta, inputs, paddings, state0):
10651065 paddings , aux_loss )
10661066 outputs = self .final_ln .FProp (theta .final_ln , features )
10671067
1068+ if p .adapter_tpl :
1069+ adapter_in_map = in_nmap .DeepCopy ()
1070+ adapter_in_map .features , adapter_in_map .padding = outputs , paddings
1071+ adapter_out_nmap = self .adapter .FProp (theta .adapter , adapter_in_map )
1072+ outputs , paddings = adapter_out_nmap .features , adapter_out_nmap .paddings
1073+
10681074 state1 = py_utils .NestedMap (
10691075 lconv_state = lconv_state1 , atten_state = atten_state1 )
10701076 return outputs , paddings , state1
0 commit comments