Skip to content

Commit f74e92c

Browse files
committed
Add missing directory AFfine
1 parent 94895a1 commit f74e92c

File tree

5 files changed

+37
-14
lines changed

5 files changed

+37
-14
lines changed

AFfine/af2_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ def get_atom_positions_from_pdb(pdb_file_path: str, aligned_sequences: tuple[str
242242
if pep_len and anchors:
243243
mhc_len = all_positions.shape[0] - pep_len
244244
full_anchors = [i - 1 + mhc_len for i in anchors] #(2-1) + 180 = 181
245+
full_anchors = sorted(full_anchors)
246+
if len(full_anchors) == 2: # in case of MHC-I distance between two anchors is long and distrubs the folding, therefore, we add two other positions
247+
full_anchors = [full_anchors[0], full_anchors[0]+2, full_anchors[1]-2, full_anchors[1]]
248+
print('mhc_1 anchor initial guess for positions:', full_anchors)
249+
all_positions *= 0. # zero for mhc 1
245250
core_region = [i for i in range(mhc_len, mhc_len + pep_len) if i not in full_anchors] # (i in (180, 189) if not anchor)
246251
mask = np.ones([num_res_query, residue_constants.atom_type_num, 3], dtype=np.float32)
247252
mask[core_region] = 0.

AFfine/alphafold/model/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class RunModel:
5151

5252
def __init__(self,
5353
config: ml_collections.ConfigDict,
54-
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
54+
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
55+
return_representations: bool = False): # added by https://github.com/AmirAsgary
5556
self.config = config
5657
self.params = params
5758

@@ -73,7 +74,10 @@ def _forward_fn(batch,
7374
self.apply = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=True)).apply)
7475
self.init = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=False)).init)
7576
self.apply_infer = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=True)).apply)
76-
self.apply_predict = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=False)).apply)
77+
#self.apply_predict = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=False)).apply)
78+
self.apply_predict = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=False,
79+
return_representations=return_representations)).apply, static_argnames=("return_representations",) # added by https://github.com/AmirAsgary
80+
)
7781
def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
7882
"""Initializes the model parameters.
7983

AFfine/predict_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,7 @@ def predict_structure(
222222

223223
metric_tags = 'plddt ptm predicted_aligned_error'.split()
224224
if return_all_outputs:
225-
metric_tags = ['distogram', 'masked_msa', 'predicted_aligned_error', 'predicted_lddt',
226-
'structure_module', 'plddt', 'aligned_confidence_probs',
227-
'max_predicted_aligned_error', 'ptm', 'representations']
225+
metric_tags = 'plddt ptm predicted_aligned_error representations'.split() # added by Amir
228226
all_metrics = {} # eventual return value
229227

230228
metrics = {} # stupid duplication
@@ -247,12 +245,11 @@ def predict_structure(
247245
template_sequence=None, aln=aln, anchors=anchors,
248246
peptide_seq=peptide_seq) #added by https://github.com/AmirAsgary/
249247
prediction_result = model_runner.predict(processed_feature_dict, initial_guess,
250-
return_representations=return_representations) #added by https://github.com/AmirAsgary/
248+
return_representations=return_all_outputs) #added by https://github.com/AmirAsgary/
251249
else:
252-
prediction_result = model_runner.predict(processed_feature_dict, return_representations=return_representations) #added by https://github.com/AmirAsgary/
250+
prediction_result = model_runner.predict(processed_feature_dict, return_representations=return_all_outputs) #added by https://github.com/AmirAsgary/
253251

254252
###
255-
print('DEBUUUUUUUUUUUUG: \n',prediction_result.keys(), '\n ----------------------------')
256253
unrelaxed_protein = protein.from_prediction(
257254
processed_feature_dict, prediction_result)
258255
unrelaxed_pdb_lines.append(protein.to_pdb(unrelaxed_protein))
@@ -282,13 +279,16 @@ def predict_structure(
282279

283280

284281
#plddts_ranked[f"model_{n+1}"] = plddts[r]
285-
286282
if dump_metrics:
287283
metrics_prefix = f'{prefix}_model_{n+1}_{model_names[r]}'
288284
for tag in metric_tags:
289285
m = metrics[tag][r]
290286
if m is not None:
291-
np.save(f'{metrics_prefix}_{tag}.npy', m)
287+
if tag != 'representations': # added by Amir
288+
np.save(f'{metrics_prefix}_{tag}.npy', m)
289+
else:
290+
with open(f'{metrics_prefix}_{tag}.pkl', 'wb') as f:
291+
pickle.dump(m, f)
292292

293293
return all_metrics
294294

@@ -348,7 +348,7 @@ def load_model_runners(
348348
model_name=model_name, data_dir=data_dir)
349349

350350
model_runners[model_name] = model.RunModel(
351-
model_config, model_params)
351+
model_config, model_params, args.return_all_outputs) #return_all_outputs added by https://github.com/AmirAsgary
352352
return model_runners
353353

354354

AFfine/run_finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
'data.get_model_haiku_params; should contain params/ subfolder')
7575

7676
flags.DEFINE_bool('only_fit_binder', False, help='if True, dont fit alphafold params')
77-
flags.DEFINE_bool('freeze_binder', False, help='if True, dont fit binder params')
77+
flags.DEFINE_bool('freeze_binder', False, help='if True, dont fit binder params') # binder means --> logistic regression part
7878
flags.DEFINE_bool('freeze_everything', False, help='if True, dont fit anything')
7979
flags.DEFINE_bool('no_ramp', False, help='if True, dont ramp')
8080
flags.DEFINE_bool('no_valid', False, help='if True, dont compute valid stats')

AFfine/run_prediction.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
parser.add_argument('--no_initial_guess', action='store_true', default=False, help='When active, no intial guess is used to direct modeling and only template is used.')
7272
parser.add_argument('--return_all_outputs', action='store_true', default=False, help='Save all alphafold outputs including evoformer output')
73+
parser.add_argument('--use_msa', action='store_true', default=False, help='If Enabled, use MSA for prediction. If not, only template is used.')
7374
args = parser.parse_args()
7475

7576
import os
@@ -117,6 +118,9 @@
117118
template_pdb_dict = targetl.template_pdb_dict
118119
with open(template_pdb_dict, 'r') as f:
119120
template_pdb_dict = json.load(f)
121+
else: # added by Amir for initial guess condition and getting dict from input tsv
122+
template_pdb_dict = None
123+
120124

121125
print(alignfile)
122126
assert exists(alignfile)
@@ -173,11 +177,21 @@
173177
)
174178
template_features_list.append(template_features)
175179

180+
181+
176182
all_template_features = predict_utils.compile_template_features(
177183
template_features_list)
178184

179-
msa=[query_sequence]
180-
deletion_matrix=[[0]*len(query_sequence)]
185+
if not args.use_msa: # added by Amir for using msa or not
186+
# if we are not using MSA, we need to set the deletion matrix
187+
msa=[query_sequence]
188+
deletion_matrix=[[0]*len(query_sequence)]
189+
else: # added by Amir for using msa or not
190+
# generate arbeitary msa from input sequence
191+
msa = [query_sequence] + msa
192+
193+
194+
181195

182196
all_metrics = predict_utils.run_alphafold_prediction(
183197
query_sequence=query_sequence,

0 commit comments

Comments
 (0)