Skip to content

Commit 8632eb7

Browse files
authored
adding "seq" to inputs (for custom loss)
1 parent 2a508c5 commit 8632eb7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

colabdesign/af/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def _model(params, model_params, inputs, key, opt):
153153
batch.update(all_atom.atom37_to_frames(**batch))
154154
else:
155155
batch = None
156+
157+
inputs["batch"] = batch
158+
156159
#######################################################################
157160
# OUTPUTS
158161
#######################################################################
@@ -174,20 +177,22 @@ def _model(params, model_params, inputs, key, opt):
174177
aux["pae"] = jnp.full((L,L),jnp.nan).at[p[:,None],p[None,:]].set(aux["pae"])
175178

176179
if self._args["recycle_mode"] == "average": aux["prev"] = outputs["prev"]
177-
180+
178181
#######################################################################
179182
# LOSS
180183
#######################################################################
181-
inputs["batch"] = batch
182-
if self._args["debug"]: aux["debug"] = {"inputs":inputs, "outputs":outputs, "opt":opt}
183184

184185
aux["losses"] = {}
185186
self._get_loss(inputs=inputs, outputs=outputs, opt=opt, aux=aux)
186187

188+
inputs["seq"] = aux["seq"]
187189
if self._loss_callback is not None:
188190
loss_fns = self._loss_callback if isinstance(self._loss_callback,list) else [self._loss_callback]
189191
for loss_fn in loss_fns:
190192
aux["losses"].update(loss_fn(inputs, outputs, opt))
193+
194+
if self._args["debug"]:
195+
aux["debug"] = {"inputs":inputs, "outputs":outputs, "opt":opt}
191196

192197
# weighted loss
193198
w = opt["weights"]

0 commit comments

Comments
 (0)