Skip to content

Commit

Permalink
adding support for esmfold_v0
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Aug 26, 2023
1 parent 33ae16f commit a382230
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions ESMFold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,33 @@
"%%time\n",
"#@title install\n",
"#@markdown install ESMFold, OpenFold and download Params (~2min 30s)\n",
"\n",
"version = \"1\" # @param [\"0\", \"1\"]\n",
"model_name = \"esmfold_v0.model\" if version == \"0\" else \"esmfold_v0.model\"\n",
"import os, time\n",
"if not os.path.isfile(\"esmfold.model\"):\n",
"if not os.path.isfile(model_name):\n",
" # download esmfold params\n",
" os.system(\"apt-get install aria2 -qq\")\n",
" os.system(\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &\")\n",
" os.system(f\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &\")\n",
"\n",
" # install libs\n",
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n",
" os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n",
" if not os.path.isfile(\"finished_install\"):\n",
" # install libs\n",
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n",
" os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n",
"\n",
" # install openfold\n",
" commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" # install openfold\n",
" commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
"\n",
" # install esmfold\n",
" os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n",
" # install esmfold\n",
" os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n",
" os.system(\"touch finished_install\")\n",
"\n",
" # wait for Params to finish downloading...\n",
" if not os.path.isfile(\"esmfold.model\"):\n",
" # backup source!\n",
" os.system(\"aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model\")\n",
" if not os.path.isfile(model_name):\n",
" print(\"ERROR: downloading esmfold params\")\n",
" else:\n",
" while os.path.isfile(\"esmfold.model.aria2\"):\n",
" print(\"waiting for param download...\")\n",
" while os.path.isfile(f\"{model_name}.aria2\"):\n",
" time.sleep(5)"
]
},
Expand All @@ -94,14 +97,16 @@
"from string import ascii_uppercase, ascii_lowercase\n",
"import hashlib, re, os\n",
"import numpy as np\n",
"import torch\n",
"from jax.tree_util import tree_map\n",
"import matplotlib.pyplot as plt\n",
"from scipy.special import softmax\n",
"import gc\n",
"\n",
"def parse_output(output):\n",
" pae = (output[\"aligned_confidence_probs\"][0] * np.arange(64)).mean(-1) * 31\n",
" plddt = output[\"plddt\"][0,:,1]\n",
" \n",
"\n",
" bins = np.append(0,np.linspace(2.3125,21.6875,63))\n",
" sm_contacts = softmax(output[\"distogram_logits\"],-1)[0]\n",
" sm_contacts = sm_contacts[...,bins<8].sum(-1)\n",
Expand All @@ -128,7 +133,7 @@
"if copies == \"\" or copies <= 0: copies = 1\n",
"sequence = \":\".join([sequence] * copies)\n",
"num_recycles = 3 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n",
"chain_linker = 25 \n",
"chain_linker = 25\n",
"\n",
"ID = jobname+\"_\"+get_hash(sequence)[:5]\n",
"seqs = sequence.split(\":\")\n",
Expand All @@ -141,10 +146,17 @@
"elif len(u_seqs) == 1: mode = \"homo\"\n",
"else: mode = \"hetero\"\n",
"\n",
"if \"model\" not in dir():\n",
" import torch\n",
" model = torch.load(\"esmfold.model\")\n",
"if \"model\" not in dir() or model_name != model_name_:\n",
" if \"model\" in dir():\n",
" # delete old model from memory\n",
" del model\n",
" gc.collect()\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
"\n",
" model = torch.load(model_name)\n",
" model.eval().cuda().requires_grad_(False)\n",
" model_name_ = model_name\n",
"\n",
"# optimized for Tesla T4\n",
"if length > 700:\n",
Expand Down Expand Up @@ -193,7 +205,7 @@
" size=(800,480), hbondCutoff=4.0,\n",
" Ls=None,\n",
" animate=False):\n",
" \n",
"\n",
" if chains is None:\n",
" chains = 1 if Ls is None else len(Ls)\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n",
Expand All @@ -215,7 +227,7 @@
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
Expand Down

0 comments on commit a382230

Please sign in to comment.