Skip to content

Commit

Permalink
Merge pull request #10 from martinghunt/use_cdhit
Browse files Browse the repository at this point in the history
Use cdhit
  • Loading branch information
martinghunt committed Feb 20, 2015
2 parents 361ceee + ad46956 commit 85a38fe
Show file tree
Hide file tree
Showing 64 changed files with 1,948 additions and 1,659 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Installation
------------

ARIBA has the following dependencies, which need to be installed:
* [cd-hit] [cdhit] version >= 4.6
* [samtools and bcftools] [samtools] version >= 1.2
* [SSPACE-basic scaffolder] [sspace]
* [GapFiller] [gapfiller]
Expand Down Expand Up @@ -39,6 +40,7 @@ Usage
Please read the [ARIBA wiki page] [ARIBA wiki] for usage instructions.


[cdhit]: http://weizhongli-lab.org/cd-hit/
[ARIBA wiki]: https://github.com/sanger-pathogens/ariba/wiki
[gapfiller]: http://www.baseclear.com/genomics/bioinformatics/basetools/gapfiller
[mummer]: http://mummer.sourceforge.net/
Expand Down
2 changes: 2 additions & 0 deletions ariba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
__all__ = [
'bam_parse',
'cdhit',
'cluster',
'clusters',
'common',
'external_progs',
'faidx',
'flag',
'histogram',
'link',
Expand Down
121 changes: 121 additions & 0 deletions ariba/cdhit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import tempfile
import shutil
import os
import pyfastaq
from ariba import common

class Error (Exception): pass



class Runner:
def __init__(
self,
infile,
outfile,
seq_identity_threshold=0.9,
threads=1,
length_diff_cutoff=0.9,
verbose=False,
):

if not os.path.exists(infile):
raise Error('File not found: "' + infile + '". Cannot continue')

self.infile = os.path.abspath(infile)
self.outfile = os.path.abspath(outfile)
self.seq_identity_threshold = seq_identity_threshold
self.threads = threads
self.length_diff_cutoff = length_diff_cutoff
self.verbose = verbose


def run(self):
tmpdir = tempfile.mkdtemp(prefix='tmp.run_cd-hit.', dir=os.getcwd())
cdhit_fasta = os.path.join(tmpdir, 'cdhit')
cluster_info_outfile = cdhit_fasta + '.bak.clstr'
infile_renamed = os.path.join(tmpdir, 'input.renamed.fa')

# cd-hit truncates all names to 19 bases in its report of which
# sequences belong to which clusters. So need to temporarily
# rename all sequences to have short enough names. Grrr.
new_to_old_name = self._enumerate_fasta(self.infile, infile_renamed)

cmd = ' '.join([
'cd-hit',
'-i', infile_renamed,
'-o', cdhit_fasta,
'-c', str(self.seq_identity_threshold),
'-T', str(self.threads),
'-s', str(self.length_diff_cutoff),
'-bak 1',
])

common.syscall(cmd, verbose=self.verbose)

cluster_representatives = self._get_ids(cdhit_fasta)
clusters, cluster_rep_to_cluster = self._parse_cluster_info_file(cluster_info_outfile, new_to_old_name, cluster_representatives)
self._rename_fasta(cdhit_fasta, self.outfile, cluster_rep_to_cluster)
shutil.rmtree(tmpdir)
return clusters


def _enumerate_fasta(self, infile, outfile):
rename_file = outfile + '.tmp.rename_info'
assert not os.path.exists(rename_file)
pyfastaq.tasks.enumerate_names(infile, outfile, rename_file=rename_file)

with open(rename_file) as f:
lines = [x.rstrip().split('\t') for x in f.readlines() if x != '#old\tnew\n']
new_to_old_name = {x[1]: x[0] for x in lines}
if len(lines) != len(new_to_old_name):
raise Error('Sequence names in input file not unique! Cannot continue')

os.unlink(rename_file)
return new_to_old_name


def _rename_fasta(self, infile, outfile, names_dict):
seq_reader = pyfastaq.sequences.file_reader(infile)
f = pyfastaq.utils.open_file_write(outfile)
for seq in seq_reader:
seq.id = names_dict[seq.id]
print(seq, file=f)

pyfastaq.utils.close(f)


def _parse_cluster_info_file(self, infile, names_dict, cluster_representatives):
f = pyfastaq.utils.open_file_read(infile)
clusters = {}
cluster_representative_to_cluster_number = {}
for line in f:
data = line.rstrip().split()
cluster = data[0]
seqname = data[2]
if not (seqname.startswith('>') and seqname.endswith('...')):
raise Error('Unexpected format of sequence name in line:\n' + line)
seqname = seqname[1:-3]

if seqname in cluster_representatives:
cluster_representative_to_cluster_number[seqname] = cluster

seqname = names_dict[seqname]

if cluster not in clusters:
clusters[cluster] = set()

if seqname in clusters[cluster]:
raise Error('Duplicate name "' + seqname + '" found in cluster ' + str(cluster))

clusters[cluster].add(seqname)

pyfastaq.utils.close(f)

return clusters, cluster_representative_to_cluster_number


def _get_ids(self, infile):
seq_reader = pyfastaq.sequences.file_reader(infile)
return set([seq.id for seq in seq_reader])

119 changes: 114 additions & 5 deletions ariba/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import operator
import pyfastaq
import pymummer
from ariba import common, mapping, bam_parse, flag
from ariba import common, mapping, bam_parse, flag, faidx

class Error (Exception): pass


class Cluster:
def __init__(self,
root_dir,
name,
assembly_kmer=0,
assembler='velvet',
max_insert=1000,
Expand All @@ -39,22 +40,24 @@ def __init__(self,
sspace_exe='SSPACE_Basic_v2.0.pl',
velvet_exe='velvet', # prefix of velvet{g,h}
spades_other=None,
clean=1,
):

self.root_dir = os.path.abspath(root_dir)
if not os.path.exists(self.root_dir):
raise Error('Directory ' + self.root_dir + ' not found. Cannot continue')

self.name = name
self.reads1 = os.path.join(self.root_dir, 'reads_1.fq')
self.reads2 = os.path.join(self.root_dir, 'reads_2.fq')
self.gene_fa = os.path.join(self.root_dir, 'gene.fa')
self.genes_fa = os.path.join(self.root_dir, 'genes.fa')
self.gene_bam = os.path.join(self.root_dir, 'gene.reads_mapped.bam')

for fname in [self.reads1, self.reads2, self.gene_fa]:
for fname in [self.reads1, self.reads2, self.genes_fa]:
if not os.path.exists(fname):
raise Error('File ' + fname + ' not found. Cannot continue')

self.gene = self._get_gene()

self.max_insert = max_insert
self.min_scaff_depth = min_scaff_depth
Expand Down Expand Up @@ -104,6 +107,7 @@ def __init__(self,
self.unique_threshold = unique_threshold
self.status_flag = flag.Flag()
self.flag_file = os.path.join(self.root_dir, 'flag')
self.clean = clean

self.assembly_dir = os.path.join(self.root_dir, 'Assembly')
try:
Expand All @@ -123,7 +127,64 @@ def __init__(self,
self.variants = {}


def _get_gene(self):
def _get_total_alignment_score(self, gene_name):
tmp_bam = os.path.join(self.root_dir, 'tmp.get_total_alignment_score.bam')
assert not os.path.exists(tmp_bam)
tmp_fa = os.path.join(self.root_dir, 'tmp.get_total_alignment_score.ref.fa')
assert not os.path.exists(tmp_fa)
faidx.write_fa_subset([gene_name], self.genes_fa, tmp_fa, samtools_exe=self.samtools_exe, verbose=self.verbose)
mapping.run_smalt(
self.reads1,
self.reads2,
tmp_fa,
tmp_bam[:-4],
threads=self.threads,
samtools=self.samtools_exe,
smalt=self.smalt_exe,
verbose=self.verbose,
)

score = mapping.get_total_alignment_score(tmp_bam)
os.unlink(tmp_bam)
os.unlink(tmp_fa)
os.unlink(tmp_fa + '.fai')
return score


def _get_best_gene_by_alignment_score(self):
cluster_size = pyfastaq.tasks.count_sequences(self.genes_fa)
if cluster_size == 1:
seqs = {}
pyfastaq.tasks.file_to_dict(self.genes_fa, seqs)
assert len(seqs) == 1
gene_name = list(seqs.values())[0].id
if self.verbose:
print('No need to choose gene for this cluster because only has one gene:', gene_name)
return gene_name

if self.verbose:
print('\nChoosing best gene from cluster of', cluster_size, 'genes...')
file_reader = pyfastaq.sequences.file_reader(self.genes_fa)
best_score = 0
best_gene_name = None
for seq in file_reader:
score = self._get_total_alignment_score(seq.id)
if self.verbose:
print('Total alignment score for gene', seq.id, 'is', score)
if score > best_score:
best_score = score
best_gene_name = seq.id

if self.verbose:
print('Best gene is', best_gene_name, 'with total alignment score of', best_score)
print()

return best_gene_name


def _choose_best_gene(self):
gene_name = self._get_best_gene_by_alignment_score()
faidx.write_fa_subset([gene_name], self.genes_fa, self.gene_fa, samtools_exe=self.samtools_exe, verbose=self.verbose)
seqs = {}
pyfastaq.tasks.file_to_dict(self.gene_fa, seqs)
assert len(seqs) == 1
Expand Down Expand Up @@ -342,6 +403,7 @@ def _fix_contig_orientation(self):
else:
to_revcomp.add(hit.qry_name)

os.unlink(tmp_coords)
in_both = to_revcomp.intersection(not_revcomp)
for name in in_both:
print('WARNING: hits to both strands of gene for scaffold. Interpretation of any variants cannot be trusted', name, file=sys.stderr)
Expand Down Expand Up @@ -649,7 +711,7 @@ def _make_report_lines(self):
self.report_lines = []

if len(self.variants) == 0:
self.report_lines.append([self.gene.id, self.status_flag.to_number(), len(self.gene)] + ['.'] * 11)
self.report_lines.append([self.gene.id, self.status_flag.to_number(), self.name, len(self.gene)] + ['.'] * 11)

for contig in self.variants:
for variants in self.variants[contig]:
Expand All @@ -660,6 +722,7 @@ def _make_report_lines(self):
self.report_lines.append([
self.gene.id,
self.status_flag.to_number(),
self.name,
len(self.gene),
pymummer.variant.var_types[v.var_type],
effect,
Expand All @@ -675,7 +738,52 @@ def _make_report_lines(self):
])


def _clean(self):
if self.verbose:
print('Cleaning', self.root_dir)

if self.clean > 0:
if self.verbose:
print(' rm -r', self.assembly_dir)
shutil.rmtree(self.assembly_dir)

to_clean = [
[
'assembly.reads_mapped.unsorted.bam',
],
[
'assembly.fa.fai',
'assembly.reads_mapped.bam.scaff',
'assembly.reads_mapped.bam.soft_clipped',
'assembly.reads_mapped.bam.unmapped_mates',
'assembly_vs_gene.coords',
'assembly_vs_gene.coords.snps',
'genes.fa',
'genes.fa.fai',
'reads_1.fq',
'reads_2.fq',
],
[
'assembly.fa.fai',
'assembly.reads_mapped.bam',
'assembly.reads_mapped.bam.vcf',
'assembly_vs_gene.coords',
'assembly_vs_gene.coords.snps',
]
]

for i in range(self.clean + 1):
for fname in to_clean[i]:
fullname = os.path.join(self.root_dir, fname)
if os.path.exists(fullname):
if self.verbose:
print(' rm', fname)
os.unlink(fullname)


def run(self):
self.gene = self._choose_best_gene()

if self.assembler == 'velvet':
self._assemble_with_velvet()
elif self.assembler == 'spades':
Expand Down Expand Up @@ -720,3 +828,4 @@ def run(self):
self._get_vcf_variant_counts()

self._make_report_lines()
self._clean()
Loading

0 comments on commit 85a38fe

Please sign in to comment.