@@ -403,6 +403,7 @@ def predict_structure(
403403 if "multimer" in model_type :
404404 # TODO: add multimer padding
405405 input_features = processed_feature_dict
406+ input_features ["asym_id" ] = input_features ["asym_id" ] - input_features ["asym_id" ][...,0 ]
406407 else :
407408 # TODO: move asym_id processing to "process_features"
408409 r = processed_feature_dict ["aatype" ].shape [0 ]
@@ -427,28 +428,24 @@ def callback(prediction_result, recycles):
427428 print_line += f" { y } ={ prediction_result [x ]:.3g} "
428429 logger .info (f"{ tag } recycle={ recycles } { print_line } " )
429430
430- if save_recycles or save_all :
431- prediction_result = _jnp_to_np (prediction_result )
432- prediction_result ["representations" ] = prediction_result .pop ("prev" )
433-
434431 if save_recycles :
435- final_atom_mask = prediction_result ["structure_module" ]["final_atom_mask" ]
436- b_factors = prediction_result ["plddt" ][:, None ] * final_atom_mask
432+ result = _jnp_to_np (prediction_result )
433+ final_atom_mask = result ["structure_module" ]["final_atom_mask" ]
434+ b_factors = result ["plddt" ][:, None ] * final_atom_mask
437435 unrelaxed_protein = protein .from_prediction (features = input_features ,
438- result = prediction_result , b_factors = b_factors ,
436+ result = result , b_factors = b_factors ,
439437 remove_leading_feature_dimension = ("ptm" in model_type ))
440438
441439 unrelaxed_pdb_lines = protein .to_pdb (class_to_np (unrelaxed_protein ))
442440 files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (unrelaxed_pdb_lines )
443441
444- if save_all :
445- with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
446- pickle .dump (prediction_result , handle )
442+ if save_all :
443+ with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
444+ pickle .dump (result , handle )
447445
448446 prediction_result , recycles = \
449447 model_runner .predict (input_features , random_seed = seed , prediction_callback = callback )
450448 prediction_result = _jnp_to_np (prediction_result )
451- prediction_result ["representations" ] = prediction_result .pop ("prev" )
452449 prediction_times .append (time .time () - start )
453450
454451 ########################
@@ -482,19 +479,23 @@ def callback(prediction_result, recycles):
482479
483480 #########################
484481 # save results
485- #########################
482+ #########################
483+
486484 # save pdb
487485 protein_lines = protein .to_pdb (unrelaxed_protein )
488486 files .get ("unrelaxed" ,"pdb" ).write_text (protein_lines )
489487 unrelaxed_pdb_lines .append (protein_lines )
490488
491489 # save raw outputs
492- if save_single_representations or save_pair_representations :
493- rep = prediction_result ["representations" ]
494- if save_single_representations :
495- np .save (files .get ("single_repr" ,"npy" ), rep ["prev_msa_first_row" ])
496- if save_pair_representations :
497- np .save (files .get ("pair_repr" ,"npy" ), rep ["prev_pair" ])
490+ if save_all :
491+ with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
492+ pickle .dump (prediction_result , handle )
493+ if save_single_representations :
494+ np .save (files .get ("single_repr" ,"npy" ),
495+ prediction_result ["prev" ]["prev_msa_first_row" ])
496+ if save_pair_representations :
497+ np .save (files .get ("pair_repr" ,"npy" ),
498+ prediction_result ["prev" ]["prev_pair" ])
498499
499500 # write an easy-to-use format (pAE and pLDDT)
500501 with files .get ("scores" ,"json" ).open ("w" ) as handle :
@@ -1186,6 +1187,7 @@ def run(
11861187 dpi : int = 200 ,
11871188 max_seq : Optional [int ] = None ,
11881189 max_extra_seq : Optional [int ] = None ,
1190+ use_cluster_profile : bool = True ,
11891191 feature_dict_callback : Callable [[Any ], Any ] = None ,
11901192 ** kwargs
11911193):
@@ -1234,7 +1236,6 @@ def run(
12341236 pair_mode = old_names .get (pair_mode ,pair_mode )
12351237 feature_dict_callback = kwargs .pop ("input_features_callback" , feature_dict_callback )
12361238 use_dropout = kwargs .pop ("training" , use_dropout )
1237- use_cluster_profile = kwargs .pop ("use_cluster_profile" , None )
12381239 use_fuse = kwargs .pop ("use_fuse" , True )
12391240 use_bfloat16 = kwargs .pop ("use_bfloat16" , True )
12401241 max_msa = kwargs .pop ("max_msa" ,None )
@@ -1659,7 +1660,7 @@ def main():
16591660 help = "rank models by auto, plddt or ptmscore" ,
16601661 type = str ,
16611662 default = "auto" ,
1662- choices = ["auto" , "plddt" , "ptmscore " , "multimer" ],
1663+ choices = ["auto" , "plddt" , "ptm" , "iptm " , "multimer" ],
16631664 )
16641665 parser .add_argument (
16651666 "--pair-mode" ,
@@ -1711,6 +1712,12 @@ def main():
17111712 type = str ,
17121713 default = None ,
17131714 )
1715+ parser .add_argument (
1716+ "--disable-cluster-profile" ,
1717+ default = False ,
1718+ action = "store_true" ,
1719+ help = "EXPERIMENTAL: for multimer models, disable cluster profiles" ,
1720+ )
17141721 parser .add_argument (
17151722 "--zip" ,
17161723 default = False ,
@@ -1798,6 +1805,7 @@ def main():
17981805 max_seq = args .max_seq ,
17991806 max_extra_seq = args .max_extra_seq ,
18001807 max_msa = args .max_msa ,
1808+ use_cluster_profile = not args .disable_cluster_profile ,
18011809 use_gpu_relax = args .use_gpu_relax ,
18021810 save_all = args .save_all ,
18031811 save_recycles = args .save_recycles ,
0 commit comments