Skip to content

Commit 34921de

Browse files
committed
skip templatesearch when custom_temp_path provided
1 parent affc6e0 commit 34921de

File tree

1 file changed

+71
-58
lines changed

1 file changed

+71
-58
lines changed

colabfold/batch.py

Lines changed: 71 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -299,19 +299,19 @@ def relax_me(pdb_filename=None, pdb_lines=None, pdb_obj=None, use_gpu=False):
299299
from alphafold.common import residue_constants
300300
from alphafold.relax import relax
301301

302-
if pdb_obj is None:
302+
if pdb_obj is None:
303303
if pdb_lines is None:
304304
pdb_lines = Path(pdb_filename).read_text()
305305
pdb_obj = protein.from_pdb_string(pdb_lines)
306-
306+
307307
amber_relaxer = relax.AmberRelaxation(
308308
max_iterations=0,
309309
tolerance=2.39,
310310
stiffness=10.0,
311311
exclude_residues=[],
312312
max_outer_iterations=3,
313313
use_gpu=use_gpu)
314-
314+
315315
relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=pdb_obj)
316316
return relaxed_pdb_lines
317317

@@ -321,7 +321,7 @@ def __init__(self, prefix: str, result_dir: Path):
321321
self.result_dir = result_dir
322322
self.tag = None
323323
self.files = {}
324-
324+
325325
def get(self, x: str, ext:str) -> Path:
326326
if self.tag not in self.files:
327327
self.files[self.tag] = []
@@ -366,13 +366,13 @@ def predict_structure(
366366

367367
# iterate through random seeds
368368
for seed_num, seed in enumerate(range(random_seed, random_seed+num_seeds)):
369-
369+
370370
# iterate through models
371371
for model_num, (model_name, model_runner, params) in enumerate(model_runner_and_params):
372-
372+
373373
# swap params to avoid recompiling
374374
model_runner.params = params
375-
375+
376376
#########################
377377
# process input features
378378
#########################
@@ -383,24 +383,24 @@ def predict_structure(
383383
input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0]
384384
else:
385385
if model_num == 0:
386-
input_features = model_runner.process_features(feature_dict, random_seed=seed)
386+
input_features = model_runner.process_features(feature_dict, random_seed=seed)
387387
r = input_features["aatype"].shape[0]
388388
input_features["asym_id"] = np.tile(feature_dict["asym_id"],r).reshape(r,-1)
389389
if seq_len < pad_len:
390-
input_features = pad_input(input_features, model_runner,
390+
input_features = pad_input(input_features, model_runner,
391391
model_name, pad_len, use_templates)
392392
logger.info(f"Padding length to {pad_len}")
393-
393+
394394

395395
tag = f"{model_type}_{model_name}_seed_{seed:03d}"
396396
model_names.append(tag)
397397
files.set_tag(tag)
398-
398+
399399
########################
400400
# predict
401401
########################
402402
start = time.time()
403-
403+
404404
# monitor intermediate results
405405
def callback(result, recycles):
406406
if recycles == 0: result.pop("tol",None)
@@ -419,12 +419,12 @@ def callback(result, recycles):
419419
result=result, b_factors=b_factors,
420420
remove_leading_feature_dimension=("multimer" not in model_type))
421421
files.get("unrelaxed",f"r{recycles}.pdb").write_text(protein.to_pdb(unrelaxed_protein))
422-
422+
423423
if save_all:
424424
with files.get("all",f"r{recycles}.pickle").open("wb") as handle:
425425
pickle.dump(result, handle)
426426
del unrelaxed_protein
427-
427+
428428
return_representations = save_all or save_single_representations or save_pair_representations
429429

430430
# predict
@@ -439,9 +439,9 @@ def callback(result, recycles):
439439
########################
440440
# parse results
441441
########################
442-
442+
443443
# summary metrics
444-
mean_scores.append(result["ranking_confidence"])
444+
mean_scores.append(result["ranking_confidence"])
445445
if recycles == 0: result.pop("tol",None)
446446
if not is_complex: result.pop("iptm",None)
447447
print_line = ""
@@ -469,7 +469,7 @@ def callback(result, recycles):
469469

470470
#########################
471471
# save results
472-
#########################
472+
#########################
473473

474474
# save pdb
475475
protein_lines = protein.to_pdb(unrelaxed_protein)
@@ -498,12 +498,12 @@ def callback(result, recycles):
498498
del pae
499499
del plddt
500500
json.dump(scores, handle)
501-
501+
502502
del result, unrelaxed_protein
503503

504504
# early stop criteria fulfilled
505505
if mean_scores[-1] > stop_at_score: break
506-
506+
507507
# early stop criteria fulfilled
508508
if mean_scores[-1] > stop_at_score: break
509509

@@ -514,7 +514,7 @@ def callback(result, recycles):
514514
###################################################
515515
# rerank models based on predicted confidence
516516
###################################################
517-
517+
518518
rank, metric = [],[]
519519
result_files = []
520520
logger.info(f"reranking models by '{rank_by}' metric")
@@ -527,7 +527,7 @@ def callback(result, recycles):
527527
if n < num_relax:
528528
start = time.time()
529529
pdb_lines = relax_me(pdb_lines=unrelaxed_pdb_lines[key], use_gpu=use_gpu_relax)
530-
files.get("relaxed","pdb").write_text(pdb_lines)
530+
files.get("relaxed","pdb").write_text(pdb_lines)
531531
logger.info(f"Relaxation took {(time.time() - start):.1f}s")
532532

533533
# rename files to include rank
@@ -538,7 +538,7 @@ def callback(result, recycles):
538538
new_file = result_dir.joinpath(f"{prefix}_{x}_{new_tag}.{ext}")
539539
file.rename(new_file)
540540
result_files.append(new_file)
541-
541+
542542
return {"rank":rank,
543543
"metric":metric,
544544
"result_files":result_files}
@@ -649,10 +649,10 @@ def get_queries(
649649
# sort by seq. len
650650
if sort_queries_by == "length":
651651
queries.sort(key=lambda t: len("".join(t[1])))
652-
652+
653653
elif sort_queries_by == "random":
654654
random.shuffle(queries)
655-
655+
656656
is_complex = False
657657
for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
658658
if isinstance(query_sequence, list):
@@ -719,6 +719,7 @@ def pad_sequences(
719719
def get_msa_and_templates(
720720
jobname: str,
721721
query_sequences: Union[str, List[str]],
722+
a3m_lines: Optional[List[str]],
722723
result_dir: Path,
723724
msa_mode: str,
724725
use_templates: bool,
@@ -749,17 +750,29 @@ def get_msa_and_templates(
749750
# get template features
750751
template_features = []
751752
if use_templates:
752-
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
753-
query_seqs_unique,
754-
str(result_dir.joinpath(jobname)),
755-
use_env,
756-
use_templates=True,
757-
host_url=host_url,
758-
)
753+
# Skip template search when custom_template_path is provided
759754
if custom_template_path is not None:
755+
if a3m_lines is None:
756+
a3m_lines_mmseqs2 = run_mmseqs2(
757+
query_seqs_unique,
758+
str(result_dir.joinpath(jobname)),
759+
use_env,
760+
use_templates=False,
761+
host_url=host_url,
762+
)
763+
else:
764+
a3m_lines_mmseqs2 = a3m_lines
760765
template_paths = {}
761766
for index in range(0, len(query_seqs_unique)):
762767
template_paths[index] = custom_template_path
768+
else:
769+
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
770+
query_seqs_unique,
771+
str(result_dir.joinpath(jobname)),
772+
use_env,
773+
use_templates=True,
774+
host_url=host_url,
775+
)
763776
if template_paths is None:
764777
logger.info("No template detected")
765778
for index in range(0, len(query_seqs_unique)):
@@ -966,7 +979,7 @@ def generate_input_feature(
966979

967980
# bugfix
968981
a3m_lines = f">0\n{full_sequence}\n"
969-
a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa)
982+
a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa)
970983

971984
input_feature = build_monomer_feature(full_sequence, a3m_lines, mk_mock_template(full_sequence))
972985
input_feature["residue_index"] = np.concatenate([np.arange(L) for L in Ls])
@@ -987,7 +1000,7 @@ def generate_input_feature(
9871000
chain_cnt = 0
9881001
# for each unique sequence
9891002
for sequence_index, sequence in enumerate(query_seqs_unique):
990-
1003+
9911004
# get unpaired msa
9921005
if unpaired_msa is None:
9931006
input_msa = f">{101 + sequence_index}\n{sequence}"
@@ -1243,7 +1256,7 @@ def run(
12431256
if max_msa is not None:
12441257
max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")]
12451258

1246-
if kwargs.pop("use_amber", False) and num_relax == 0:
1259+
if kwargs.pop("use_amber", False) and num_relax == 0:
12471260
num_relax = num_models * num_seeds
12481261

12491262
if len(kwargs) > 0:
@@ -1263,22 +1276,22 @@ def run(
12631276
L = len("".join(query_sequence))
12641277
if L > max_len: max_len = L
12651278
if N > max_num: max_num = N
1266-
1279+
12671280
# get max sequences
12681281
# 512 5120 = alphafold_ptm (models 1,3,4)
12691282
# 512 1024 = alphafold_ptm (models 2,5)
12701283
# 508 2048 = alphafold-multimer_v3 (models 1,2,3)
12711284
# 508 1152 = alphafold-multimer_v3 (models 4,5)
12721285
# 252 1152 = alphafold-multimer_v[1,2]
1273-
1286+
12741287
set_if = lambda x,y: y if x is None else x
12751288
if model_type in ["alphafold2_multimer_v1","alphafold2_multimer_v2"]:
12761289
(max_seq, max_extra_seq) = (set_if(max_seq,252), set_if(max_extra_seq,1152))
12771290
elif model_type == "alphafold2_multimer_v3":
12781291
(max_seq, max_extra_seq) = (set_if(max_seq,508), set_if(max_extra_seq,2048))
12791292
else:
12801293
(max_seq, max_extra_seq) = (set_if(max_seq,512), set_if(max_extra_seq,5120))
1281-
1294+
12821295
if msa_mode == "single_sequence":
12831296
num_seqs = 1
12841297
if is_complex and "multimer" not in model_type: num_seqs += max_num
@@ -1337,7 +1350,7 @@ def run(
13371350
first_job = True
13381351
for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
13391352
jobname = safe_filename(raw_jobname)
1340-
1353+
13411354
#######################################
13421355
# check if job has already finished
13431356
#######################################
@@ -1359,59 +1372,59 @@ def run(
13591372
# generate MSA (a3m_lines) and templates
13601373
###########################################
13611374
try:
1362-
if a3m_lines is None:
1375+
if a3m_lines is None:
13631376
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \
1364-
= get_msa_and_templates(jobname, query_sequence, result_dir, msa_mode, use_templates,
1377+
= get_msa_and_templates(jobname, query_sequence, a3m_lines, result_dir, msa_mode, use_templates,
13651378
custom_template_path, pair_mode, pairing_strategy, host_url)
1366-
1367-
elif a3m_lines is not None:
1379+
1380+
elif a3m_lines is not None:
13681381
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \
13691382
= unserialize_msa(a3m_lines, query_sequence)
1370-
if use_templates:
1383+
if use_templates:
13711384
(_, _, _, _, template_features) \
1372-
= get_msa_and_templates(jobname, query_seqs_unique, result_dir, 'single_sequence', use_templates,
1385+
= get_msa_and_templates(jobname, query_seqs_unique, a3m_lines, result_dir, 'single_sequence', use_templates,
13731386
custom_template_path, pair_mode, pairing_strategy, host_url)
1374-
1387+
13751388
# save a3m
13761389
msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality)
13771390
result_dir.joinpath(f"{jobname}.a3m").write_text(msa)
1378-
1391+
13791392
except Exception as e:
13801393
logger.exception(f"Could not get MSA/templates for {jobname}: {e}")
13811394
continue
1382-
1395+
13831396
#######################
13841397
# generate features
13851398
#######################
13861399
try:
13871400
(feature_dict, domain_names) \
13881401
= generate_input_feature(query_seqs_unique, query_seqs_cardinality, unpaired_msa, paired_msa,
13891402
template_features, is_complex, model_type, max_seq=max_seq)
1390-
1403+
13911404
# to allow display of MSA info during colab/chimera run (thanks tomgoddard)
13921405
if feature_dict_callback is not None:
13931406
feature_dict_callback(feature_dict)
1394-
1407+
13951408
except Exception as e:
13961409
logger.exception(f"Could not generate input features {jobname}: {e}")
13971410
continue
1398-
1411+
13991412
######################
14001413
# predict structures
14011414
######################
14021415
try:
14031416
# get list of lengths
1404-
query_sequence_len_array = sum([[len(x)] * y
1417+
query_sequence_len_array = sum([[len(x)] * y
14051418
for x,y in zip(query_seqs_unique, query_seqs_cardinality)],[])
1406-
1419+
14071420
# decide how much to pad (to avoid recompiling)
14081421
if seq_len > pad_len:
14091422
if isinstance(recompile_padding, float):
14101423
pad_len = math.ceil(seq_len * recompile_padding)
14111424
else:
14121425
pad_len = seq_len + recompile_padding
14131426
pad_len = min(pad_len, max_len)
1414-
1427+
14151428
# prep model and params
14161429
if first_job:
14171430
# if one job input adjust max settings
@@ -1423,7 +1436,7 @@ def run(
14231436
num_seqs = int(len(feature_dict["msa"]))
14241437

14251438
if use_templates: num_seqs += 4
1426-
1439+
14271440
# adjust max settings
14281441
max_seq = min(num_seqs, max_seq)
14291442
max_extra_seq = max(min(num_seqs - max_seq, max_extra_seq), 1)
@@ -1498,7 +1511,7 @@ def run(
14981511
scores_file = result_dir.joinpath(f"{jobname}_scores_{r}.json")
14991512
with scores_file.open("r") as handle:
15001513
scores.append(json.load(handle))
1501-
1514+
15021515
# write alphafold-db format (pAE)
15031516
if "pae" in scores[0]:
15041517
af_pae_file = result_dir.joinpath(f"{jobname}_predicted_aligned_error_v1.json")
@@ -1535,7 +1548,7 @@ def run(
15351548
with zipfile.ZipFile(result_zip, "w") as result_zip:
15361549
for file in result_files:
15371550
result_zip.write(file, arcname=file.name)
1538-
1551+
15391552
# Delete only after the zip was successful, and also not the bibtex and config because we need those again
15401553
for file in result_files[:-2]:
15411554
file.unlink()
@@ -1737,7 +1750,7 @@ def main():
17371750
)
17381751

17391752
args = parser.parse_args()
1740-
1753+
17411754
# disable unified memory
17421755
if args.disable_unified_memory:
17431756
for k in ENV.keys():
@@ -1756,7 +1769,7 @@ def main():
17561769

17571770
queries, is_complex = get_queries(args.input, args.sort_queries_by)
17581771
model_type = set_model_type(is_complex, args.model_type)
1759-
1772+
17601773
download_alphafold_params(model_type, data_dir)
17611774

17621775
if args.msa_mode != "single_sequence" and not args.templates:

0 commit comments

Comments
 (0)