Skip to content

Commit 755afd9

Browse files
committed
fixing save-all/save-recycles option(s)
1 parent c4d0071 commit 755afd9

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

colabfold/batch.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -428,23 +428,20 @@ def callback(prediction_result, recycles):
428428
print_line += f" {y}={prediction_result[x]:.3g}"
429429
logger.info(f"{tag} recycle={recycles}{print_line}")
430430

431-
if save_recycles or save_all:
432-
prediction_result = _jnp_to_np(prediction_result)
433-
prediction_result["representations"] = prediction_result.pop("prev")
434-
435431
if save_recycles:
436-
final_atom_mask = prediction_result["structure_module"]["final_atom_mask"]
437-
b_factors = prediction_result["plddt"][:, None] * final_atom_mask
432+
result = _jnp_to_np(prediction_result)
433+
final_atom_mask = result["structure_module"]["final_atom_mask"]
434+
b_factors = result["plddt"][:, None] * final_atom_mask
438435
unrelaxed_protein = protein.from_prediction(features=input_features,
439-
result=prediction_result, b_factors=b_factors,
436+
result=result, b_factors=b_factors,
440437
remove_leading_feature_dimension=("ptm" in model_type))
441438

442439
unrelaxed_pdb_lines = protein.to_pdb(class_to_np(unrelaxed_protein))
443440
files.get("unrelaxed",f"r{recycles}.pdb").write_text(unrelaxed_pdb_lines)
444441

445-
if save_all:
446-
with files.get("all",f"r{recycles}.pickle").open("wb") as handle:
447-
pickle.dump(prediction_result, handle)
442+
if save_all:
443+
with files.get("all",f"r{recycles}.pickle").open("wb") as handle:
444+
pickle.dump(result, handle)
448445

449446
prediction_result, recycles = \
450447
model_runner.predict(input_features, random_seed=seed, prediction_callback=callback)
@@ -489,7 +486,11 @@ def callback(prediction_result, recycles):
489486
files.get("unrelaxed","pdb").write_text(protein_lines)
490487
unrelaxed_pdb_lines.append(protein_lines)
491488

489+
492490
# save raw outputs
491+
if save_all:
492+
with files.get("all","pickle").open("wb") as handle:
493+
pickle.dump(prediction_result, handle)
493494
if save_single_representations or save_pair_representations:
494495
rep = prediction_result["representations"]
495496
if save_single_representations:

0 commit comments

Comments
 (0)