From 06e56981432be868939b9d40cadb0c529e99e38d Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 11 Mar 2022 11:17:21 -0800 Subject: [PATCH] [wip] Load compressed sequences into database This was tried, and ultimately a bit much overhead for `augur filter`. The usage of sequence data in `augur filter` is trivial - all the logic is done on the tabular metadata file which loads nicely into the database. Sequence data is simply iterated via a generator, and anything that has passed metadata and sequence index filters gets written out. This means no sequence data needs to be kept in memory, eliminating the need for loading all into a database, which comes with storage overhead: Sequence data at large scale should be compressed. Uncompressed open (GenBank) SARS-CoV-2 data as of today yielded a SQLite DB file of >80GB before I terminated the process due to limited local storage. The compression of all open (GenBank) SARS-CoV-2 sequences to date using zlib took 4.5 hours on my local machine, compared to 7 minutes using the current approach of reading and outputting on the fly. --- augur/filter_support/db/base.py | 14 ++++++- augur/filter_support/db/sqlite.py | 29 +++++++++++++ augur/io_support/db/sqlite_sequences.py | 56 +++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 augur/io_support/db/sqlite_sequences.py diff --git a/augur/filter_support/db/base.py b/augur/filter_support/db/base.py index af69b2847..efe9716d8 100644 --- a/augur/filter_support/db/base.py +++ b/augur/filter_support/db/base.py @@ -99,6 +99,9 @@ def db_set_sanitized_identifiers(self): pass @abc.abstractmethod def db_load_metadata(self): pass + @abc.abstractmethod + def db_load_sequences(self): pass + @abc.abstractmethod def db_load_sequence_index(self, path:str): pass @@ -279,6 +282,9 @@ def db_create_filter_reason_table(self, exclude_by:List[FilterOption], include_b @abc.abstractmethod def db_create_output_table(self, input_table:str): pass + @abc.abstractmethod + def db_create_sequence_output_table(self, input_table:str): pass + def subsample(self): """Apply subsampling to update filter reason table.""" self.create_priorities_table() @@ -349,7 +355,10 @@ def db_create_group_sizes_table(self, group_by_cols:List[str], sequences_per_gro def write_outputs(self): """Write various outputs.""" if self.args.output: - self.read_and_output_sequences() + self.db_load_sequences() + self.db_create_sequence_output_table() + self.db_output_sequences() + # self.read_and_output_sequences() if self.args.output_strains: self.db_output_strains() if self.args.output_metadata: @@ -408,6 +417,9 @@ def db_output_strains(self): pass @abc.abstractmethod def db_output_metadata(self): pass + @abc.abstractmethod + def db_output_sequences(self): pass + @abc.abstractmethod def db_output_log(self): pass diff --git a/augur/filter_support/db/sqlite.py b/augur/filter_support/db/sqlite.py index 95bbd408b..cc91d1093 100644 --- a/augur/filter_support/db/sqlite.py +++ b/augur/filter_support/db/sqlite.py @@ -4,9 +4,11 @@ import pandas as pd import sqlite3 from tempfile import NamedTemporaryFile + from augur.filter_support.exceptions import FilterException from augur.io_support.db.sqlite import load_tsv, cleanup, ROW_ORDER_COLUMN, sanitize_identifier +from augur.io_support.db.sqlite_sequences import SEQUENCE_ID_COLUMN, load_fasta, write_fasta from augur.utils import read_strains from augur.filter_support.db.base import DUMMY_COL, FilterBase, FilterCallableReturn, FilterOption from augur.filter_support.date_parsing import ASSERT_ONLY_LESS_SIGNIFICANT_AMBIGUITY_ERROR, InvalidDateFormat, get_year, get_month, get_day, get_date_min, get_date_max, get_date_errors @@ -16,6 +18,7 @@ # internal database globals # table names METADATA_TABLE_NAME = 'metadata' +SEQUENCES_TABLE_NAME = 'sequences' SEQUENCE_INDEX_TABLE_NAME = 'sequence_index' PRIORITIES_TABLE_NAME = 'priorities' DATE_TABLE_NAME = 'metadata_date_expanded' @@ -23,6 +26,7 @@ EXTENDED_FILTERED_TABLE_NAME = 'metadata_filtered_extended' GROUP_SIZES_TABLE_NAME = 'group_sizes' OUTPUT_METADATA_TABLE_NAME = 'metadata_output' +OUTPUT_SEQUENCES_TABLE_NAME = 'sequences_output' # column names DATE_YEAR_COL = 'year' DATE_MONTH_COL = 'month' @@ -79,6 +83,18 @@ def db_load_metadata(self): load_tsv(self.args.metadata, self.get_db_context(), METADATA_TABLE_NAME) self.db_create_strain_index(METADATA_TABLE_NAME) + def db_load_sequences(self): + """Loads a sequences file into the database. + + Retrieves the filename from `self.args`. + """ + load_fasta(self.args.sequences, self.get_db_context(), SEQUENCES_TABLE_NAME) + with self.get_db_context() as con: + con.execute(f""" + CREATE UNIQUE INDEX sequence_id_index_test + ON {SEQUENCES_TABLE_NAME} ({SEQUENCE_ID_COLUMN}) + """) + def db_load_sequence_index(self, path:str): """Loads a sequence index file into the database. @@ -601,6 +617,16 @@ def db_create_output_table(self): WHERE NOT f.{EXCLUDE_COL} OR f.{INCLUDE_COL} """) + def db_create_sequence_output_table(self): + """Creates a final intermediate table to be used for output. + """ + with self.get_db_context() as con: + con.execute(f"""CREATE TABLE {OUTPUT_SEQUENCES_TABLE_NAME} AS + SELECT s.* FROM {SEQUENCES_TABLE_NAME} s + JOIN {OUTPUT_METADATA_TABLE_NAME} m + ON (s.{SEQUENCE_ID_COLUMN} = m.{self.sanitized_metadata_id_column}) + """) + def db_get_counts_per_group(self, group_by_cols:List[str]) -> List[int]: """ Returns @@ -760,6 +786,9 @@ def db_output_log(self): """, self.get_db_context()) df.to_csv(self.args.output_log, sep='\t', index=None) + def db_output_sequences(self): + write_fasta(self.args.output, self.get_db_context(), OUTPUT_SEQUENCES_TABLE_NAME) + def db_get_metadata_strains(self) -> Set[str]: with self.get_db_context() as con: cur = con.execute(f""" diff --git a/augur/io_support/db/sqlite_sequences.py b/augur/io_support/db/sqlite_sequences.py new file mode 100644 index 000000000..d646d2bb8 --- /dev/null +++ b/augur/io_support/db/sqlite_sequences.py @@ -0,0 +1,56 @@ +import sqlite3 +import zlib +from Bio import SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + +from augur.io import open_file + +SEQUENCE_ID_COLUMN = 'id' +SEQUENCE_VALUE_COLUMN = 'seq' + +def load_fasta(fasta_file:str, connection:sqlite3.Connection, table_name:str): + """Loads sequence data from a FASTA file.""" + with connection: + create_table_statement = f""" + CREATE TABLE {table_name} ( + {SEQUENCE_ID_COLUMN} TEXT, + {SEQUENCE_VALUE_COLUMN} BLOB + ) + """ + connection.execute(create_table_statement) + + insert_statement = f""" + INSERT INTO {table_name} + VALUES (?,?) + """ + # TODO: format=VCF + rows = _iter_sequences(fasta_file) + try: + with connection: + connection.executemany(insert_statement, rows) + except sqlite3.ProgrammingError as e: + raise ValueError(f'Failed to load {fasta_file}.') from e + + +def _iter_sequences(fasta_file:str, format="fasta"): + """Yield sequences.""" + with open_file(fasta_file) as f: + records = SeqIO.parse(f, format) + for record in records: + # yield (record.id, str(record.seq)) + yield (record.id, zlib.compress(str(record.seq).encode())) + +def write_fasta(fasta_file:str, connection:sqlite3.Connection, table_name:str): + rows = connection.execute(f""" + SELECT {SEQUENCE_ID_COLUMN}, {SEQUENCE_VALUE_COLUMN} + FROM {table_name} + """) + with open_file(fasta_file, 'w') as f: + for row in rows: + record = SeqRecord( + Seq(zlib.decompress(row[1]).decode('UTF-8')), + id=row[0], + description='' + ) + SeqIO.write(record, f, "fasta-2line")