Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample reads #46

Merged
merged 9 commits into from
Apr 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions ariba/cluster.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import random
import math
import shutil
import sys
import pyfastaq
from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, external_progs, flag, mapping, report, samtools_variants
from ariba import assembly, assembly_compare, assembly_variants, bam_parse, best_seq_chooser, common, external_progs, flag, mapping, report, samtools_variants

class Error (Exception): pass

Expand All @@ -13,6 +15,9 @@ def __init__(self,
root_dir,
name,
refdata,
total_reads,
total_reads_bases,
assembly_coverage=100,
assembly_kmer=21,
assembler='spades',
max_insert=1000,
Expand All @@ -34,6 +39,7 @@ def __init__(self,
spades_other_options=None,
clean=1,
extern_progs=None,
random_seed=42,
):

self.root_dir = os.path.abspath(root_dir)
Expand All @@ -42,19 +48,24 @@ def __init__(self,

self.name = name
self.refdata = refdata
self.total_reads = total_reads
self.total_reads_bases = total_reads_bases
self.assembly_coverage = assembly_coverage
self.assembly_kmer = assembly_kmer
self.assembler = assembler
self.sspace_k = sspace_k
self.sspace_sd = sspace_sd
self.reads_insert = reads_insert
self.spades_other_options = spades_other_options

self.reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.all_reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.all_reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.reads_for_assembly1 = os.path.join(self.root_dir, 'reads_for_assembly_1.fq')
self.reads_for_assembly2 = os.path.join(self.root_dir, 'reads_for_assembly_2.fq')
self.reference_fa = os.path.join(self.root_dir, 'reference.fa')
self.references_fa = os.path.join(self.root_dir, 'references.fa')

for fname in [self.reads1, self.reads2, self.references_fa]:
for fname in [self.all_reads1, self.all_reads2, self.references_fa]:
if not os.path.exists(fname):
raise Error('File ' + fname + ' not found. Cannot continue')

Expand Down Expand Up @@ -92,7 +103,6 @@ def __init__(self,
self.mummer_variants = {}
self.variant_depths = {}
self.percent_identities = {}
self.total_reads = self._count_reads(self.reads1, self.reads2)

# The log filehandle self.log_fh is set at the start of the run() method.
# Lots of other methods use self.log_fh. But for unit testing, run() isn't
Expand All @@ -109,13 +119,7 @@ def __init__(self,
else:
self.extern_progs = extern_progs


@staticmethod
def _count_reads(reads1, reads2):
count1 = pyfastaq.tasks.count_sequences(reads1)
count2 = pyfastaq.tasks.count_sequences(reads2)
assert(count1 == count2)
return count1 + count2
self.random_seed = random_seed


def _clean(self):
Expand All @@ -132,8 +136,8 @@ def _clean(self):
shutil.rmtree(self.assembly_dir)

to_delete = [
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.references_fa,
self.references_fa + '.fai',
self.final_assembly_bam + '.read_depths.gz',
Expand All @@ -153,14 +157,64 @@ def _clean(self):
raise Error('Error deleting file', filename)


@staticmethod
def _number_of_reads_for_assembly(reference_fa, insert_size, total_bases, total_reads, coverage):
file_reader = pyfastaq.sequences.file_reader(reference_fa)
ref_length = sum([len(x) for x in file_reader])
assert ref_length > 0
ref_length += 2 * insert_size
mean_read_length = total_bases / total_reads
wanted_bases = coverage * ref_length
wanted_reads = int(math.ceil(wanted_bases / mean_read_length))
wanted_reads += wanted_reads % 2
return wanted_reads


@staticmethod
def _make_reads_for_assembly(number_of_wanted_reads, total_reads, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=None):
'''Makes fastq files that are random subset of input files. Returns total number of reads in output files.
If the number of wanted reads is >= total reads, then just makes symlinks instead of making
new copies of the input files.'''
random.seed(random_seed)

if number_of_wanted_reads < total_reads:
reads_written = 0
percent_wanted = 100 * number_of_wanted_reads / total_reads
file_reader1 = pyfastaq.sequences.file_reader(reads_in1)
file_reader2 = pyfastaq.sequences.file_reader(reads_in2)
out1 = pyfastaq.utils.open_file_write(reads_out1)
out2 = pyfastaq.utils.open_file_write(reads_out2)

for read1 in file_reader1:
try:
read2 = next(file_reader2)
except StopIteration:
pyfastaq.utils.close(out1)
pyfastaq.utils.close(out2)
raise Error('Error subsetting reads. No mate found for read ' + read1.id)

if random.randint(0, 100) <= percent_wanted:
print(read1, file=out1)
print(read2, file=out2)
reads_written += 2

pyfastaq.utils.close(out1)
pyfastaq.utils.close(out2)
return reads_written
else:
os.symlink(reads_in1, reads_out1)
os.symlink(reads_in2, reads_out2)
return total_reads


def run(self):
self.logfile = os.path.join(self.root_dir, 'log.txt')
self.log_fh = pyfastaq.utils.open_file_write(self.logfile)

print('Choosing best reference sequence:', file=self.log_fh, flush=True)
seq_chooser = best_seq_chooser.BestSeqChooser(
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.references_fa,
self.log_fh,
samtools_exe=self.extern_progs.exe('samtools'),
Expand All @@ -174,12 +228,15 @@ def run(self):
self.status_flag.add('ref_seq_choose_fail')
self.assembled_ok = False
else:
print('\nAssembling reads:', file=self.log_fh, flush=True)
wanted_reads = self._number_of_reads_for_assembly(self.reference_fa, self.reads_insert, self.total_reads_bases, self.total_reads, self.assembly_coverage)
made_reads = self._make_reads_for_assembly(wanted_reads, self.total_reads, self.all_reads1, self.all_reads2, self.reads_for_assembly1, self.reads_for_assembly2, random_seed=self.random_seed)
print('\nUsing', made_reads, 'from a total of', self.total_reads, 'for assembly.', file=self.log_fh, flush=True)
print('Assembling reads:', file=self.log_fh, flush=True)
self.ref_sequence_type = self.refdata.sequence_type(self.ref_sequence.id)
assert self.ref_sequence_type is not None
self.assembly = assembly.Assembly(
self.reads1,
self.reads2,
self.reads_for_assembly1,
self.reads_for_assembly2,
self.reference_fa,
self.assembly_dir,
self.final_assembly_fa,
Expand All @@ -202,8 +259,8 @@ def run(self):
print('\nAssembly was successful\n\nMapping reads to assembly:', file=self.log_fh, flush=True)

mapping.run_bowtie2(
self.reads1,
self.reads2,
self.all_reads1,
self.all_reads2,
self.final_assembly_fa,
self.final_assembly_bam[:-4],
threads=1,
Expand Down
11 changes: 10 additions & 1 deletion ariba/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
outdir,
extern_progs,
assembly_kmer=21,
assembly_coverage=100,
threads=1,
verbose=False,
assembler='spades',
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self,
self.assembler = assembler
assert self.assembler in ['spades']
self.assembly_kmer = assembly_kmer
self.assembly_coverage = assembly_coverage
self.spades_other = spades_other

self.refdata_files_prefix = os.path.join(self.outdir, 'refdata')
Expand Down Expand Up @@ -91,6 +93,8 @@ def __init__(self,

self.cluster_to_dir = {} # gene name -> abs path of cluster directory
self.clusters = {} # gene name -> Cluster object
self.cluster_read_counts = {} # gene name -> number of reads
self.cluster_base_counts = {} # gene name -> number of bases

self.cdhit_seq_identity_threshold = cdhit_seq_identity_threshold
self.cdhit_length_diff_cutoff = cdhit_length_diff_cutoff
Expand Down Expand Up @@ -210,6 +214,8 @@ def _bam_to_clusters_reads(self):

print(read1, file=filehandles_1[ref])
print(read2, file=filehandles_2[ref])
self.cluster_read_counts[ref] = self.cluster_read_counts.get(ref, 0) + 2
self.cluster_base_counts[ref] = self.cluster_base_counts.get(ref, 0) + len(read1) + len(read2)

sam1 = None

Expand Down Expand Up @@ -257,7 +263,10 @@ def _init_and_run_clusters(self):
cluster_list.append(cluster.Cluster(
new_dir,
seq_name,
refdata=self.refdata,
self.refdata,
self.cluster_read_counts[seq_name],
self.cluster_base_counts[seq_name],
assembly_coverage=self.assembly_coverage,
assembly_kmer=self.assembly_kmer,
assembler=self.assembler,
max_insert=self.insert_proper_pair_max,
Expand Down
2 changes: 2 additions & 0 deletions ariba/tasks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run():
nucmer_group.add_argument('--nucmer_breaklen', type=int, help='Value to use for -breaklen when running nucmer [%(default)s]', default=50, metavar='INT')

assembly_group = parser.add_argument_group('Assembly options')
assembly_group.add_argument('--assembly_cov', type=int, help='Target read coverage when sampling reads for assembly [%(default)s]', default=100, metavar='INT')
assembly_group.add_argument('--assembler_k', type=int, help='kmer size to use with assembler. You can use 0 to set kmer to 2/3 of the read length. Warning - lower kmers are usually better. [%(default)s]', metavar='INT', default=21)
assembly_group.add_argument('--spades_other', help='Put options string to be used with spades in quotes. This will NOT be sanity checked. Do not use -k or -t: for these options you should use the ariba run options --assembler_k and --threads [%(default)s]', default="--only-assembler", metavar="OPTIONS")
assembly_group.add_argument('--min_scaff_depth', type=int, help='Minimum number of read pairs needed as evidence for scaffold link between two contigs. This is also the value used for sspace -k when scaffolding [%(default)s]', default=10, metavar='INT')
Expand Down Expand Up @@ -72,6 +73,7 @@ def run():
options.outdir,
extern_progs,
assembly_kmer=options.assembler_k,
assembly_coverage=options.assembly_cov,
assembler='spades',
threads=options.threads,
verbose=options.verbose,
Expand Down
72 changes: 57 additions & 15 deletions ariba/tests/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,60 @@ def test_init_fail_files_missing(self):
tmpdir = 'tmp.cluster_test_init_fail_files_missing'
shutil.copytree(d, tmpdir)
with self.assertRaises(cluster.Error):
c = cluster.Cluster(tmpdir, 'name', refdata=refdata)
c = cluster.Cluster(tmpdir, 'name', refdata=refdata, total_reads=42, total_reads_bases=4242)
shutil.rmtree(tmpdir)

with self.assertRaises(cluster.Error):
c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata)

c = cluster.Cluster('directorydoesnotexistshouldthrowerror', 'name', refdata=refdata, total_reads=42, total_reads_bases=4242)


def test_number_of_reads_for_assembly(self):
'''Test _number_of_reads_for_assembly'''
# ref is 100bp long
ref_fa = os.path.join(data_dir, 'cluster_test_number_of_reads_for_assembly.ref.fa')
tests = [
(50, 1000, 10, 20, 40),
(50, 999, 10, 20, 42),
(50, 1000, 10, 10, 20),
(50, 1000, 10, 5, 10),
]

def test_count_reads(self):
'''test _count_reads pass'''
reads1 = os.path.join(data_dir, 'cluster_test_count_reads_1.fq')
reads2 = os.path.join(data_dir, 'cluster_test_count_reads_2.fq')
self.assertEqual(4, cluster.Cluster._count_reads(reads1, reads2))
for insert, bases, reads, coverage, expected in tests:
self.assertEqual(expected, cluster.Cluster._number_of_reads_for_assembly(ref_fa, insert, bases, reads, coverage))


def test_make_reads_for_assembly_proper_sample(self):
'''Test _make_reads_for_assembly when sampling from reads'''
reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq')
reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq')
expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq')
expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq')
reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq'
reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq'
reads_written = cluster.Cluster._make_reads_for_assembly(10, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42)
self.assertEqual(14, reads_written)
self.assertTrue(filecmp.cmp(expected_out1, reads_out1, shallow=False))
self.assertTrue(filecmp.cmp(expected_out2, reads_out2, shallow=False))
os.unlink(reads_out1)
os.unlink(reads_out2)


def test_make_reads_for_assembly_symlinks(self):
'''Test _make_reads_for_assembly when just makes symlinks'''
reads_in1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in1.fq')
reads_in2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.in2.fq')
expected_out1 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out1.fq')
expected_out2 = os.path.join(data_dir, 'cluster_test_make_reads_for_assembly.out2.fq')
reads_out1 = 'tmp.test_make_reads_for_assembly.reads.out1.fq'
reads_out2 = 'tmp.test_make_reads_for_assembly.reads.out2.fq'
reads_written = cluster.Cluster._make_reads_for_assembly(20, 20, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=42)
self.assertEqual(20, reads_written)
self.assertTrue(os.path.islink(reads_out1))
self.assertTrue(os.path.islink(reads_out2))
self.assertEqual(os.readlink(reads_out1), reads_in1)
self.assertEqual(os.readlink(reads_out2), reads_in2)
os.unlink(reads_out1)
os.unlink(reads_out2)


def test_full_run_choose_ref_fail(self):
Expand All @@ -70,7 +112,7 @@ def test_full_run_choose_ref_fail(self):
tmpdir = 'tmp.test_full_run_choose_ref_fail'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_choose_ref_fail'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata)
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=2, total_reads_bases=108)
c.run()

expected = '\t'.join(['.', '.', '1088', '2', 'cluster_name'] + ['.'] * 23)
Expand All @@ -88,7 +130,7 @@ def test_full_run_assembly_fail(self):
tmpdir = 'tmp.test_full_run_assembly_fail'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_assembly_fail'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata)
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, total_reads=4, total_reads_bases=304)
c.run()

expected = '\t'.join(['noncoding_ref_seq', 'non_coding', '64', '4', 'cluster_name'] + ['.'] * 23)
Expand All @@ -108,7 +150,7 @@ def test_full_run_ok_non_coding(self):
tmpdir = 'tmp.test_full_run_ok_non_coding'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_non_coding'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=72, total_reads_bases=3600)
c.run()

expected = [
Expand All @@ -134,7 +176,7 @@ def test_full_run_ok_presence_absence(self):
tmpdir = 'tmp.cluster_test_full_run_ok_presence_absence'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_presence_absence'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=64, total_reads_bases=3200)
c.run()

expected = [
Expand All @@ -160,7 +202,7 @@ def test_full_run_ok_variants_only_variant_not_present(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()
expected = [
'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, so do not report\tGeneric description of variants_only1'
Expand All @@ -179,7 +221,7 @@ def test_full_run_ok_variants_only_variant_not_present_always_report(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.not_present.always_report'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()
expected = [
'variants_only1\tvariants_only\t27\t66\tcluster_name\t96\t96\t100.0\tvariants_only1.scaffold.1\t215\t1\tSNP\tp\tR3S\t0\t.\t.\t7\t9\tC;G;C\t65\t67\tC;G;C\t18;18;19\t.;.;.\t18;18;19\tvariants_only1_p_R3S_Ref and assembly have wild type, but always report anyway\tGeneric description of variants_only1'
Expand All @@ -198,7 +240,7 @@ def test_full_run_ok_variants_only_variant_is_present(self):
tmpdir = 'tmp.cluster_test_full_run_ok_variants_only.present'
shutil.copytree(os.path.join(data_dir, 'cluster_test_full_run_ok_variants_only'), tmpdir)

c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler')
c = cluster.Cluster(tmpdir, 'cluster_name', refdata, spades_other_options='--only-assembler', total_reads=66, total_reads_bases=3300)
c.run()

expected = [
Expand Down
2 changes: 2 additions & 0 deletions ariba/tests/clusters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_bam_to_clusters_reads(self):
self.assertTrue(filecmp.cmp(expected[i], got[i], shallow=False))

self.assertEqual({780:1}, c.insert_hist.bins)
self.assertEqual({'ref1': 4, 'ref2': 2}, c.cluster_read_counts)
self.assertEqual({'ref1': 240, 'ref2': 120}, c.cluster_base_counts)

shutil.rmtree(clusters_dir)

Expand Down
Loading