1414
1515"""Functions for processing confidence metrics."""
1616
17- from typing import Dict , Optional , Tuple
17+ import jax .numpy as jnp
18+ import jax
1819import numpy as np
20+ from colabdesign .af .alphafold .common import residue_constants
1921import scipy .special
2022
23+ def compute_tol (prev_pos , current_pos , mask , use_jnp = False ):
24+ # Early stopping criteria based on criteria used in
25+ # AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
26+ _np = jnp if use_jnp else np
27+ dist = lambda x :_np .sqrt (((x [:,None ] - x [None ,:])** 2 ).sum (- 1 ))
28+ ca_idx = residue_constants .atom_order ['CA' ]
29+ sq_diff = _np .square (dist (prev_pos [:,ca_idx ])- dist (current_pos [:,ca_idx ]))
30+ mask_2d = mask [:,None ] * mask [None ,:]
31+ return _np .sqrt ((sq_diff * mask_2d ).sum ()/ mask_2d .sum () + 1e-8 )
2132
22- def compute_plddt (logits : np .ndarray ) -> np .ndarray :
23- """Computes per-residue pLDDT from logits.
2433
34+ def compute_plddt (logits , use_jnp = False ):
35+ """Computes per-residue pLDDT from logits.
2536 Args:
2637 logits: [num_res, num_bins] output from the PredictedLDDTHead.
27-
2838 Returns:
2939 plddt: [num_res] per-residue pLDDT.
3040 """
41+ if use_jnp :
42+ _np , _softmax = jnp , jax .nn .softmax
43+ else :
44+ _np , _softmax = np , scipy .special .softmax
45+
3146 num_bins = logits .shape [- 1 ]
3247 bin_width = 1.0 / num_bins
33- bin_centers = np .arange (start = 0.5 * bin_width , stop = 1.0 , step = bin_width )
34- probs = scipy . special . softmax (logits , axis = - 1 )
35- predicted_lddt_ca = np . sum (probs * bin_centers [None , :], axis = - 1 )
48+ bin_centers = _np .arange (start = 0.5 * bin_width , stop = 1.0 , step = bin_width )
49+ probs = _softmax (logits , axis = - 1 )
50+ predicted_lddt_ca = (probs * bin_centers [None , :]). sum ( - 1 )
3651 return predicted_lddt_ca * 100
3752
38-
39- def _calculate_bin_centers (breaks : np .ndarray ):
53+ def _calculate_bin_centers (breaks , use_jnp = False ):
4054 """Gets the bin centers from the bin edges.
41-
4255 Args:
4356 breaks: [num_bins - 1] the error bin edges.
44-
4557 Returns:
4658 bin_centers: [num_bins] the error bin centers.
4759 """
48- step = (breaks [1 ] - breaks [0 ])
60+ _np = jnp if use_jnp else np
61+ step = breaks [1 ] - breaks [0 ]
4962
5063 # Add half-step to get the center
5164 bin_centers = breaks + step / 2
52- # Add a catch-all bin at the end.
53- bin_centers = np .concatenate ([bin_centers , [bin_centers [- 1 ] + step ]],
54- axis = 0 )
55- return bin_centers
5665
66+ # Add a catch-all bin at the end.
67+ return _np .append (bin_centers , bin_centers [- 1 ] + step )
5768
5869def _calculate_expected_aligned_error (
59- alignment_confidence_breaks : np .ndarray ,
60- aligned_distance_error_probs : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
70+ alignment_confidence_breaks ,
71+ aligned_distance_error_probs ,
72+ use_jnp = False ):
6173 """Calculates expected aligned distance errors for every pair of residues.
62-
6374 Args:
6475 alignment_confidence_breaks: [num_bins - 1] the error bin edges.
6576 aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
6677 probs for each error bin, for each pair of residues.
67-
6878 Returns:
6979 predicted_aligned_error: [num_res, num_res] the expected aligned distance
7080 error for each pair of residues.
7181 max_predicted_aligned_error: The maximum predicted error possible.
7282 """
73- bin_centers = _calculate_bin_centers (alignment_confidence_breaks )
74-
83+ bin_centers = _calculate_bin_centers (alignment_confidence_breaks , use_jnp = use_jnp )
7584 # Tuple of expected aligned distance error and max possible error.
76- return (np .sum (aligned_distance_error_probs * bin_centers , axis = - 1 ),
77- np .asarray (bin_centers [- 1 ]))
78-
85+ pae = (aligned_distance_error_probs * bin_centers ).sum (- 1 )
86+ return (pae , bin_centers [- 1 ])
7987
80- def compute_predicted_aligned_error (
81- logits : np .ndarray ,
82- breaks : np .ndarray ) -> Dict [str , np .ndarray ]:
88+ def compute_predicted_aligned_error (logits , breaks , use_jnp = False ):
8389 """Computes aligned confidence metrics from logits.
84-
8590 Args:
8691 logits: [num_res, num_res, num_bins] the logits output from
8792 PredictedAlignedErrorHead.
@@ -94,62 +99,71 @@ def compute_predicted_aligned_error(
9499 error for each pair of residues.
95100 max_predicted_aligned_error: The maximum predicted error possible.
96101 """
97- aligned_confidence_probs = scipy .special .softmax (
98- logits ,
99- axis = - 1 )
100- predicted_aligned_error , max_predicted_aligned_error = (
101- _calculate_expected_aligned_error (
102- alignment_confidence_breaks = breaks ,
103- aligned_distance_error_probs = aligned_confidence_probs ))
102+ _softmax = jax .nn .softmax if use_jnp else scipy .special .softmax
103+ aligned_confidence_probs = _softmax (logits ,axis = - 1 )
104+ predicted_aligned_error , max_predicted_aligned_error = \
105+ _calculate_expected_aligned_error (breaks , aligned_confidence_probs , use_jnp = use_jnp )
106+
104107 return {
105108 'aligned_confidence_probs' : aligned_confidence_probs ,
106109 'predicted_aligned_error' : predicted_aligned_error ,
107110 'max_predicted_aligned_error' : max_predicted_aligned_error ,
108111 }
109112
110-
111- def predicted_tm_score (
112- logits : np .ndarray ,
113- breaks : np .ndarray ,
114- residue_weights : Optional [np .ndarray ] = None ) -> np .ndarray :
115- """Computes predicted TM alignment score.
113+ def predicted_tm_score (logits , breaks , residue_weights = None ,
114+ asym_id = None , use_jnp = False ):
115+ """Computes predicted TM alignment or predicted interface TM alignment score.
116116
117117 Args:
118118 logits: [num_res, num_res, num_bins] the logits output from
119119 PredictedAlignedErrorHead.
120120 breaks: [num_bins] the error bins.
121121 residue_weights: [num_res] the per residue weights to use for the
122122 expectation.
123+ asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
124+ ipTM calculation.
123125
124126 Returns:
125- ptm_score: the predicted TM alignment score.
127+ ptm_score: The predicted TM alignment or the predicted iTM score.
126128 """
129+ if use_jnp :
130+ _np , _softmax = jnp , jax .nn .softmax
131+ else :
132+ _np , _softmax = np , scipy .special .softmax
127133
128134 # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
129135 # exp. resolved head's probability.
130136 if residue_weights is None :
131- residue_weights = np .ones (logits .shape [0 ])
137+ residue_weights = _np .ones (logits .shape [0 ])
132138
133- bin_centers = _calculate_bin_centers (breaks )
139+ bin_centers = _calculate_bin_centers (breaks , use_jnp = use_jnp )
140+ num_res = residue_weights .shape [0 ]
134141
135- num_res = np .sum (residue_weights )
136142 # Clip num_res to avoid negative/undefined d0.
137- clipped_num_res = max ( num_res , 19 )
143+ clipped_num_res = _np . maximum ( residue_weights . sum () , 19 )
138144
139- # Compute d_0(num_res) as defined by TM-score, eqn. (5) in
140- # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
141- # Yang & Skolnick "Scoring function for automated
142- # assessment of protein structure template quality" 2004
145+ # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
146+ # "Scoring function for automated assessment of protein structure template
147+ # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
143148 d0 = 1.24 * (clipped_num_res - 15 ) ** (1. / 3 ) - 1.8
144149
145- # Convert logits to probs
146- probs = scipy .special .softmax (logits , axis = - 1 )
150+ # Convert logits to probs.
151+ probs = _softmax (logits , axis = - 1 )
152+
153+ # TM-Score term for every bin.
154+ tm_per_bin = 1. / (1 + _np .square (bin_centers ) / _np .square (d0 ))
155+ # E_distances tm(distance).
156+ predicted_tm_term = (probs * tm_per_bin ).sum (- 1 )
157+
158+ if asym_id is None :
159+ pair_mask = _np .full ((num_res ,num_res ),True )
160+ else :
161+ pair_mask = asym_id [:, None ] != asym_id [None , :]
162+
163+ predicted_tm_term *= pair_mask
147164
148- # TM-Score term for every bin
149- tm_per_bin = 1. / (1 + np .square (bin_centers ) / np .square (d0 ))
150- # E_distances tm(distance)
151- predicted_tm_term = np .sum (probs * tm_per_bin , axis = - 1 )
165+ pair_residue_weights = pair_mask * (residue_weights [None , :] * residue_weights [:, None ])
166+ normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights .sum (- 1 , keepdims = True ))
167+ per_alignment = (predicted_tm_term * normed_residue_mask ).sum (- 1 )
152168
153- normed_residue_mask = residue_weights / (1e-8 + residue_weights .sum ())
154- per_alignment = np .sum (predicted_tm_term * normed_residue_mask , axis = - 1 )
155- return np .asarray (per_alignment [(per_alignment * residue_weights ).argmax ()])
169+ return (per_alignment * residue_weights ).max ()
0 commit comments