@@ -428,23 +428,20 @@ def callback(prediction_result, recycles):
428428 print_line += f" { y } ={ prediction_result [x ]:.3g} "
429429 logger .info (f"{ tag } recycle={ recycles } { print_line } " )
430430
431- if save_recycles or save_all :
432- prediction_result = _jnp_to_np (prediction_result )
433- prediction_result ["representations" ] = prediction_result .pop ("prev" )
434-
435431 if save_recycles :
436- final_atom_mask = prediction_result ["structure_module" ]["final_atom_mask" ]
437- b_factors = prediction_result ["plddt" ][:, None ] * final_atom_mask
432+ result = _jnp_to_np (prediction_result )
433+ final_atom_mask = result ["structure_module" ]["final_atom_mask" ]
434+ b_factors = result ["plddt" ][:, None ] * final_atom_mask
438435 unrelaxed_protein = protein .from_prediction (features = input_features ,
439- result = prediction_result , b_factors = b_factors ,
436+ result = result , b_factors = b_factors ,
440437 remove_leading_feature_dimension = ("ptm" in model_type ))
441438
442439 unrelaxed_pdb_lines = protein .to_pdb (class_to_np (unrelaxed_protein ))
443440 files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (unrelaxed_pdb_lines )
444441
445- if save_all :
446- with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
447- pickle .dump (prediction_result , handle )
442+ if save_all :
443+ with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
444+ pickle .dump (result , handle )
448445
449446 prediction_result , recycles = \
450447 model_runner .predict (input_features , random_seed = seed , prediction_callback = callback )
@@ -489,7 +486,11 @@ def callback(prediction_result, recycles):
489486 files .get ("unrelaxed" ,"pdb" ).write_text (protein_lines )
490487 unrelaxed_pdb_lines .append (protein_lines )
491488
489+
492490 # save raw outputs
491+ if save_all :
492+ with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
493+ pickle .dump (prediction_result , handle )
493494 if save_single_representations or save_pair_representations :
494495 rep = prediction_result ["representations" ]
495496 if save_single_representations :
0 commit comments