Skip to content

Commit

Permalink
colabfold_search : allow continuing interrupted runs, disabling unpac…
Browse files Browse the repository at this point in the history
…king databases
  • Loading branch information
milot-mirdita committed Jul 14, 2024
1 parent a9293c9 commit f4d9e37
Showing 1 changed file with 126 additions and 92 deletions.
218 changes: 126 additions & 92 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""
Functionality for running mmseqs locally. Takes in a fasta file, outputs final.a3m
Note: Currently needs mmseqs compiled from source
"""

import logging
Expand All @@ -18,8 +16,28 @@

logger = logging.getLogger(__name__)

MODULE_OUTPUT_POS = {
"align": 4,
"convertalis": 4,
"expandaln": 5,
"filterresult": 4,
"lndb": 2,
"mergedbs": 2,
"mvdb": 2,
"pairaln": 4,
"result2msa": 4,
"search": 3,
}

def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
module = params[0]
if module in MODULE_OUTPUT_POS:
output_pos = MODULE_OUTPUT_POS[module]
output_path = Path(params[output_pos]).with_suffix('.dbtype')
if output_path.exists():
logger.info(f"Skipping {module} because {output_path} already exists")
return

params_log = " ".join(str(i) for i in params)
logger.info(f"Running {mmseqs} {params_log}")
# hide MMseqs2 verbose paramters list that clogs up the log
Expand All @@ -46,6 +64,7 @@ def mmseqs_search_monomer(
s: float = 8,
db_load_mode: int = 2,
threads: int = 32,
unpack: bool = True,
):
"""Run mmseqs with a local colabfold database set
Expand Down Expand Up @@ -86,8 +105,6 @@ def mmseqs_search_monomer(
dbSuffix2 = ".idx"
dbSuffix3 = ".idx"

# fmt: off
# @formatter:off
search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"]
search_param += ["--prefilter-mode", str(prefilter_mode)]
if s is not None:
Expand All @@ -98,24 +115,27 @@ def mmseqs_search_monomer(
filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",]
expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",]

run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param)
run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")])
run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")])
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param)
run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"])
run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode",
str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads",
str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"])
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode",
"6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign")])
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp")])
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res")])
subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign_filter")])
if not base.joinpath("uniref.a3m").with_suffix('.a3m.dbtype').exists():
run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param)
run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")])
run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")])
run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param)
run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"])
run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode",
str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads",
str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"])
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"),
base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode",
"6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_filter")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
else:
logger.info(f"Skipping {uniref_db} search because uniref.a3m already exists")

if use_env:
if use_env and not base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m").with_suffix('.a3m.dbtype').exists():
run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(metagenomic_db), base.joinpath("res_env"),
base.joinpath("tmp3"), "--threads", str(threads)] + search_param)
run_mmseqs(mmseqs, ["expandaln", base.joinpath("prof_res"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), base.joinpath("res_env"),
Expand All @@ -133,45 +153,49 @@ def mmseqs_search_monomer(
base.joinpath("res_env_exp_realign_filter"),
base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m"), "--msa-format-mode", "6",
"--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param)

run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign_filter")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env")])
elif use_env:
logger.info(f"Skipping {metagenomic_db} search because bfd.mgnify30.metaeuk30.smag30.a3m already exists")

run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
else:
run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")])

if use_templates:
if use_templates and not base.joinpath("res_pdb.m8").with_suffix('.m8.dbtype').exists():
run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(template_db), base.joinpath("res_pdb"),
base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)])
run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), dbbase.joinpath(f"{template_db}{dbSuffix3}"), base.joinpath("res_pdb"),
base.joinpath(f"{template_db}"), "--format-output",
base.joinpath("res_pdb.m8"), "--format-output",
"query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar",
"--db-output", "1",
"--db-load-mode", str(db_load_mode), "--threads", str(threads)])
run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db}"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db}")])
elif use_templates:
logger.info(f"Skipping {template_db} search because res_pdb.m8 already exists")

