|
59 | 59 | "%%time\n", |
60 | 60 | "#@title install\n", |
61 | 61 | "#@markdown install ESMFold, OpenFold and download Params (~2min 30s)\n", |
62 | | - "\n", |
| 62 | + "version = \"1\" # @param [\"0\", \"1\"]\n", |
| 63 | + "model_name = \"esmfold_v0.model\" if version == \"0\" else \"esmfold.model\"\n", |
63 | 64 | "import os, time\n", |
64 | | - "if not os.path.isfile(\"esmfold.model\"):\n", |
| 65 | + "if not os.path.isfile(model_name):\n", |
65 | 66 | " # download esmfold params\n", |
66 | 67 | " os.system(\"apt-get install aria2 -qq\")\n", |
67 | | - " os.system(\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &\")\n", |
| 68 | + " os.system(f\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &\")\n", |
68 | 69 | "\n", |
69 | | - " # install libs\n", |
70 | | - " os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n", |
71 | | - " os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n", |
| 70 | + " if not os.path.isfile(\"finished_install\"):\n", |
| 71 | + " print(\"installing esmfold...\")\n", |
| 72 | + " # install libs\n", |
| 73 | + " os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n", |
| 74 | + " os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n", |
72 | 75 | "\n", |
73 | | - " # install openfold\n", |
74 | | - " commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n", |
75 | | - " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", |
| 76 | + " # install openfold\n", |
| 77 | + " commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n", |
| 78 | + " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", |
76 | 79 | "\n", |
77 | | - " # install esmfold\n", |
78 | | - " os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n", |
| 80 | + " # install esmfold\n", |
| 81 | + " os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n", |
| 82 | + " os.system(\"touch finished_install\")\n", |
79 | 83 | "\n", |
80 | 84 | " # wait for Params to finish downloading...\n", |
81 | | - " if not os.path.isfile(\"esmfold.model\"):\n", |
82 | | - " # backup source!\n", |
83 | | - " os.system(\"aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model\")\n", |
84 | | - " else:\n", |
85 | | - " while os.path.isfile(\"esmfold.model.aria2\"):\n", |
86 | | - " time.sleep(5)" |
| 85 | + " while not os.path.isfile(model_name):\n", |
| 86 | + " time.sleep(5)\n", |
| 87 | + " if os.path.isfile(f\"{model_name}.aria2\"):\n", |
| 88 | + " print(\"downloading params...\")\n", |
| 89 | + " while os.path.isfile(f\"{model_name}.aria2\"):\n", |
| 90 | + " time.sleep(5)" |
87 | 91 | ] |
88 | 92 | }, |
89 | 93 | { |
|
94 | 98 | "from string import ascii_uppercase, ascii_lowercase\n", |
95 | 99 | "import hashlib, re, os\n", |
96 | 100 | "import numpy as np\n", |
| 101 | + "import torch\n", |
97 | 102 | "from jax.tree_util import tree_map\n", |
98 | 103 | "import matplotlib.pyplot as plt\n", |
99 | 104 | "from scipy.special import softmax\n", |
| 105 | + "import gc\n", |
100 | 106 | "\n", |
101 | 107 | "def parse_output(output):\n", |
102 | 108 | " pae = (output[\"aligned_confidence_probs\"][0] * np.arange(64)).mean(-1) * 31\n", |
103 | 109 | " plddt = output[\"plddt\"][0,:,1]\n", |
104 | | - " \n", |
| 110 | + "\n", |
105 | 111 | " bins = np.append(0,np.linspace(2.3125,21.6875,63))\n", |
106 | 112 | " sm_contacts = softmax(output[\"distogram_logits\"],-1)[0]\n", |
107 | 113 | " sm_contacts = sm_contacts[...,bins<8].sum(-1)\n", |
|
128 | 134 | "if copies == \"\" or copies <= 0: copies = 1\n", |
129 | 135 | "sequence = \":\".join([sequence] * copies)\n", |
130 | 136 | "num_recycles = 3 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n", |
131 | | - "chain_linker = 25 \n", |
| 137 | + "chain_linker = 25\n", |
132 | 138 | "\n", |
133 | 139 | "ID = jobname+\"_\"+get_hash(sequence)[:5]\n", |
134 | 140 | "seqs = sequence.split(\":\")\n", |
|
141 | 147 | "elif len(u_seqs) == 1: mode = \"homo\"\n", |
142 | 148 | "else: mode = \"hetero\"\n", |
143 | 149 | "\n", |
144 | | - "if \"model\" not in dir():\n", |
145 | | - " import torch\n", |
146 | | - " model = torch.load(\"esmfold.model\")\n", |
| 150 | + "if \"model\" not in dir() or model_name != model_name_:\n", |
| 151 | + " if \"model\" in dir():\n", |
| 152 | + " # delete old model from memory\n", |
| 153 | + " del model\n", |
| 154 | + " gc.collect()\n", |
| 155 | + " if torch.cuda.is_available():\n", |
| 156 | + " torch.cuda.empty_cache()\n", |
| 157 | + "\n", |
| 158 | + " model = torch.load(model_name)\n", |
147 | 159 | " model.eval().cuda().requires_grad_(False)\n", |
| 160 | + " model_name_ = model_name\n", |
148 | 161 | "\n", |
149 | 162 | "# optimized for Tesla T4\n", |
150 | 163 | "if length > 700:\n", |
|
193 | 206 | " size=(800,480), hbondCutoff=4.0,\n", |
194 | 207 | " Ls=None,\n", |
195 | 208 | " animate=False):\n", |
196 | | - " \n", |
| 209 | + "\n", |
197 | 210 | " if chains is None:\n", |
198 | 211 | " chains = 1 if Ls is None else len(Ls)\n", |
199 | 212 | " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n", |
|
215 | 228 | " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", |
216 | 229 | " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", |
217 | 230 | " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", |
218 | | - " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", |
| 231 | + " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", |
219 | 232 | " if show_mainchains:\n", |
220 | 233 | " BB = ['C','O','N','CA']\n", |
221 | 234 | " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", |
|
0 commit comments