@@ -153,6 +153,9 @@ def _model(params, model_params, inputs, key, opt):
153153 batch .update (all_atom .atom37_to_frames (** batch ))
154154 else :
155155 batch = None
156+
157+ inputs ["batch" ] = batch
158+
156159 #######################################################################
157160 # OUTPUTS
158161 #######################################################################
@@ -174,20 +177,22 @@ def _model(params, model_params, inputs, key, opt):
174177 aux ["pae" ] = jnp .full ((L ,L ),jnp .nan ).at [p [:,None ],p [None ,:]].set (aux ["pae" ])
175178
176179 if self ._args ["recycle_mode" ] == "average" : aux ["prev" ] = outputs ["prev" ]
177-
180+
178181 #######################################################################
179182 # LOSS
180183 #######################################################################
181- inputs ["batch" ] = batch
182- if self ._args ["debug" ]: aux ["debug" ] = {"inputs" :inputs , "outputs" :outputs , "opt" :opt }
183184
184185 aux ["losses" ] = {}
185186 self ._get_loss (inputs = inputs , outputs = outputs , opt = opt , aux = aux )
186187
188+ inputs ["seq" ] = aux ["seq" ]
187189 if self ._loss_callback is not None :
188190 loss_fns = self ._loss_callback if isinstance (self ._loss_callback ,list ) else [self ._loss_callback ]
189191 for loss_fn in loss_fns :
190192 aux ["losses" ].update (loss_fn (inputs , outputs , opt ))
193+
194+ if self ._args ["debug" ]:
195+ aux ["debug" ] = {"inputs" :inputs , "outputs" :outputs , "opt" :opt }
191196
192197 # weighted loss
193198 w = opt ["weights" ]
0 commit comments