run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
# @formatter:on
# fmt: on
if use_env:
run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])
else:
run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")])

for file in base.glob("prof_res*"):
file.unlink()
if unpack:
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")])

if use_templates:
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("res_pdb.m8"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb.m8")])

run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res_h")])
shutil.rmtree(base.joinpath("tmp"))
if use_templates:
shutil.rmtree(base.joinpath("tmp2"))
if use_env:
shutil.rmtree(base.joinpath("tmp3"))


def mmseqs_search_pair(
dbbase: Path,
base: Path,
Expand All @@ -184,6 +208,7 @@ def mmseqs_search_pair(
threads: int = 64,
db_load_mode: int = 2,
pairing_strategy: int = 0,
unpack: bool = True,
):
if not dbbase.joinpath(f"{uniref_db}.dbtype").is_file():
raise FileNotFoundError(f"Database {uniref_db} does not exist")
Expand Down Expand Up @@ -225,14 +250,15 @@ def mmseqs_search_pair(
run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],)
run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],)
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],)
if unpack:
run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],)
run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair_bt")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_final")])
run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")])
shutil.rmtree(base.joinpath("tmp"))
# @formatter:on
# fmt: on
Expand Down Expand Up @@ -340,6 +366,9 @@ def main():
default=0,
help="Database preload mode 0: auto, 1: fread, 2: mmap, 3: mmap+touch",
)
parser.add_argument(
"--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to loose files or keep MMseqs2 databases."
)
parser.add_argument(
"--threads", type=int, default=64, help="Number of threads to use."
)
Expand Down Expand Up @@ -416,6 +445,7 @@ def main():
s=args.s,
db_load_mode=args.db_load_mode,
threads=args.threads,
unpack=args.unpack,
)
if is_complex is True:
mmseqs_search_pair(
Expand All @@ -429,6 +459,7 @@ def main():
threads=args.threads,
pairing_strategy=args.pairing_strategy,
pair_env=False,
unpack=args.unpack,
)
if args.use_env_pairing:
mmseqs_search_pair(
Expand All @@ -443,63 +474,66 @@ def main():
threads=args.threads,
pairing_strategy=args.pairing_strategy,
pair_env=True,
unpack=args.unpack,
)

id = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
paired_msa = []
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()

if args.use_env_pairing:
with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair:
with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env:
while chunk := file_pair_env.read(10 * 1024 * 1024):
file_pair.write(chunk)
args.base.joinpath(f"{id}.env.paired.a3m").unlink()

if args.unpack:
id = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
paired_msa = []
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()

if args.use_env_pairing:
with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair:
with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env:
while chunk := file_pair_env.read(10 * 1024 * 1024):
file_pair.write(chunk)
args.base.joinpath(f"{id}.env.paired.a3m").unlink()

if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)

if args.unpack:
# rename a3m files
for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique):
os.rename(
args.base.joinpath(f"{job_number}.a3m"),
args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"),
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)

# rename a3m files
for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique):
os.rename(
args.base.joinpath(f"{job_number}.a3m"),
args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"),
)

# rename m8 files
if args.use_templates:
id = 0
for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique:
with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open(
"w"
) as f:
for _ in range(len(query_seqs_cardinality)):
with args.base.joinpath(f"{id}.m8").open("r") as g:
f.write(g.read())
os.remove(args.base.joinpath(f"{id}.m8"))
id += 1
# rename m8 files
if args.use_templates:
id = 0
for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique:
with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open(
"w"
) as f:
for _ in range(len(query_seqs_cardinality)):
with args.base.joinpath(f"{id}.m8").open("r") as g:
f.write(g.read())
os.remove(args.base.joinpath(f"{id}.m8"))
id += 1
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])

query_file.unlink()
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])


if __name__ == "__main__":
Expand Down

0 comments on commit f4d9e37

Please sign in to comment.