@@ -299,19 +299,19 @@ def relax_me(pdb_filename=None, pdb_lines=None, pdb_obj=None, use_gpu=False):
299299 from alphafold .common import residue_constants
300300 from alphafold .relax import relax
301301
302- if pdb_obj is None :
302+ if pdb_obj is None :
303303 if pdb_lines is None :
304304 pdb_lines = Path (pdb_filename ).read_text ()
305305 pdb_obj = protein .from_pdb_string (pdb_lines )
306-
306+
307307 amber_relaxer = relax .AmberRelaxation (
308308 max_iterations = 0 ,
309309 tolerance = 2.39 ,
310310 stiffness = 10.0 ,
311311 exclude_residues = [],
312312 max_outer_iterations = 3 ,
313313 use_gpu = use_gpu )
314-
314+
315315 relaxed_pdb_lines , _ , _ = amber_relaxer .process (prot = pdb_obj )
316316 return relaxed_pdb_lines
317317
@@ -321,7 +321,7 @@ def __init__(self, prefix: str, result_dir: Path):
321321 self .result_dir = result_dir
322322 self .tag = None
323323 self .files = {}
324-
324+
325325 def get (self , x : str , ext :str ) -> Path :
326326 if self .tag not in self .files :
327327 self .files [self .tag ] = []
@@ -366,13 +366,13 @@ def predict_structure(
366366
367367 # iterate through random seeds
368368 for seed_num , seed in enumerate (range (random_seed , random_seed + num_seeds )):
369-
369+
370370 # iterate through models
371371 for model_num , (model_name , model_runner , params ) in enumerate (model_runner_and_params ):
372-
372+
373373 # swap params to avoid recompiling
374374 model_runner .params = params
375-
375+
376376 #########################
377377 # process input features
378378 #########################
@@ -383,24 +383,24 @@ def predict_structure(
383383 input_features ["asym_id" ] = input_features ["asym_id" ] - input_features ["asym_id" ][...,0 ]
384384 else :
385385 if model_num == 0 :
386- input_features = model_runner .process_features (feature_dict , random_seed = seed )
386+ input_features = model_runner .process_features (feature_dict , random_seed = seed )
387387 r = input_features ["aatype" ].shape [0 ]
388388 input_features ["asym_id" ] = np .tile (feature_dict ["asym_id" ],r ).reshape (r ,- 1 )
389389 if seq_len < pad_len :
390- input_features = pad_input (input_features , model_runner ,
390+ input_features = pad_input (input_features , model_runner ,
391391 model_name , pad_len , use_templates )
392392 logger .info (f"Padding length to { pad_len } " )
393-
393+
394394
395395 tag = f"{ model_type } _{ model_name } _seed_{ seed :03d} "
396396 model_names .append (tag )
397397 files .set_tag (tag )
398-
398+
399399 ########################
400400 # predict
401401 ########################
402402 start = time .time ()
403-
403+
404404 # monitor intermediate results
405405 def callback (result , recycles ):
406406 if recycles == 0 : result .pop ("tol" ,None )
@@ -419,12 +419,12 @@ def callback(result, recycles):
419419 result = result , b_factors = b_factors ,
420420 remove_leading_feature_dimension = ("multimer" not in model_type ))
421421 files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (protein .to_pdb (unrelaxed_protein ))
422-
422+
423423 if save_all :
424424 with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
425425 pickle .dump (result , handle )
426426 del unrelaxed_protein
427-
427+
428428 return_representations = save_all or save_single_representations or save_pair_representations
429429
430430 # predict
@@ -439,9 +439,9 @@ def callback(result, recycles):
439439 ########################
440440 # parse results
441441 ########################
442-
442+
443443 # summary metrics
444- mean_scores .append (result ["ranking_confidence" ])
444+ mean_scores .append (result ["ranking_confidence" ])
445445 if recycles == 0 : result .pop ("tol" ,None )
446446 if not is_complex : result .pop ("iptm" ,None )
447447 print_line = ""
@@ -469,7 +469,7 @@ def callback(result, recycles):
469469
470470 #########################
471471 # save results
472- #########################
472+ #########################
473473
474474 # save pdb
475475 protein_lines = protein .to_pdb (unrelaxed_protein )
@@ -498,12 +498,12 @@ def callback(result, recycles):
498498 del pae
499499 del plddt
500500 json .dump (scores , handle )
501-
501+
502502 del result , unrelaxed_protein
503503
504504 # early stop criteria fulfilled
505505 if mean_scores [- 1 ] > stop_at_score : break
506-
506+
507507 # early stop criteria fulfilled
508508 if mean_scores [- 1 ] > stop_at_score : break
509509
@@ -514,7 +514,7 @@ def callback(result, recycles):
514514 ###################################################
515515 # rerank models based on predicted confidence
516516 ###################################################
517-
517+
518518 rank , metric = [],[]
519519 result_files = []
520520 logger .info (f"reranking models by '{ rank_by } ' metric" )
@@ -527,7 +527,7 @@ def callback(result, recycles):
527527 if n < num_relax :
528528 start = time .time ()
529529 pdb_lines = relax_me (pdb_lines = unrelaxed_pdb_lines [key ], use_gpu = use_gpu_relax )
530- files .get ("relaxed" ,"pdb" ).write_text (pdb_lines )
530+ files .get ("relaxed" ,"pdb" ).write_text (pdb_lines )
531531 logger .info (f"Relaxation took { (time .time () - start ):.1f} s" )
532532
533533 # rename files to include rank
@@ -538,7 +538,7 @@ def callback(result, recycles):
538538 new_file = result_dir .joinpath (f"{ prefix } _{ x } _{ new_tag } .{ ext } " )
539539 file .rename (new_file )
540540 result_files .append (new_file )
541-
541+
542542 return {"rank" :rank ,
543543 "metric" :metric ,
544544 "result_files" :result_files }
@@ -649,10 +649,10 @@ def get_queries(
649649 # sort by seq. len
650650 if sort_queries_by == "length" :
651651 queries .sort (key = lambda t : len ("" .join (t [1 ])))
652-
652+
653653 elif sort_queries_by == "random" :
654654 random .shuffle (queries )
655-
655+
656656 is_complex = False
657657 for job_number , (raw_jobname , query_sequence , a3m_lines ) in enumerate (queries ):
658658 if isinstance (query_sequence , list ):
@@ -719,6 +719,7 @@ def pad_sequences(
719719def get_msa_and_templates (
720720 jobname : str ,
721721 query_sequences : Union [str , List [str ]],
722+ a3m_lines : Optional [List [str ]],
722723 result_dir : Path ,
723724 msa_mode : str ,
724725 use_templates : bool ,
@@ -749,17 +750,29 @@ def get_msa_and_templates(
749750 # get template features
750751 template_features = []
751752 if use_templates :
752- a3m_lines_mmseqs2 , template_paths = run_mmseqs2 (
753- query_seqs_unique ,
754- str (result_dir .joinpath (jobname )),
755- use_env ,
756- use_templates = True ,
757- host_url = host_url ,
758- )
753+ # Skip template search when custom_template_path is provided
759754 if custom_template_path is not None :
755+ if a3m_lines is None :
756+ a3m_lines_mmseqs2 = run_mmseqs2 (
757+ query_seqs_unique ,
758+ str (result_dir .joinpath (jobname )),
759+ use_env ,
760+ use_templates = False ,
761+ host_url = host_url ,
762+ )
763+ else :
764+ a3m_lines_mmseqs2 = a3m_lines
760765 template_paths = {}
761766 for index in range (0 , len (query_seqs_unique )):
762767 template_paths [index ] = custom_template_path
768+ else :
769+ a3m_lines_mmseqs2 , template_paths = run_mmseqs2 (
770+ query_seqs_unique ,
771+ str (result_dir .joinpath (jobname )),
772+ use_env ,
773+ use_templates = True ,
774+ host_url = host_url ,
775+ )
763776 if template_paths is None :
764777 logger .info ("No template detected" )
765778 for index in range (0 , len (query_seqs_unique )):
@@ -966,7 +979,7 @@ def generate_input_feature(
966979
967980 # bugfix
968981 a3m_lines = f">0\n { full_sequence } \n "
969- a3m_lines += pair_msa (query_seqs_unique , query_seqs_cardinality , paired_msa , unpaired_msa )
982+ a3m_lines += pair_msa (query_seqs_unique , query_seqs_cardinality , paired_msa , unpaired_msa )
970983
971984 input_feature = build_monomer_feature (full_sequence , a3m_lines , mk_mock_template (full_sequence ))
972985 input_feature ["residue_index" ] = np .concatenate ([np .arange (L ) for L in Ls ])
@@ -987,7 +1000,7 @@ def generate_input_feature(
9871000 chain_cnt = 0
9881001 # for each unique sequence
9891002 for sequence_index , sequence in enumerate (query_seqs_unique ):
990-
1003+
9911004 # get unpaired msa
9921005 if unpaired_msa is None :
9931006 input_msa = f">{ 101 + sequence_index } \n { sequence } "
@@ -1243,7 +1256,7 @@ def run(
12431256 if max_msa is not None :
12441257 max_seq , max_extra_seq = [int (x ) for x in max_msa .split (":" )]
12451258
1246- if kwargs .pop ("use_amber" , False ) and num_relax == 0 :
1259+ if kwargs .pop ("use_amber" , False ) and num_relax == 0 :
12471260 num_relax = num_models * num_seeds
12481261
12491262 if len (kwargs ) > 0 :
@@ -1263,22 +1276,22 @@ def run(
12631276 L = len ("" .join (query_sequence ))
12641277 if L > max_len : max_len = L
12651278 if N > max_num : max_num = N
1266-
1279+
12671280 # get max sequences
12681281 # 512 5120 = alphafold_ptm (models 1,3,4)
12691282 # 512 1024 = alphafold_ptm (models 2,5)
12701283 # 508 2048 = alphafold-multimer_v3 (models 1,2,3)
12711284 # 508 1152 = alphafold-multimer_v3 (models 4,5)
12721285 # 252 1152 = alphafold-multimer_v[1,2]
1273-
1286+
12741287 set_if = lambda x ,y : y if x is None else x
12751288 if model_type in ["alphafold2_multimer_v1" ,"alphafold2_multimer_v2" ]:
12761289 (max_seq , max_extra_seq ) = (set_if (max_seq ,252 ), set_if (max_extra_seq ,1152 ))
12771290 elif model_type == "alphafold2_multimer_v3" :
12781291 (max_seq , max_extra_seq ) = (set_if (max_seq ,508 ), set_if (max_extra_seq ,2048 ))
12791292 else :
12801293 (max_seq , max_extra_seq ) = (set_if (max_seq ,512 ), set_if (max_extra_seq ,5120 ))
1281-
1294+
12821295 if msa_mode == "single_sequence" :
12831296 num_seqs = 1
12841297 if is_complex and "multimer" not in model_type : num_seqs += max_num
@@ -1337,7 +1350,7 @@ def run(
13371350 first_job = True
13381351 for job_number , (raw_jobname , query_sequence , a3m_lines ) in enumerate (queries ):
13391352 jobname = safe_filename (raw_jobname )
1340-
1353+
13411354 #######################################
13421355 # check if job has already finished
13431356 #######################################
@@ -1359,59 +1372,59 @@ def run(
13591372 # generate MSA (a3m_lines) and templates
13601373 ###########################################
13611374 try :
1362- if a3m_lines is None :
1375+ if a3m_lines is None :
13631376 (unpaired_msa , paired_msa , query_seqs_unique , query_seqs_cardinality , template_features ) \
1364- = get_msa_and_templates (jobname , query_sequence , result_dir , msa_mode , use_templates ,
1377+ = get_msa_and_templates (jobname , query_sequence , a3m_lines , result_dir , msa_mode , use_templates ,
13651378 custom_template_path , pair_mode , pairing_strategy , host_url )
1366-
1367- elif a3m_lines is not None :
1379+
1380+ elif a3m_lines is not None :
13681381 (unpaired_msa , paired_msa , query_seqs_unique , query_seqs_cardinality , template_features ) \
13691382 = unserialize_msa (a3m_lines , query_sequence )
1370- if use_templates :
1383+ if use_templates :
13711384 (_ , _ , _ , _ , template_features ) \
1372- = get_msa_and_templates (jobname , query_seqs_unique , result_dir , 'single_sequence' , use_templates ,
1385+ = get_msa_and_templates (jobname , query_seqs_unique , a3m_lines , result_dir , 'single_sequence' , use_templates ,
13731386 custom_template_path , pair_mode , pairing_strategy , host_url )
1374-
1387+
13751388 # save a3m
13761389 msa = msa_to_str (unpaired_msa , paired_msa , query_seqs_unique , query_seqs_cardinality )
13771390 result_dir .joinpath (f"{ jobname } .a3m" ).write_text (msa )
1378-
1391+
13791392 except Exception as e :
13801393 logger .exception (f"Could not get MSA/templates for { jobname } : { e } " )
13811394 continue
1382-
1395+
13831396 #######################
13841397 # generate features
13851398 #######################
13861399 try :
13871400 (feature_dict , domain_names ) \
13881401 = generate_input_feature (query_seqs_unique , query_seqs_cardinality , unpaired_msa , paired_msa ,
13891402 template_features , is_complex , model_type , max_seq = max_seq )
1390-
1403+
13911404 # to allow display of MSA info during colab/chimera run (thanks tomgoddard)
13921405 if feature_dict_callback is not None :
13931406 feature_dict_callback (feature_dict )
1394-
1407+
13951408 except Exception as e :
13961409 logger .exception (f"Could not generate input features { jobname } : { e } " )
13971410 continue
1398-
1411+
13991412 ######################
14001413 # predict structures
14011414 ######################
14021415 try :
14031416 # get list of lengths
1404- query_sequence_len_array = sum ([[len (x )] * y
1417+ query_sequence_len_array = sum ([[len (x )] * y
14051418 for x ,y in zip (query_seqs_unique , query_seqs_cardinality )],[])
1406-
1419+
14071420 # decide how much to pad (to avoid recompiling)
14081421 if seq_len > pad_len :
14091422 if isinstance (recompile_padding , float ):
14101423 pad_len = math .ceil (seq_len * recompile_padding )
14111424 else :
14121425 pad_len = seq_len + recompile_padding
14131426 pad_len = min (pad_len , max_len )
1414-
1427+
14151428 # prep model and params
14161429 if first_job :
14171430 # if one job input adjust max settings
@@ -1423,7 +1436,7 @@ def run(
14231436 num_seqs = int (len (feature_dict ["msa" ]))
14241437
14251438 if use_templates : num_seqs += 4
1426-
1439+
14271440 # adjust max settings
14281441 max_seq = min (num_seqs , max_seq )
14291442 max_extra_seq = max (min (num_seqs - max_seq , max_extra_seq ), 1 )
@@ -1498,7 +1511,7 @@ def run(
14981511 scores_file = result_dir .joinpath (f"{ jobname } _scores_{ r } .json" )
14991512 with scores_file .open ("r" ) as handle :
15001513 scores .append (json .load (handle ))
1501-
1514+
15021515 # write alphafold-db format (pAE)
15031516 if "pae" in scores [0 ]:
15041517 af_pae_file = result_dir .joinpath (f"{ jobname } _predicted_aligned_error_v1.json" )
@@ -1535,7 +1548,7 @@ def run(
15351548 with zipfile .ZipFile (result_zip , "w" ) as result_zip :
15361549 for file in result_files :
15371550 result_zip .write (file , arcname = file .name )
1538-
1551+
15391552 # Delete only after the zip was successful, and also not the bibtex and config because we need those again
15401553 for file in result_files [:- 2 ]:
15411554 file .unlink ()
@@ -1737,7 +1750,7 @@ def main():
17371750 )
17381751
17391752 args = parser .parse_args ()
1740-
1753+
17411754 # disable unified memory
17421755 if args .disable_unified_memory :
17431756 for k in ENV .keys ():
@@ -1756,7 +1769,7 @@ def main():
17561769
17571770 queries , is_complex = get_queries (args .input , args .sort_queries_by )
17581771 model_type = set_model_type (is_complex , args .model_type )
1759-
1772+
17601773 download_alphafold_params (model_type , data_dir )
17611774
17621775 if args .msa_mode != "single_sequence" and not args .templates :
0 commit comments