Skip to content

Commit 4ce69b3

Browse files
committed
add mask for binder loss
1 parent 85679d6 commit 4ce69b3

File tree

2 files changed

+98
-68
lines changed

2 files changed

+98
-68
lines changed

colabdesign/af/alphafold/common/confidence.py

Lines changed: 74 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,74 +14,79 @@
1414

1515
"""Functions for processing confidence metrics."""
1616

17-
from typing import Dict, Optional, Tuple
17+
import jax.numpy as jnp
18+
import jax
1819
import numpy as np
20+
from colabdesign.af.alphafold.common import residue_constants
1921
import scipy.special
2022

23+
def compute_tol(prev_pos, current_pos, mask, use_jnp=False):
24+
# Early stopping criteria based on criteria used in
25+
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
26+
_np = jnp if use_jnp else np
27+
dist = lambda x:_np.sqrt(((x[:,None] - x[None,:])**2).sum(-1))
28+
ca_idx = residue_constants.atom_order['CA']
29+
sq_diff = _np.square(dist(prev_pos[:,ca_idx])-dist(current_pos[:,ca_idx]))
30+
mask_2d = mask[:,None] * mask[None,:]
31+
return _np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum() + 1e-8)
2132

22-
def compute_plddt(logits: np.ndarray) -> np.ndarray:
23-
"""Computes per-residue pLDDT from logits.
2433

34+
def compute_plddt(logits, use_jnp=False):
35+
"""Computes per-residue pLDDT from logits.
2536
Args:
2637
logits: [num_res, num_bins] output from the PredictedLDDTHead.
27-
2838
Returns:
2939
plddt: [num_res] per-residue pLDDT.
3040
"""
41+
if use_jnp:
42+
_np, _softmax = jnp, jax.nn.softmax
43+
else:
44+
_np, _softmax = np, scipy.special.softmax
45+
3146
num_bins = logits.shape[-1]
3247
bin_width = 1.0 / num_bins
33-
bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
34-
probs = scipy.special.softmax(logits, axis=-1)
35-
predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
48+
bin_centers = _np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
49+
probs = _softmax(logits, axis=-1)
50+
predicted_lddt_ca = (probs * bin_centers[None, :]).sum(-1)
3651
return predicted_lddt_ca * 100
3752

38-
39-
def _calculate_bin_centers(breaks: np.ndarray):
53+
def _calculate_bin_centers(breaks, use_jnp=False):
4054
"""Gets the bin centers from the bin edges.
41-
4255
Args:
4356
breaks: [num_bins - 1] the error bin edges.
44-
4557
Returns:
4658
bin_centers: [num_bins] the error bin centers.
4759
"""
48-
step = (breaks[1] - breaks[0])
60+
_np = jnp if use_jnp else np
61+
step = breaks[1] - breaks[0]
4962

5063
# Add half-step to get the center
5164
bin_centers = breaks + step / 2
52-
# Add a catch-all bin at the end.
53-
bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
54-
axis=0)
55-
return bin_centers
5665

66+
# Add a catch-all bin at the end.
67+
return _np.append(bin_centers, bin_centers[-1] + step)
5768

5869
def _calculate_expected_aligned_error(
59-
alignment_confidence_breaks: np.ndarray,
60-
aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
70+
alignment_confidence_breaks,
71+
aligned_distance_error_probs,
72+
use_jnp=False):
6173
"""Calculates expected aligned distance errors for every pair of residues.
62-
6374
Args:
6475
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
6576
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
6677
probs for each error bin, for each pair of residues.
67-
6878
Returns:
6979
predicted_aligned_error: [num_res, num_res] the expected aligned distance
7080
error for each pair of residues.
7181
max_predicted_aligned_error: The maximum predicted error possible.
7282
"""
73-
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
74-
83+
bin_centers = _calculate_bin_centers(alignment_confidence_breaks, use_jnp=use_jnp)
7584
# Tuple of expected aligned distance error and max possible error.
76-
return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
77-
np.asarray(bin_centers[-1]))
78-
85+
pae = (aligned_distance_error_probs * bin_centers).sum(-1)
86+
return (pae, bin_centers[-1])
7987

80-
def compute_predicted_aligned_error(
81-
logits: np.ndarray,
82-
breaks: np.ndarray) -> Dict[str, np.ndarray]:
88+
def compute_predicted_aligned_error(logits, breaks, use_jnp=False):
8389
"""Computes aligned confidence metrics from logits.
84-
8590
Args:
8691
logits: [num_res, num_res, num_bins] the logits output from
8792
PredictedAlignedErrorHead.
@@ -94,62 +99,71 @@ def compute_predicted_aligned_error(
9499
error for each pair of residues.
95100
max_predicted_aligned_error: The maximum predicted error possible.
96101
"""
97-
aligned_confidence_probs = scipy.special.softmax(
98-
logits,
99-
axis=-1)
100-
predicted_aligned_error, max_predicted_aligned_error = (
101-
_calculate_expected_aligned_error(
102-
alignment_confidence_breaks=breaks,
103-
aligned_distance_error_probs=aligned_confidence_probs))
102+
_softmax = jax.nn.softmax if use_jnp else scipy.special.softmax
103+
aligned_confidence_probs = _softmax(logits,axis=-1)
104+
predicted_aligned_error, max_predicted_aligned_error = \
105+
_calculate_expected_aligned_error(breaks, aligned_confidence_probs, use_jnp=use_jnp)
106+
104107
return {
105108
'aligned_confidence_probs': aligned_confidence_probs,
106109
'predicted_aligned_error': predicted_aligned_error,
107110
'max_predicted_aligned_error': max_predicted_aligned_error,
108111
}
109112

110-
111-
def predicted_tm_score(
112-
logits: np.ndarray,
113-
breaks: np.ndarray,
114-
residue_weights: Optional[np.ndarray] = None) -> np.ndarray:
115-
"""Computes predicted TM alignment score.
113+
def predicted_tm_score(logits, breaks, residue_weights = None,
114+
asym_id = None, use_jnp=False):
115+
"""Computes predicted TM alignment or predicted interface TM alignment score.
116116
117117
Args:
118118
logits: [num_res, num_res, num_bins] the logits output from
119119
PredictedAlignedErrorHead.
120120
breaks: [num_bins] the error bins.
121121
residue_weights: [num_res] the per residue weights to use for the
122122
expectation.
123+
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
124+
ipTM calculation.
123125
124126
Returns:
125-
ptm_score: the predicted TM alignment score.
127+
ptm_score: The predicted TM alignment or the predicted iTM score.
126128
"""
129+
if use_jnp:
130+
_np, _softmax = jnp, jax.nn.softmax
131+
else:
132+
_np, _softmax = np, scipy.special.softmax
127133

128134
# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
129135
# exp. resolved head's probability.
130136
if residue_weights is None:
131-
residue_weights = np.ones(logits.shape[0])
137+
residue_weights = _np.ones(logits.shape[0])
132138

133-
bin_centers = _calculate_bin_centers(breaks)
139+
bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp)
140+
num_res = residue_weights.shape[0]
134141

