Skip to content

Commit

Permalink
fix param switch (need some time for download to start)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Mar 29, 2024
1 parent 0614fdf commit fe88237
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions RoseTTAFold2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@
"import os, time, sys\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n",
"\n",
"if params == \"RF2_jan24\" and not os.path.isfile(\"RF2_jan24.tgz\"):\n",
"if params == \"RF2_jan24\" and not os.path.isfile(f\"{params}.tgz\"):\n",
" # send param download into background\n",
" os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_jan24.tgz) &\")\n",
"\n",
"if params == \"RF2_apr23\" and not os.path.isfile(\"RF2_apr23.tgz\"):\n",
"if params == \"RF2_apr23\" and not os.path.isfile(f\"{params}.tgz\"):\n",
" # send param download into background\n",
" os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_apr23.tgz) &\")\n",
"\n",
Expand All @@ -86,19 +86,18 @@
" os.makedirs(\"hhsuite\", exist_ok=True)\n",
" os.system(f\"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/\")\n",
"\n",
"if not os.path.isfile(f\"{params}.pt\"):\n",
" time.sleep(5)\n",
"\n",
"if os.path.isfile(f\"{params}.tgz.aria2\"):\n",
" print(\"downloading RoseTTAFold2 params\")\n",
" while os.path.isfile(f\"{params}.tgz.aria2\"):\n",
" time.sleep(5)\n",
"\n",
"if params == \"RF2_jan24\":\n",
" model_params = f\"{params}.pt\"\n",
"if params == \"RF2_apr23\":\n",
" model_params = f\"weights/{params}.pt\"\n",
"\n",
"if not os.path.isfile(model_params):\n",
"if not os.path.isfile(f\"{params}.pt\"):\n",
" os.system(f\"tar -zxvf {params}.tgz\")\n",
" if params == \"RF2_apr23\":\n",
" os.system(f\"mv weights/{params}.pt .\")\n",
"\n",
"if not \"IMPORTED\" in dir():\n",
" if 'RoseTTAFold2/network' not in sys.path:\n",
Expand All @@ -125,16 +124,16 @@
"\n",
" IMPORTED = True\n",
"\n",
"if not \"pred\" in dir() or model_params_sele != model_params:\n",
"if not \"pred\" in dir() or params_sele != params:\n",
" from predict import Predictor\n",
" print(\"compile RoseTTAFold2\")\n",
"\n",
" if (torch.cuda.is_available()):\n",
" pred = Predictor(model_params, torch.device(\"cuda:0\"))\n",
" pred = Predictor(f\"{params}.pt\", torch.device(\"cuda:0\"))\n",
" else:\n",
" print (\"WARNING: using CPU\")\n",
" pred = Predictor(model_params, torch.device(\"cpu\"))\n",
" model_params_sele = model_params\n",
" pred = Predictor(f\"{params}.pt\", torch.device(\"cpu\"))\n",
" params_sele = params\n",
"\n",
"def get_unique_sequences(seq_list):\n",
" unique_seqs = list(OrderedDict.fromkeys(seq_list))\n",
Expand Down

0 comments on commit fe88237

Please sign in to comment.