Skip to content

Commit

Permalink
[wip] Load compressed sequences into database
Browse files Browse the repository at this point in the history
This was tried, and ultimately a bit much overhead for `augur filter`.

The usage of sequence data in `augur filter` is trivial - all the logic
is done on the tabular metadata file which loads nicely into the
database. Sequence data is simply iterated via a generator, and anything
that has passed metadata and sequence index filters gets written out.
This means no sequence data needs to be kept in memory, eliminating the
need for loading all into a database, which comes with storage overhead:

Sequence data at large scale should be compressed. Uncompressed open
(GenBank) SARS-CoV-2 data as of today yielded a SQLite DB file of >80GB
before I terminated the process due to limited local storage.

The compression of all open (GenBank) SARS-CoV-2 sequences to date using
zlib took 4.5 hours on my local machine, compared to 7 minutes using the
current approach of reading and outputting on the fly.
  • Loading branch information
victorlin committed Mar 11, 2022
1 parent a73aec3 commit 06e5698
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
14 changes: 13 additions & 1 deletion augur/filter_support/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def db_set_sanitized_identifiers(self): pass
@abc.abstractmethod
def db_load_metadata(self): pass

@abc.abstractmethod
def db_load_sequences(self): pass

@abc.abstractmethod
def db_load_sequence_index(self, path:str): pass

Expand Down Expand Up @@ -279,6 +282,9 @@ def db_create_filter_reason_table(self, exclude_by:List[FilterOption], include_b
@abc.abstractmethod
def db_create_output_table(self, input_table:str): pass

@abc.abstractmethod
def db_create_sequence_output_table(self, input_table:str): pass

def subsample(self):
"""Apply subsampling to update filter reason table."""
self.create_priorities_table()
Expand Down Expand Up @@ -349,7 +355,10 @@ def db_create_group_sizes_table(self, group_by_cols:List[str], sequences_per_gro
def write_outputs(self):
"""Write various outputs."""
if self.args.output:
self.read_and_output_sequences()
self.db_load_sequences()
self.db_create_sequence_output_table()
self.db_output_sequences()
# self.read_and_output_sequences()
if self.args.output_strains:
self.db_output_strains()
if self.args.output_metadata:
Expand Down Expand Up @@ -408,6 +417,9 @@ def db_output_strains(self): pass
@abc.abstractmethod
def db_output_metadata(self): pass

@abc.abstractmethod
def db_output_sequences(self): pass

@abc.abstractmethod
def db_output_log(self): pass

Expand Down
29 changes: 29 additions & 0 deletions augur/filter_support/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pandas as pd
import sqlite3
from tempfile import NamedTemporaryFile

from augur.filter_support.exceptions import FilterException

from augur.io_support.db.sqlite import load_tsv, cleanup, ROW_ORDER_COLUMN, sanitize_identifier
from augur.io_support.db.sqlite_sequences import SEQUENCE_ID_COLUMN, load_fasta, write_fasta
from augur.utils import read_strains
from augur.filter_support.db.base import DUMMY_COL, FilterBase, FilterCallableReturn, FilterOption
from augur.filter_support.date_parsing import ASSERT_ONLY_LESS_SIGNIFICANT_AMBIGUITY_ERROR, InvalidDateFormat, get_year, get_month, get_day, get_date_min, get_date_max, get_date_errors
Expand All @@ -16,13 +18,15 @@
# internal database globals
# table names
METADATA_TABLE_NAME = 'metadata'
SEQUENCES_TABLE_NAME = 'sequences'
SEQUENCE_INDEX_TABLE_NAME = 'sequence_index'
PRIORITIES_TABLE_NAME = 'priorities'
DATE_TABLE_NAME = 'metadata_date_expanded'
METADATA_FILTER_REASON_TABLE_NAME = 'metadata_filtered_reason'
EXTENDED_FILTERED_TABLE_NAME = 'metadata_filtered_extended'
GROUP_SIZES_TABLE_NAME = 'group_sizes'
OUTPUT_METADATA_TABLE_NAME = 'metadata_output'
OUTPUT_SEQUENCES_TABLE_NAME = 'sequences_output'
# column names
DATE_YEAR_COL = 'year'
DATE_MONTH_COL = 'month'
Expand Down Expand Up @@ -79,6 +83,18 @@ def db_load_metadata(self):
load_tsv(self.args.metadata, self.get_db_context(), METADATA_TABLE_NAME)
self.db_create_strain_index(METADATA_TABLE_NAME)

def db_load_sequences(self):
"""Loads a sequences file into the database.
Retrieves the filename from `self.args`.
"""
load_fasta(self.args.sequences, self.get_db_context(), SEQUENCES_TABLE_NAME)
with self.get_db_context() as con:
con.execute(f"""
CREATE UNIQUE INDEX sequence_id_index_test
ON {SEQUENCES_TABLE_NAME} ({SEQUENCE_ID_COLUMN})
""")

def db_load_sequence_index(self, path:str):
"""Loads a sequence index file into the database.
Expand Down Expand Up @@ -601,6 +617,16 @@ def db_create_output_table(self):
WHERE NOT f.{EXCLUDE_COL} OR f.{INCLUDE_COL}
""")

def db_create_sequence_output_table(self):
"""Creates a final intermediate table to be used for output.
"""
with self.get_db_context() as con:
con.execute(f"""CREATE TABLE {OUTPUT_SEQUENCES_TABLE_NAME} AS
SELECT s.* FROM {SEQUENCES_TABLE_NAME} s
JOIN {OUTPUT_METADATA_TABLE_NAME} m
ON (s.{SEQUENCE_ID_COLUMN} = m.{self.sanitized_metadata_id_column})
""")

def db_get_counts_per_group(self, group_by_cols:List[str]) -> List[int]:
"""
Returns
Expand Down Expand Up @@ -760,6 +786,9 @@ def db_output_log(self):
""", self.get_db_context())
df.to_csv(self.args.output_log, sep='\t', index=None)