135-
num_res = np.sum(residue_weights)
136142
# Clip num_res to avoid negative/undefined d0.
137-
clipped_num_res = max(num_res, 19)
143+
clipped_num_res = _np.maximum(residue_weights.sum(), 19)
138144

139-
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in
140-
# http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
141-
# Yang & Skolnick "Scoring function for automated
142-
# assessment of protein structure template quality" 2004
145+
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
146+
# "Scoring function for automated assessment of protein structure template
147+
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
143148
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
144149

145-
# Convert logits to probs
146-
probs = scipy.special.softmax(logits, axis=-1)
150+
# Convert logits to probs.
151+
probs = _softmax(logits, axis=-1)
152+
153+
# TM-Score term for every bin.
154+
tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0))
155+
# E_distances tm(distance).
156+
predicted_tm_term = (probs * tm_per_bin).sum(-1)
157+
158+
if asym_id is None:
159+
pair_mask = _np.full((num_res,num_res),True)
160+
else:
161+
pair_mask = asym_id[:, None] != asym_id[None, :]
162+
163+
predicted_tm_term *= pair_mask
147164

148-
# TM-Score term for every bin
149-
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
150-
# E_distances tm(distance)
151-
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
165+
pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
166+
normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True))
167+
per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1)
152168

153-
normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum())
154-
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
155-
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
169+
return (per_alignment * residue_weights).max()

