diff --git a/changelog.md b/changelog.md index e54bf507..fc59aa40 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,8 @@ +# Next: add auspice.json output + + * the output directory now contains a json file that is compatible with auspice.us. Both time scaled phylogenies and ancestral inferences can now be visualized and explored using auspice. Available colorings are "Date", "genotype", "Branch support", and Excluded. See [PR #232](https://github.com/neherlab/treetime/pull/232) for details. + * move most function related to IO of the command line wrappers into a separate file. + # 0.9.6: bug fixes and new mode of polytomy resolution * in cases when very large polytomies are resolved, the multiplication of the discretized message results in messages/distributions of length 1. This resulted in an error, since interpolation objects require at least two points. This is now caught and a small discrete grid created. * increase recursion limit to 10000 by default. The recursion limit can now also be set via the environment variable `TREETIME_RECURSION_LIMIT`. diff --git a/treetime/CLI_io.py b/treetime/CLI_io.py new file mode 100644 index 00000000..131c5852 --- /dev/null +++ b/treetime/CLI_io.py @@ -0,0 +1,292 @@ +import os, sys +from Bio import AlignIO, Phylo +from .vcf_utils import read_vcf, write_vcf +from .seq_utils import alphabets +from Bio import __version__ as bioversion +import numpy as np + +def get_outdir(params, suffix='_treetime'): + if params.outdir: + if os.path.exists(params.outdir): + if os.path.isdir(params.outdir): + return params.outdir.rstrip('/') + '/' + else: + print("designated output location %s is not a directory"%params.outdir, file=sys.stderr) + else: + os.makedirs(params.outdir) + return params.outdir.rstrip('/') + '/' + + from datetime import datetime + outdir_stem = datetime.now().date().isoformat() + outdir = outdir_stem + suffix.rstrip('/')+'/' + count = 1 + while os.path.exists(outdir): + outdir = outdir_stem + '-%04d'%count + suffix.rstrip('/')+'/' + count += 1 + + os.makedirs(outdir) + return outdir + +def get_basename(params, outdir): + # if params.aln: + # basename = outdir + '.'.join(params.aln.split('/')[-1].split('.')[:-1]) + # elif params.tree: + # basename = outdir + '.'.join(params.tree.split('/')[-1].split('.')[:-1]) + # else: + basename = outdir + return basename + +def read_in_DRMs(drm_file, offset): + import pandas as pd + + DRMs = {} + drmPositions = [] + + df = pd.read_csv(drm_file, sep='\t') + for mi, m in df.iterrows(): + pos = m.GENOMIC_POSITION-1+offset #put in correct numbering + drmPositions.append(pos) + + if pos in DRMs: + DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION + else: + DRMs[pos] = {} + DRMs[pos]['drug'] = m.DRUG + DRMs[pos]['alt_base'] = {} + DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION + DRMs[pos]['gene'] = m.GENE + + drmPositions = np.array(drmPositions) + drmPositions = np.unique(drmPositions) + drmPositions = np.sort(drmPositions) + + DRM_info = {'DRMs': DRMs, + 'drmPositions': drmPositions} + + return DRM_info + + +def read_if_vcf(params): + """ + Checks if input is VCF and reads in appropriately if it is + """ + ref = None + aln = params.aln + fixed_pi = None + if hasattr(params, 'aln') and params.aln is not None: + if any([params.aln.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]): + if not params.vcf_reference: + print("ERROR: a reference Fasta is required with VCF-format alignments") + return -1 + compress_seq = read_vcf(params.aln, params.vcf_reference) + sequences = compress_seq['sequences'] + ref = compress_seq['reference'] + aln = sequences + + if not hasattr(params, 'gtr') or params.gtr=="infer": #if not specified, set it: + alpha = alphabets['aa'] if params.aa else alphabets['nuc'] + fixed_pi = [ref.count(base)/len(ref) for base in alpha] + if fixed_pi[-1] == 0: + fixed_pi[-1] = 0.05 + fixed_pi = [v-0.01 for v in fixed_pi] + + return aln, ref, fixed_pi + + +def plot_rtt(tt, fname): + tt.plot_root_to_tip() + + from matplotlib import pyplot as plt + plt.savefig(fname) + print("--- root-to-tip plot saved to \n\t"+fname) + + +def export_sequences_and_tree(tt, basename, is_vcf=False, zero_based=False, + report_ambiguous=False, timetree=False, confidence=False, + reconstruct_tip_states=False, tree_suffix=''): + seq_info = is_vcf or tt.aln + if is_vcf: + outaln_name = basename + f'ancestral_sequences{tree_suffix}.vcf' + write_vcf(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name) + elif tt.aln: + outaln_name = basename + f'ancestral_sequences{tree_suffix}.fasta' + AlignIO.write(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name, 'fasta') + if seq_info: + print("\n--- alignment including ancestral nodes saved as \n\t %s\n"%outaln_name) + + # decorate tree with inferred mutations + terminal_count = 0 + offset = 0 if zero_based else 1 + if timetree: + dates_fname = basename + f'dates{tree_suffix}.tsv' + fh_dates = open(dates_fname, 'w', encoding='utf-8') + if confidence: + fh_dates.write('#Lower and upper bound delineate the 90% max posterior region\n') + fh_dates.write('#node\tdate\tnumeric date\tlower bound\tupper bound\n') + else: + fh_dates.write('#node\tdate\tnumeric date\n') + + mutations_out = open(basename + "branch_mutations.txt", "w") + mutations_out.write("node\tstate1\tpos\tstate2\n") + for n in tt.tree.find_clades(): + if timetree: + if confidence: + if n.bad_branch: + fh_dates.write('%s\t--\t--\t--\t--\n'%(n.name)) + else: + conf = tt.get_max_posterior_region(n, fraction=0.9) + fh_dates.write('%s\t%s\t%f\t%f\t%f\n'%(n.name, n.date, n.numdate,conf[0], conf[1])) + else: + if n.bad_branch: + fh_dates.write('%s\t--\t--\n'%(n.name)) + else: + fh_dates.write('%s\t%s\t%f\n'%(n.name, n.date, n.numdate)) + + n.confidence=None + # due to a bug in older versions of biopython that truncated filenames in nexus export + # we truncate them by hand and make them unique. + if n.is_terminal() and len(n.name)>40 and bioversion<"1.69": + n.name = n.name[:35]+'_%03d'%terminal_count + terminal_count+=1 + n.comment='' + if seq_info and len(n.mutations): + if n.mask is None: + if report_ambiguous: + n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations])+'"' + else: + n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations + if tt.gtr.ambiguous not in [a,d]])+'"' + else: + if report_ambiguous: + n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations if n.mask[pos]>0])+f'",mcc="{n.mcc}"' + else: + n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations + if tt.gtr.ambiguous not in [a,d] and n.mask[pos]>0])+f'",mcc="{n.mcc}"' + + for (a, pos, d) in n.mutations: + if tt.gtr.ambiguous not in [a,d] or report_ambiguous: + mutations_out.write("%s\t%s\t%s\t%s\n" %(n.name, a, pos + 1, d)) + if timetree: + n.comment+=(',' if n.comment else '&') + 'date=%1.2f'%n.numdate + mutations_out.close() + + # write tree to file + fmt_bl = "%1.6f" if tt.data.full_length<1e6 else "%1.8e" + if timetree: + outtree_name = basename + f'timetree{tree_suffix}.nexus' + print("--- saved divergence times in \n\t %s\n"%dates_fname) + Phylo.write(tt.tree, outtree_name, 'nexus') + else: + outtree_name = basename + f'annotated_tree{tree_suffix}.nexus' + Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl) + print("--- tree saved in nexus format as \n\t %s\n"%outtree_name) + + auspice = create_auspice_json(tt, timetree=timetree, confidence=confidence) + outtree_name_json = basename + f'auspice_tree{tree_suffix}.json' + with open(outtree_name_json, 'w') as fh: + import json + json.dump(auspice, fh, indent=0) + print("--- tree saved in auspice json format as \n\t %s\n"%outtree_name_json) + + if timetree: + for n in tt.tree.find_clades(): + n.branch_length = n.mutation_length + outtree_name = basename + f'divergence_tree{tree_suffix}.nexus' + Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl) + print("--- divergence tree saved in nexus format as \n\t %s\n"%outtree_name) + + +def print_save_plot_skyline(tt, n_std=2.0, screen=True, save='', plot='', gen=50): + if plot: + import matplotlib.pyplot as plt + + skyline, conf = tt.merger_model.skyline_inferred(gen=gen, confidence=n_std) + if save: fh = open(save, 'w', encoding='utf-8') + header1 = "Skyline assuming "+ str(gen)+" gen/year and approximate confidence bounds (+/- %f standard deviations of the LH)\n"%n_std + header2 = "date \tN_e \tlower \tupper" + if screen: print('\t'+header1+'\t'+header2) + if save: fh.write("#"+ header1+'#'+header2+'\n') + for (x,y, y1, y2) in zip(skyline.x, skyline.y, conf[0], conf[1]): + if screen: print("\t%1.3f\t%1.3e\t%1.3e\t%1.3e"%(x,y, y1, y2)) + if save: fh.write("%1.3f\t%1.3e\t%1.3e\t%1.3e\n"%(x,y, y1, y2)) + + if save: + print("\n --- written skyline to %s\n"%save) + fh.close() + + if plot: + plt.figure() + plt.fill_between(skyline.x, conf[0], conf[1], color=(0.8, 0.8, 0.8)) + plt.plot(skyline.x, skyline.y, label='maximum likelihood skyline') + plt.yscale('log') + plt.legend() + plt.ticklabel_format(axis='x',useOffset=False) + plt.savefig(plot) + + + +def create_auspice_json(tt, timetree=False, confidence=False): + # mock up meta data for auspice json + meta = { + "title": "Auspice visualization of TreeTime analysis", + "build_url": "https://github.com/neherlab/treetime", + "genome_annotations": { + "nuc":{"start":1, "end":int(tt.data.full_length), "type":"source", "strand":"+:"} + }, + "panels":["tree", "entropy"], + "colorings": [ + { + "title": "Date", + "type": "continuous", + "key": "gt", + }, + { + "title": "Genotype", + "type": "categorical", + "key": "num_date", + }, + { + "title": "Excluded", + "type": "categorical", + "key": "bad_branch" + }, + { + "title": "Branch Support", + "type": "continuous", + "key": "confidence" + } + ] + } + + def node_to_json(n, pdiv=0.0): + j = {"name":n.name, "node_attrs":{}, "branch_attrs":{}} + if n.clades: + j["children"] = [] + + if timetree: + j["node_attrs"]["num_date"] = {"value":float(n.numdate)} + if confidence: + conf = tt.get_max_posterior_region(n, fraction=0.9) + j["node_attrs"]["num_date"]["confidence"] = (float(conf[0]), float(conf[1])) + j["node_attrs"]["div"] = float(pdiv + n.mutation_length) + j["node_attrs"]["bad_branch"] = {"value": "Yes" if n.bad_branch else "No"} + + j["branch_attrs"]["mutations"] = {"nuc": [f"{a}{pos+1}{d}" for a,pos,d in n.mutations if d in "ACGT-"]} + # generate bootstrap confidence substitute via the negative exponential of the number of mutations + # this is the bootstrap confidence for iid mutations (only ACGT mutations) + j["node_attrs"]["confidence"] = {"value":round(1-np.exp(-len([pos for a,pos,d in n.mutations if d in "ACGT"])),3) + if not n.is_terminal() else 1.0} + return j + + # create the tree data structure from the Biopython tree + tree = node_to_json(tt.tree.root, 0.0) + # dictionary to look up nodes by name + node_lookup = {tt.tree.root.name: tree} + for n in tt.tree.get_nonterminals(): + n_json = node_lookup[n.name] + for c in n.clades: + # generate node jsons for all children and attach them the to parent + n_json["children"].append(node_to_json(c, n_json["node_attrs"]["div"])) + node_lookup[c.name] = n_json["children"][-1] + + return {"meta":meta, "tree":tree} diff --git a/treetime/argument_parser.py b/treetime/argument_parser.py index f2012e95..a646d1b7 100644 --- a/treetime/argument_parser.py +++ b/treetime/argument_parser.py @@ -125,10 +125,10 @@ def add_aln_group(parser, required=True): def add_reroot_group(parser): - parser.add_argument('--clock-filter', type=float, default=3, + parser.add_argument('--clock-filter', type=float, default=4.0, help="ignore tips that don't follow a loose clock, " "'clock-filter=number of interquartile ranges from regression'. " - "Default=3.0, set to 0 to switch off.") + "Default=4.0, set to 0 to switch off.") reroot_group = parser.add_mutually_exclusive_group() reroot_group.add_argument('--reroot', nargs='+', default='best', help=reroot_description) reroot_group.add_argument('--keep-root', required = False, action="store_true", default=False, diff --git a/treetime/treetime.py b/treetime/treetime.py index 5883bbfa..c282a727 100644 --- a/treetime/treetime.py +++ b/treetime/treetime.py @@ -667,7 +667,6 @@ def cost_gain(n1, n2, parent): method='bounded',args=(n1,n2, parent), options={'xatol':1e-4*self.one_mutation}) return cg['x'], - cg['fun'] except: - import ipdb; ipdb.set_trace() self.logger("TreeTime._poly.cost_gain: optimization of gain failed", 3, warn=True) return parent.time_before_present, 0.0 @@ -801,16 +800,21 @@ def generate_subtree(self, parent): # loop until time collides with the parent node or all but two branches have been dealt with # the remaining two would be the children of the parent while len(branches_alive)+len(branches_to_come)>2 and t advance to next branch if (total_mut_rate + total_coalescent_rate)==0 and len(branches_to_come): branches_alive.append(branches_to_come.pop(0)) diff --git a/treetime/wrappers.py b/treetime/wrappers.py index e707b8ab..7b88f939 100644 --- a/treetime/wrappers.py +++ b/treetime/wrappers.py @@ -2,14 +2,13 @@ import numpy as np import pandas as pd from textwrap import fill -from Bio import Phylo, AlignIO +from Bio import Phylo from Bio import __version__ as bioversion from . import TreeAnc, GTR, TreeTime from . import utils -from .vcf_utils import read_vcf, write_vcf -from .seq_utils import alphabets from . import TreeTimeError, MissingDataError, UnknownMethodError from .treetime import reduce_time_marginal_argument +from .CLI_io import * def assure_tree(params, tmp_dir='treetime_tmp'): """ @@ -76,218 +75,6 @@ def create_gtr(params): return gtr -def get_outdir(params, suffix='_treetime'): - if params.outdir: - if os.path.exists(params.outdir): - if os.path.isdir(params.outdir): - return params.outdir.rstrip('/') + '/' - else: - print("designated output location %s is not a directory"%params.outdir, file=sys.stderr) - else: - os.makedirs(params.outdir) - return params.outdir.rstrip('/') + '/' - - from datetime import datetime - outdir_stem = datetime.now().date().isoformat() - outdir = outdir_stem + suffix.rstrip('/')+'/' - count = 1 - while os.path.exists(outdir): - outdir = outdir_stem + '-%04d'%count + suffix.rstrip('/')+'/' - count += 1 - - os.makedirs(outdir) - return outdir - -def get_basename(params, outdir): - # if params.aln: - # basename = outdir + '.'.join(params.aln.split('/')[-1].split('.')[:-1]) - # elif params.tree: - # basename = outdir + '.'.join(params.tree.split('/')[-1].split('.')[:-1]) - # else: - basename = outdir - return basename - -def read_in_DRMs(drm_file, offset): - import pandas as pd - - DRMs = {} - drmPositions = [] - - df = pd.read_csv(drm_file, sep='\t') - for mi, m in df.iterrows(): - pos = m.GENOMIC_POSITION-1+offset #put in correct numbering - drmPositions.append(pos) - - if pos in DRMs: - DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION - else: - DRMs[pos] = {} - DRMs[pos]['drug'] = m.DRUG - DRMs[pos]['alt_base'] = {} - DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION - DRMs[pos]['gene'] = m.GENE - - drmPositions = np.array(drmPositions) - drmPositions = np.unique(drmPositions) - drmPositions = np.sort(drmPositions) - - DRM_info = {'DRMs': DRMs, - 'drmPositions': drmPositions} - - return DRM_info - - -def read_if_vcf(params): - """ - Checks if input is VCF and reads in appropriately if it is - """ - ref = None - aln = params.aln - fixed_pi = None - if hasattr(params, 'aln') and params.aln is not None: - if any([params.aln.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]): - if not params.vcf_reference: - print("ERROR: a reference Fasta is required with VCF-format alignments") - return -1 - compress_seq = read_vcf(params.aln, params.vcf_reference) - sequences = compress_seq['sequences'] - ref = compress_seq['reference'] - aln = sequences - - if not hasattr(params, 'gtr') or params.gtr=="infer": #if not specified, set it: - alpha = alphabets['aa'] if params.aa else alphabets['nuc'] - fixed_pi = [ref.count(base)/len(ref) for base in alpha] - if fixed_pi[-1] == 0: - fixed_pi[-1] = 0.05 - fixed_pi = [v-0.01 for v in fixed_pi] - - return aln, ref, fixed_pi - - -def plot_rtt(tt, fname): - tt.plot_root_to_tip() - - from matplotlib import pyplot as plt - plt.savefig(fname) - print("--- root-to-tip plot saved to \n\t"+fname) - - -def export_sequences_and_tree(tt, basename, is_vcf=False, zero_based=False, - report_ambiguous=False, timetree=False, confidence=False, - reconstruct_tip_states=False, tree_suffix=''): - seq_info = is_vcf or tt.aln - if is_vcf: - outaln_name = basename + f'ancestral_sequences{tree_suffix}.vcf' - write_vcf(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name) - elif tt.aln: - outaln_name = basename + f'ancestral_sequences{tree_suffix}.fasta' - AlignIO.write(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name, 'fasta') - if seq_info: - print("\n--- alignment including ancestral nodes saved as \n\t %s\n"%outaln_name) - - # decorate tree with inferred mutations - terminal_count = 0 - offset = 0 if zero_based else 1 - if timetree: - dates_fname = basename + f'dates{tree_suffix}.tsv' - fh_dates = open(dates_fname, 'w', encoding='utf-8') - if confidence: - fh_dates.write('#Lower and upper bound delineate the 90% max posterior region\n') - fh_dates.write('#node\tdate\tnumeric date\tlower bound\tupper bound\n') - else: - fh_dates.write('#node\tdate\tnumeric date\n') - - mutations_out = open(basename + "branch_mutations.txt", "w") - mutations_out.write("node\tstate1\tpos\tstate2\n") - for n in tt.tree.find_clades(): - if timetree: - if confidence: - if n.bad_branch: - fh_dates.write('%s\t--\t--\t--\t--\n'%(n.name)) - else: - conf = tt.get_max_posterior_region(n, fraction=0.9) - fh_dates.write('%s\t%s\t%f\t%f\t%f\n'%(n.name, n.date, n.numdate,conf[0], conf[1])) - else: - if n.bad_branch: - fh_dates.write('%s\t--\t--\n'%(n.name)) - else: - fh_dates.write('%s\t%s\t%f\n'%(n.name, n.date, n.numdate)) - - n.confidence=None - # due to a bug in older versions of biopython that truncated filenames in nexus export - # we truncate them by hand and make them unique. - if n.is_terminal() and len(n.name)>40 and bioversion<"1.69": - n.name = n.name[:35]+'_%03d'%terminal_count - terminal_count+=1 - n.comment='' - if seq_info and len(n.mutations): - if n.mask is None: - if report_ambiguous: - n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations])+'"' - else: - n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations - if tt.gtr.ambiguous not in [a,d]])+'"' - else: - if report_ambiguous: - n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations if n.mask[pos]>0])+f'",mcc="{n.mcc}"' - else: - n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations - if tt.gtr.ambiguous not in [a,d] and n.mask[pos]>0])+f'",mcc="{n.mcc}"' - - for (a, pos, d) in n.mutations: - if tt.gtr.ambiguous not in [a,d] or report_ambiguous: - mutations_out.write("%s\t%s\t%s\t%s\n" %(n.name, a, pos + 1, d)) - if timetree: - n.comment+=(',' if n.comment else '&') + 'date=%1.2f'%n.numdate - mutations_out.close() - - # write tree to file - fmt_bl = "%1.6f" if tt.data.full_length<1e6 else "%1.8e" - if timetree: - outtree_name = basename + f'timetree{tree_suffix}.nexus' - print("--- saved divergence times in \n\t %s\n"%dates_fname) - Phylo.write(tt.tree, outtree_name, 'nexus') - else: - outtree_name = basename + f'annotated_tree{tree_suffix}.nexus' - Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl) - print("--- tree saved in nexus format as \n\t %s\n"%outtree_name) - - if timetree: - for n in tt.tree.find_clades(): - n.branch_length = n.mutation_length - outtree_name = basename + f'divergence_tree{tree_suffix}.nexus' - Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl) - print("--- divergence tree saved in nexus format as \n\t %s\n"%outtree_name) - - -def print_save_plot_skyline(tt, n_std=2.0, screen=True, save='', plot='', gen=50): - if plot: - import matplotlib.pyplot as plt - - skyline, conf = tt.merger_model.skyline_inferred(gen=gen, confidence=n_std) - if save: fh = open(save, 'w', encoding='utf-8') - header1 = "Skyline assuming "+ str(gen)+" gen/year and approximate confidence bounds (+/- %f standard deviations of the LH)\n"%n_std - header2 = "date \tN_e \tlower \tupper" - if screen: print('\t'+header1+'\t'+header2) - if save: fh.write("#"+ header1+'#'+header2+'\n') - for (x,y, y1, y2) in zip(skyline.x, skyline.y, conf[0], conf[1]): - if screen: print("\t%1.3f\t%1.3e\t%1.3e\t%1.3e"%(x,y, y1, y2)) - if save: fh.write("%1.3f\t%1.3e\t%1.3e\t%1.3e\n"%(x,y, y1, y2)) - - if save: - print("\n --- written skyline to %s\n"%save) - fh.close() - - if plot: - plt.figure() - plt.fill_between(skyline.x, conf[0], conf[1], color=(0.8, 0.8, 0.8)) - plt.plot(skyline.x, skyline.y, label='maximum likelihood skyline') - plt.yscale('log') - plt.legend() - plt.ticklabel_format(axis='x',useOffset=False) - plt.savefig(plot) - - def scan_homoplasies(params): """ the function implementing treetime homoplasies