Skip to content

Commit 396b0c4

Browse files
sokryptonIeremie
andcommitted
adding support for esmfold_v0
Co-Authored-By: Ieremie Ioan <[email protected]>
1 parent 33ae16f commit 396b0c4

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

ESMFold.ipynb

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,35 @@
5959
"%%time\n",
6060
"#@title install\n",
6161
"#@markdown install ESMFold, OpenFold and download Params (~2min 30s)\n",
62-
"\n",
62+
"version = \"1\" # @param [\"0\", \"1\"]\n",
63+
"model_name = \"esmfold_v0.model\" if version == \"0\" else \"esmfold.model\"\n",
6364
"import os, time\n",
64-
"if not os.path.isfile(\"esmfold.model\"):\n",
65+
"if not os.path.isfile(model_name):\n",
6566
" # download esmfold params\n",
6667
" os.system(\"apt-get install aria2 -qq\")\n",
67-
" os.system(\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &\")\n",
68+
" os.system(f\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &\")\n",
6869
"\n",
69-
" # install libs\n",
70-
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n",
71-
" os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n",
70+
" if not os.path.isfile(\"finished_install\"):\n",
71+
" print(\"installing esmfold...\")\n",
72+
" # install libs\n",
73+
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n",
74+
" os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n",
7275
"\n",
73-
" # install openfold\n",
74-
" commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n",
75-
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
76+
" # install openfold\n",
77+
" commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n",
78+
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
7679
"\n",
77-
" # install esmfold\n",
78-
" os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n",
80+
" # install esmfold\n",
81+
" os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n",
82+
" os.system(\"touch finished_install\")\n",
7983
"\n",
8084
" # wait for Params to finish downloading...\n",
81-
" if not os.path.isfile(\"esmfold.model\"):\n",
82-
" # backup source!\n",
83-
" os.system(\"aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model\")\n",
84-
" else:\n",
85-
" while os.path.isfile(\"esmfold.model.aria2\"):\n",
86-
" time.sleep(5)"
85+
" while not os.path.isfile(model_name):\n",
86+
" time.sleep(5)\n",
87+
" if os.path.isfile(f\"{model_name}.aria2\"):\n",
88+
" print(\"downloading params...\")\n",
89+
" while os.path.isfile(f\"{model_name}.aria2\"):\n",
90+
" time.sleep(5)"
8791
]
8892
},
8993
{
@@ -94,14 +98,16 @@
9498
"from string import ascii_uppercase, ascii_lowercase\n",
9599
"import hashlib, re, os\n",
96100
"import numpy as np\n",
101+
"import torch\n",
97102
"from jax.tree_util import tree_map\n",
98103
"import matplotlib.pyplot as plt\n",
99104
"from scipy.special import softmax\n",
105+
"import gc\n",
100106
"\n",
101107
"def parse_output(output):\n",
102108
" pae = (output[\"aligned_confidence_probs\"][0] * np.arange(64)).mean(-1) * 31\n",
103109
" plddt = output[\"plddt\"][0,:,1]\n",
104-
" \n",
110+
"\n",
105111
" bins = np.append(0,np.linspace(2.3125,21.6875,63))\n",
106112
" sm_contacts = softmax(output[\"distogram_logits\"],-1)[0]\n",
107113
" sm_contacts = sm_contacts[...,bins<8].sum(-1)\n",
@@ -128,7 +134,7 @@
128134
"if copies == \"\" or copies <= 0: copies = 1\n",
129135
"sequence = \":\".join([sequence] * copies)\n",
130136
"num_recycles = 3 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n",
131-
"chain_linker = 25 \n",
137+
"chain_linker = 25\n",
132138
"\n",
133139
"ID = jobname+\"_\"+get_hash(sequence)[:5]\n",
134140
"seqs = sequence.split(\":\")\n",
@@ -141,10 +147,17 @@
141147
"elif len(u_seqs) == 1: mode = \"homo\"\n",
142148
"else: mode = \"hetero\"\n",
143149
"\n",
144-
"if \"model\" not in dir():\n",
145-
" import torch\n",
146-
" model = torch.load(\"esmfold.model\")\n",
150+
"if \"model\" not in dir() or model_name != model_name_:\n",
151+
" if \"model\" in dir():\n",
152+
" # delete old model from memory\n",
153+
" del model\n",
154+
" gc.collect()\n",
155+
" if torch.cuda.is_available():\n",
156+
" torch.cuda.empty_cache()\n",
157+
"\n",
158+
" model = torch.load(model_name)\n",
147159
" model.eval().cuda().requires_grad_(False)\n",
160+
" model_name_ = model_name\n",
148161
"\n",
149162
"# optimized for Tesla T4\n",
150163
"if length > 700:\n",
@@ -193,7 +206,7 @@
193206
" size=(800,480), hbondCutoff=4.0,\n",
194207
" Ls=None,\n",
195208
" animate=False):\n",
196-
" \n",
209+
"\n",
197210
" if chains is None:\n",
198211
" chains = 1 if Ls is None else len(Ls)\n",
199212
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n",
@@ -215,7 +228,7 @@
215228
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
216229
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
217230
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
218-
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
231+
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
219232
" if show_mainchains:\n",
220233
" BB = ['C','O','N','CA']\n",
221234
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",

0 commit comments

Comments
 (0)