colabdesign/af/loss.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from colabdesign.shared.utils import Key, copy_dict
66
from colabdesign.shared.protein import jnp_rmsd_w, _np_kabsch, _np_rmsd, _np_get_6D_loss
77
from colabdesign.af.alphafold.model import model, folding, all_atom
8-
from colabdesign.af.alphafold.common import confidence_jax, residue_constants
8+
from colabdesign.af.alphafold.common import confidence, residue_constants
99

1010
####################################################
1111
# AF_LOSS - setup loss function
@@ -36,12 +36,16 @@ def _loss_binder(self, inputs, outputs, aux):
3636
'''get losses'''
3737
opt = inputs["opt"]
3838
zeros = jnp.zeros(sum(self._lengths))
39-
binder_id = zeros.at[self._target_len:].set(1)
39+
mask_1d = inputs["seq_mask"]
40+
binder_id = zeros.at[-self._binder_len:].set(1)
41+
binder_id = jnp.where(mask_1d, binder_id, 0)
4042
if "hotspot" in opt:
4143
target_id = zeros.at[opt["hotspot"]].set(1)
44+
target_id = jnp.where(mask_1d, target_id, 0)
4245
i_con_loss = get_con_loss(inputs, outputs, opt["i_con"], mask_1d=target_id, mask_1b=binder_id)
4346
else:
4447
target_id = zeros.at[:self._target_len].set(1)
48+
target_id = jnp.where(mask_1d, target_id, 0)
4549
i_con_loss = get_con_loss(inputs, outputs, opt["i_con"], mask_1d=binder_id, mask_1b=target_id)
4650

4751
# unsupervised losses
@@ -68,10 +72,17 @@ def _loss_binder(self, inputs, outputs, aux):
6872
# compute fape
6973
fape = get_fape_loss(inputs, outputs, clamp=opt["fape_cutoff"], return_mtx=True)
7074

75+
mask_1d = inputs["batch"]["all_atom_mask"][:,1]
76+
mask_2d = mask_1d[:,None] * mask_1d[None,:]
77+
def exclude_target(x):
78+
x = x[-self._binder_len:,:]
79+
m = mask_2d[-self._binder_len:,:]
80+
return (x*m).sum() / (m.sum() + 1e-8)
81+
7182
aux["losses"].update({
7283
"rmsd": aln["rmsd"],
73-
"dgram_cce": cce[self._target_len:,:].mean(),
74-
"fape": fape[self._target_len:,:].mean()
84+
"dgram_cce": exclude_target(cce),
85+
"fape": exclude_target(fape)
7586
})
7687

7788
else:
@@ -198,10 +209,15 @@ def get_pae(outputs):
198209
return (prob*bin_centers).sum(-1)
199210

200211
def get_ptm(inputs, outputs, interface=False):
201-
pae = outputs["predicted_aligned_error"]
202-
if "asym_id" not in pae:
203-
pae["asym_id"] = inputs["asym_id"]
204-
return confidence_jax.predicted_tm_score_jax(**pae, interface=interface)
212+
pae = {"residue_weights":inputs["seq_mask"],
213+
**outputs["predicted_aligned_error"]}
214+
if interface:
215+
if "asym_id" not in pae:
216+
pae["asym_id"] = inputs["asym_id"]
217+
else:
218+
if "asym_id" in pae:
219+
pae.pop("asym_id")
220+
return confidence.predicted_tm_score(**pae, use_jnp=True)
205221

206222
def get_contact_map(outputs, dist=8.0):
207223
'''get contact map from distogram'''

0 commit comments

Comments
 (0)