@@ -360,21 +360,31 @@ def design_semigreedy(self, iters=100, tries=20, num_recycles=None, num_models=1
360360 if self ._k == 0 :
361361 self .run (num_recycles = num_recycles , backprop = False )
362362
363- def mut (seq , plddt = None ):
363+ def mut (seq , plddt = None , bias = None ):
364364 '''mutate random position'''
365+ # bias mutations towards positions with low pLDDT
366+ # https://www.biorxiv.org/content/10.1101/2021.08.24.457549v1
365367 L ,A = seq .shape [- 2 :]
366- while True :
367- if plddt is None :
368- i = jax .random .randint (self .key (),[],0 ,L )
369- else :
370- p = (1 - plddt )/ (1 - plddt ).sum (- 1 ,keepdims = True )
371- # bias mutations towards positions with low pLDDT
372- # https://www.biorxiv.org/content/10.1101/2021.08.24.457549v1
373- i = jax .random .choice (self .key (),jnp .arange (L ),[],p = p )
374368
375- a = jax .random .randint (self .key (),[],0 ,A )
376- if seq [0 ,i ,a ] == 0 : break
377- return seq .at [:,i ,:].set (jnp .eye (A )[a ])
369+ # sample position
370+ pi = jnp .ones (L ) if plddt is None else jax .nn .relu (1 - plddt )
371+ if "fix_pos" in self .opt : pi = pi .at [self .opt ["fix_pos" ]].set (0 )
372+
373+ assert sum (pi ) > 0
374+ i = jax .random .choice (self .key (),jnp .arange (L ),[],p = pi / pi .sum ())
375+
376+ # sample amino acid
377+ if isinstance (bias ,float ):
378+ pa = jax .nn .relu (1 - seq [0 ,i ])
379+ else :
380+ if bias .ndim == 2 : bias = bias [i ]
381+ pa = jax .nn .softmax (bias - seq [0 ,i ] * 1e8 )
382+
383+ assert sum (pa ) > 0
384+ a = jax .random .choice (self .key (),jnp .arange (A ),[],p = pa / pa .sum ())
385+
386+ # return mutant
387+ return seq .at [0 ,i ,:].set (jax .nn .one_hot (a ,A ))
378388
379389 def get_seq ():
380390 return jax .nn .one_hot (self ._params ["seq" ].argmax (- 1 ),20 )
@@ -392,7 +402,7 @@ def get_seq():
392402
393403 buff = []
394404 for _ in range (tries ):
395- self .set_seq (seq = mut (seq , plddt ), set_state = False )
405+ self .set_seq (seq = mut (seq , plddt , bias = self . opt [ "bias" ] ), set_state = False )
396406 self .run (num_recycles = num_recycles , backprop = False )
397407 buff .append ({"aux" :self .aux , "seq" :self ._params ["seq" ]})
398408
0 commit comments