def db_output_sequences(self):
write_fasta(self.args.output, self.get_db_context(), OUTPUT_SEQUENCES_TABLE_NAME)

def db_get_metadata_strains(self) -> Set[str]:
with self.get_db_context() as con:
cur = con.execute(f"""
Expand Down
56 changes: 56 additions & 0 deletions augur/io_support/db/sqlite_sequences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sqlite3
import zlib
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

from augur.io import open_file

SEQUENCE_ID_COLUMN = 'id'
SEQUENCE_VALUE_COLUMN = 'seq'

def load_fasta(fasta_file:str, connection:sqlite3.Connection, table_name:str):
"""Loads sequence data from a FASTA file."""
with connection:
create_table_statement = f"""
CREATE TABLE {table_name} (
{SEQUENCE_ID_COLUMN} TEXT,
{SEQUENCE_VALUE_COLUMN} BLOB
)
"""
connection.execute(create_table_statement)

insert_statement = f"""
INSERT INTO {table_name}
VALUES (?,?)
"""
# TODO: format=VCF
rows = _iter_sequences(fasta_file)
try:
with connection:
connection.executemany(insert_statement, rows)
except sqlite3.ProgrammingError as e:
raise ValueError(f'Failed to load {fasta_file}.') from e


def _iter_sequences(fasta_file:str, format="fasta"):
"""Yield sequences."""
with open_file(fasta_file) as f:
records = SeqIO.parse(f, format)
for record in records:
# yield (record.id, str(record.seq))
yield (record.id, zlib.compress(str(record.seq).encode()))

def write_fasta(fasta_file:str, connection:sqlite3.Connection, table_name:str):
rows = connection.execute(f"""
SELECT {SEQUENCE_ID_COLUMN}, {SEQUENCE_VALUE_COLUMN}
FROM {table_name}
""")
with open_file(fasta_file, 'w') as f:
for row in rows:
record = SeqRecord(
Seq(zlib.decompress(row[1]).decode('UTF-8')),
id=row[0],
description=''
)
SeqIO.write(record, f, "fasta-2line")

0 comments on commit 06e5698

Please sign in to comment.