@@ -446,7 +446,6 @@ def callback(prediction_result, recycles):
446446 prediction_result , recycles = \
447447 model_runner .predict (input_features , random_seed = seed , prediction_callback = callback )
448448 prediction_result = _jnp_to_np (prediction_result )
449- prediction_result ["representations" ] = prediction_result .pop ("prev" )
450449 prediction_times .append (time .time () - start )
451450
452451 ########################
@@ -480,23 +479,23 @@ def callback(prediction_result, recycles):
480479
481480 #########################
482481 # save results
483- #########################
482+ #########################
483+
484484 # save pdb
485485 protein_lines = protein .to_pdb (unrelaxed_protein )
486486 files .get ("unrelaxed" ,"pdb" ).write_text (protein_lines )
487487 unrelaxed_pdb_lines .append (protein_lines )
488488
489-
490489 # save raw outputs
491490 if save_all :
492491 with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
493492 pickle .dump (prediction_result , handle )
494- if save_single_representations or save_pair_representations :
495- rep = prediction_result [ "representations" ]
496- if save_single_representations :
497- np . save ( files . get ( "single_repr" , "npy" ), rep [ "prev_msa_first_row" ])
498- if save_pair_representations :
499- np . save ( files . get ( "pair_repr" , "npy" ), rep ["prev_pair" ])
493+ if save_single_representations :
494+ np . save ( files . get ( "single_repr" , "npy" ),
495+ prediction_result [ "prev" ][ "prev_msa_first_row" ])
496+ if save_pair_representations :
497+ np . save ( files . get ( "pair_repr" , "npy" ),
498+ prediction_result [ "prev" ] ["prev_pair" ])
500499
501500 # write an easy-to-use format (pAE and pLDDT)
502501 with files .get ("scores" ,"json" ).open ("w" ) as handle :
0 commit comments