Skip to content

Commit

Permalink
Merge pull request #46 from martinghunt/sample_reads
Browse files Browse the repository at this point in the history
Sample reads
  • Loading branch information
martinghunt committed Apr 7, 2016
2 parents 32b390b + 2223c41 commit 8c0bca0
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 37 deletions.
99 changes: 78 additions & 21 deletions ariba/cluster.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import random
import math
import shutil
import sys
import pyfastaq
from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, external_progs, flag, mapping, report, samtools_variants
from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, common, external_progs, flag, mapping, report, samtools_variants

class Error (Exception): pass

Expand All @@ -13,6 +15,9 @@ def __init__(self,
root_dir,
name,
refdata,
total_reads,
total_reads_bases,
assembly_coverage=100,
assembly_kmer=21,
assembler='spades',
max_insert=1000,
Expand All @@ -34,6 +39,7 @@ def __init__(self,
spades_other_options=None,
clean=1,
extern_progs=None,
random_seed=42,
):

self.root_dir = os.path.abspath(root_dir)
Expand All @@ -42,19 +48,24 @@ def __init__(self,

self.name = name
self.refdata = refdata
self.total_reads = total_reads
self.total_reads_bases = total_reads_bases
self.assembly_coverage = assembly_coverage
self.assembly_kmer = assembly_kmer
self.assembler = assembler
self.sspace_k = sspace_k
self.sspace_sd = sspace_sd
self.reads_insert = reads_insert
self.spades_other_options = spades_other_options

self.reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.all_reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.all_reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.reads_for_assembly1 = os.path.join(self.root_dir, 'reads_for_assembly_1.fq')
self.reads_for_assembly2 = os.path.join(self.root_dir, 'reads_for_assembly_2.fq')
self.reference_fa = os.path.join(self.root_dir, 'reference.fa')
self.references_fa = os.path.join(self.root_dir, 'references.fa')

for fname in [self.reads1, self.reads2, self.references_fa]:
for fname in [self.all_reads1, self.all_reads2, self.references_fa]:
if not os.path.exists(fname):
raise Error('File ' + fname + ' not found. Cannot continue')

Expand Down Expand Up @@ -92,7 +103,6 @@ def __init__(self,
self.mummer_variants = {}
self.variant_depths = {}
self.percent_identities = {}
self.total_reads = self._count_reads(self.reads1, self.reads2)

# The log filehandle self.log_fh is set at the start of the run() method.
# Lots of other methods use self.log_fh. But for unit testing, run() isn't
Expand All @@ -109,13 +119,7 @@ def __init__(self,
else:
self.extern_progs = extern_progs


@staticmethod
def _count_reads(reads1, reads2):
count1 = pyfastaq.tasks.count_sequences(reads1)
count2 = pyfastaq.tasks.count_sequences(reads2)
assert(count1 == count2)
return count1 + count2
self.random_seed = random_seed


def _clean(self):
Expand All @@ -132,8 +136,8 @@ def _clean(self):
shutil.rmtree(self.assembly_dir)

to_delete = [
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.references_fa,
self.references_fa + '.fai',
self.final_assembly_bam + '.read_depths.gz',
Expand All @@ -153,14 +157,64 @@ def _clean(self):
raise Error('Error deleting file', filename)


@staticmethod
def _number_of_reads_for_assembly(reference_fa, insert_size, total_bases, total_reads, coverage):
file_reader = pyfastaq.sequences.file_reader(reference_fa)
ref_length = sum([len(x) for x in file_reader])
assert ref_length > 0
ref_length += 2 * insert_size
mean_read_length = total_bases / total_reads
wanted_bases = coverage * ref_length
wanted_reads = int(math.ceil(wanted_bases / mean_read_length))
wanted_reads += wanted_reads % 2
return wanted_reads


@staticmethod
def _make_reads_for_assembly(number_of_wanted_reads, total_reads, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=None):
'''Makes fastq files that are random subset of input files. Returns total number of reads in output files.
If the number of wanted reads is >= total reads, then just makes symlinks instead of making
new copies of the input files.'''
random.seed(random_seed)

if number_of_wanted_reads < total_reads:
reads_written = 0
percent_wanted = 100 * number_of_wanted_reads / total_reads
file_reader1 = pyfastaq.sequences.file_reader(reads_in1)
file_reader2 = pyfastaq.sequences.file_reader(reads_in2)
out1 = pyfastaq.utils.open_file_write(reads_out1)
out2 = pyfastaq.utils.open_file_write(reads_out2)

for read1 in file_reader1:
try:
read2 = next(file_reader2)
except StopIteration:
pyfastaq.utils.close(out1)
pyfastaq.utils.close(out2)
raise Error('Error subsetting reads. No mate found for read ' + read1.id)

if random.randint(0, 100) <= percent_wanted:
print(read1, file=out1)
print(read2, file=out2)
reads_written += 2

pyfastaq.utils.close(out1)
pyfastaq.utils.close(out2)
return reads_written
else:
os.symlink(reads_in1, reads_out1)
os.symlink(reads_in2, reads_out2)
return total_reads


def run(self):
self.logfile = os.path.join(self.root_dir, 'log.txt')
self.log_fh = pyfastaq.utils.open_file_write(self.logfile)

print('Choosing best reference sequence:', file=self.log_fh, flush=True)
seq_chooser = best_seq_chooser.BestSeqChooser(
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.references_fa,
self.log_fh,
samtools_exe=self.extern_progs.exe('samtools'),
Expand All @@ -174,12 +228,15 @@ def run(self):
self.status_flag.add('ref_seq_choose_fail')
self.assembled_ok = False
else:
print('\nAssembling reads:', file=self.log_fh, flush=True)
wanted_reads = self._number_of_reads_for_assembly(self.reference_fa, self.reads_insert, self.total_reads_bases, self.total_reads, self.assembly_coverage)
made_reads = self._make_reads_for_assembly(wanted_reads, self.total_reads, self.all_reads1, self.all_reads2, self.reads_for_assembly1, self.reads_for_assembly2, random_seed=self.random_seed)
print('\nUsing', made_reads, 'from a total of', self.total_reads, 'for assembly.', file=self.log_fh, flush=True)
print('Assembling reads:', file=self.log_fh, flush=True)
self.ref_sequence_type = self.refdata.sequence_type(self.ref_sequence.id)
assert self.ref_sequence_type is not None
self.assembly = assembly.Assembly(
self.reads1,
self.reads2,
self.reads_for_assembly1,
self.reads_for_assembly2,
self.reference_fa,
self.assembly_dir,
self.final_assembly_fa,
Expand All @@ -202,8 +259,8 @@ def run(self):
print('\nAssembly was successful\n\nMapping reads to assembly:', file=self.log_fh, flush=True)

mapping.run_bowtie2(
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.final_assembly_fa,
self.final_assembly_bam[:-4],
threads=1,
Expand Down
11 changes: 10 additions & 1 deletion ariba/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
outdir,
extern_progs,
assembly_kmer=21,
assembly_coverage=100,
threads=1,
verbose=False,
assembler='spades',
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self,
self.assembler = assembler
assert self.assembler in ['spades']
self.assembly_kmer = assembly_kmer
self.assembly_coverage = assembly_coverage
self.spades_other = spades_other

self.refdata_files_prefix = os.path.join(self.outdir, 'refdata')
Expand Down Expand Up @@ -91,6 +93,8 @@ def __init__(self,

self.cluster_to_dir = {} # gene name -> abs path of cluster directory
self.clusters = {} # gene name -> Cluster object
self.cluster_read_counts = {} # gene name -> number of reads
self.cluster_base_counts = {} # gene name -> number of bases

self.cdhit_seq_identity_threshold = cdhit_seq_identity_threshold
self.cdhit_length_diff_cutoff = cdhit_length_diff_cutoff
Expand Down Expand Up @@ -210,6 +214,8 @@ def _bam_to_clusters_reads(self):

print(read1, file=filehandles_1[ref])
print(read2, file=filehandles_2[ref])
self.cluster_read_counts[ref] = self.cluster_read_counts.get(ref, 0) + 2
self.cluster_base_counts[ref] = self.cluster_base_counts.get(ref, 0) + len(read1) + len(read2)

sam1 = None

Expand Down Expand Up @@ -257,7 +263,10 @@ def _init_and_run_clusters(self):
cluster_list.append(cluster.Cluster(
new_dir,
seq_name,
refdata=self.refdata,
self.refdata,
self.cluster_read_counts[seq_name],
self.cluster_base_counts[seq_name],
assembly_coverage=self.assembly_coverage,
assembly_kmer=self.assembly_kmer,
assembler=self.assembler,
max_insert=self.insert_proper_pair_max,
Expand Down
2 changes: 2 additions & 0 deletions ariba/tasks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run():
nucmer_group.add_argument('--nucmer_breaklen', type=int, help='Value to use for -breaklen when running nucmer [%(default)s]', default=50, metavar='INT')

assembly_group = parser.add_argument_group('Assembly options')
assembly_group.add_argument('--assembly_cov', type=int, help='Target read coverage when sampling reads for assembly [%(default)s]', default=100, metavar='INT')
assembly_group.add_argument('--assembler_k', type=int, help='kmer size to use with assembler. You can use 0 to set kmer to 2/3 of the read length. Warning - lower kmers are usually better. [%(default)s]', metavar='INT', default=21)
assembly_group.add_argument('--spades_other', help='Put options string to be used with spades in quotes. This will NOT be sanity checked. Do not use -k or -t: for these options you should use the ariba run options --assembler_k and --threads [%(default)s]', default="--only-assembler", metavar="OPTIONS")
assembly_group.add_argument('--min_scaff_depth', type=int, help='Minimum number of read pairs needed as evidence for scaffold link between two contigs. This is also the value used for sspace -k when scaffolding [%(default)s]', default=10, metavar='INT')
Expand Down Expand Up @@ -72,6 +73,7 @@ def run():
options.outdir,
extern_progs,
assembly_kmer=options.assembler_k,
assembly_coverage=options.assembly_cov,
assembler='spades',
threads=options.threads,
verbose=options.verbose,
Expand Down
72 changes: 57 additions & 15 deletions ariba/tests/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,60 @@ def test_init_fail_files_missing(self):
tmpdir = 'tmp.cluster_test_init_fail_files_missing'
shutil.copytree(d, tmpdir)
with self.assertRaises(cluster.Error):
c = cluster.Cluster(tmpdir, 'name', refdata=refdata)
c = cluster.Cluster(tmpdir, 'name', refdata=refdata, total_reads=42, total_reads_bases=4242)
shutil.rmtree(tmpdir)

with self.assertRaises(cluster.Error):
c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata)

c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata, total_reads=42, total_reads_bases=4242)


def test_number_of_reads_for_assembly(self):
'''Test _number_of_reads_for_assembly'''
# ref is 100bp long
ref_fa = os.path.join(data_dir, 'cluster_test_number_of_reads_for_assembly.ref.fa')
tests = [
(50, 1000, 10, 20, 40),
(50, 999, 10, 20, 42),
(50, 1000, 10, 10, 20),
(50, 1000, 10, 5, 10),
]

def test_count_reads(self):
'''test _count_reads pass'''
reads1 = os.path.join(data_dir, 'cluster_test_count_reads_1.fq')
reads2 = os.path.join(data_dir, 'cluster_test_count_reads_2.fq')
self.assertEqual(4, cluster.Cluster._count_reads(reads1, reads2))
for insert, bases, reads, coverage, expected in tests:
self.assertEqual(expected, cluster.Cluster._number_of_reads_for_assembly(ref_fa, insert, bases, reads, coverage))


def test_make_reads_for_assembly_proper_sample(self):
'''Test _make_reads_for_assembly when sampling from reads'''
reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq')
reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq')
expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq')
expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq')
reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq'
reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq'
reads_written = cluster.Cluster._make_reads_for_assembly(10, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42)
self.assertEqual(14, reads_written)
self.assertTrue(filecmp.cmp(expected_out1, reads_out1, shallow=False))
self.assertTrue(filecmp.cmp(expected_out2, reads_out2, shallow=False))
os.unlink(reads_out1)
os.unlink(reads_out2)


def test_make_reads_for_assembly_symlinks(self):
'''Test _make_reads_for_assembly when just makes symlinks'''
reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq')
reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq')
expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq')
expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq')
reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq'
reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq'
reads_written = cluster.Cluster._make_reads_for_assembly(20, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42)
self.assertEqual(20, reads_written)
self.assertTrue(os.path.islink(reads_out1))
self.assertTrue(os.path.islink(reads_out2))
self.assertEqual(os.readlink(reads_out1), reads_in1)
self.assertEqual(os.readlink(reads_out2), reads_in2)
os.unlink(reads_out1)
os.unlink(reads_out2)


def test_full_run_choose_ref_fail(self):
Expand All @@ -70,7 +112,7 @@ def test_full_run_choose_ref_fail(self):
tmpdir = 'tmp.test_full_run_choose_ref_fail'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_choose_ref_fail'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata)
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=2, total_reads_bases=108)
c.run()

expected = '\t'.join(['.', '.', '1088', '2', 'cluster_name'] + ['.'] * 23)
Expand All @@ -88,7 +130,7 @@ def test_full_run_assembly_fail(self):
tmpdir = 'tmp.test_full_run_assembly_fail'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_assembly_fail'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata)
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=4, total_reads_bases=304)
c.run()

expected = '\t'.join(['noncoding_ref_seq', 'non_coding', '64', '4', 'cluster_name'] + ['.'] * 23)
Expand All @@ -108,7 +150,7 @@ def test_full_run_ok_non_coding(self):
tmpdir = 'tmp.test_full_run_ok_non_coding'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_non_coding'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=72, total_reads_bases=3600)
c.run()

expected = [
Expand All @@ -134,7 +176,7 @@ def test_full_run_ok_presence_absence(self):
tmpdir = 'tmp.cluster_test_full_run_ok_presence_absence'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_presence_absence'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=64, total_reads_bases=3200)
c.run()

expected = [
Expand All @@ -160,7 +202,7 @@ def test_full_run_ok_variants_only_variant_not_present(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()
expected = [
'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, so do not report\tGeneric description of variants_only1'
Expand All @@ -179,7 +221,7 @@ def test_full_run_ok_variants_only_variant_not_present_always_report(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present.always_report'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()
expected = [
'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, but always report anyway\tGeneric description of variants_only1'
Expand All @@ -198,7 +240,7 @@ def test_full_run_ok_variants_only_variant_is_present(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.present'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()

expected = [
Expand Down
2 changes: 2 additions & 0 deletions ariba/tests/clusters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_bam_to_clusters_reads(self):
self.assertTrue(filecmp.cmp(expected[i], got[i], shallow=False))

self.assertEqual({780:1}, c.insert_hist.bins)
self.assertEqual({'ref1': 4, 'ref2': 2}, c.cluster_read_counts)
self.assertEqual({'ref1': 240, 'ref2': 120}, c.cluster_base_counts)

shutil.rmtree(clusters_dir)

Expand Down
Loading

0 comments on commit 8c0bca0

Please sign in to comment.