Skip to content

Commit e7bb3de

Browse files
authored
adding partial and bias support to design_semigreedy()
2 parents 5cd4191 + 3b0d23c commit e7bb3de

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

colabdesign/af/design.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -360,21 +360,31 @@ def design_semigreedy(self, iters=100, tries=20, num_recycles=None, num_models=1
360360
if self._k == 0:
361361
self.run(num_recycles=num_recycles, backprop=False)
362362

363-
def mut(seq, plddt=None):
363+
def mut(seq, plddt=None, bias=None):
364364
'''mutate random position'''
365+
# bias mutations towards positions with low pLDDT
366+
# https://www.biorxiv.org/content/10.1101/2021.08.24.457549v1
365367
L,A = seq.shape[-2:]
366-
while True:
367-
if plddt is None:
368-
i = jax.random.randint(self.key(),[],0,L)
369-
else:
370-
p = (1-plddt)/(1-plddt).sum(-1,keepdims=True)
371-
# bias mutations towards positions with low pLDDT
372-
# https://www.biorxiv.org/content/10.1101/2021.08.24.457549v1
373-
i = jax.random.choice(self.key(),jnp.arange(L),[],p=p)
374368

375-
a = jax.random.randint(self.key(),[],0,A)
376-
if seq[0,i,a] == 0: break
377-
return seq.at[:,i,:].set(jnp.eye(A)[a])
369+
# sample position
370+
pi = jnp.ones(L) if plddt is None else jax.nn.relu(1-plddt)
371+
if "fix_pos" in self.opt: pi = pi.at[self.opt["fix_pos"]].set(0)
372+
373+
assert sum(pi) > 0
374+
i = jax.random.choice(self.key(),jnp.arange(L),[],p=pi/pi.sum())
375+
376+
# sample amino acid
377+
if isinstance(bias,float):
378+
pa = jax.nn.relu(1-seq[0,i])
379+
else:
380+
if bias.ndim == 2: bias = bias[i]
381+
pa = jax.nn.softmax(bias - seq[0,i] * 1e8)
382+
383+
assert sum(pa) > 0
384+
a = jax.random.choice(self.key(),jnp.arange(A),[],p=pa/pa.sum())
385+
386+
# return mutant
387+
return seq.at[0,i,:].set(jax.nn.one_hot(a,A))
378388

379389
def get_seq():
380390
return jax.nn.one_hot(self._params["seq"].argmax(-1),20)
@@ -392,7 +402,7 @@ def get_seq():
392402

393403
buff = []
394404
for _ in range(tries):
395-
self.set_seq(seq=mut(seq, plddt), set_state=False)
405+
self.set_seq(seq=mut(seq, plddt, bias=self.opt["bias"]), set_state=False)
396406
self.run(num_recycles=num_recycles, backprop=False)
397407
buff.append({"aux":self.aux, "seq":self._params["seq"]})
398408

0 commit comments

Comments
 (0)