From a642872eb37ba47f2a12e180cdf3aaaecf7b5fbe Mon Sep 17 00:00:00 2001 From: Richard Neher Date: Thu, 7 Sep 2023 15:40:42 +0200 Subject: [PATCH] feat: select branches for confidence plotting --- treetime/treetime.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/treetime/treetime.py b/treetime/treetime.py index 57441650..80fc6514 100644 --- a/treetime/treetime.py +++ b/treetime/treetime.py @@ -773,6 +773,7 @@ def merge_nodes(source_arr, isall=False): def generate_subtree(self, parent): + from .branch_len_interpolator import BranchLenInterpolator # use the random number generator of TreeTime exp_dis = self.rng.exponential @@ -870,6 +871,10 @@ def generate_subtree(self, parent): if hasattr(parent, "_cseq"): new_node._cseq = parent._cseq self.add_branch_state(new_node) + new_node.branch_length_interpolator = BranchLenInterpolator(new_node, self.gtr, + pattern_multiplicity = self.data.multiplicity(mask=new_node.mask), min_width=self.min_width, + one_mutation=self.one_mutation, branch_length_mode=self.branch_length_mode, + n_grid_points = self.branch_grid_points) branches_alive = [b for b in branches_alive if b not in [n1,n2]] + [new_node] remaining_branches = [] @@ -1079,7 +1084,7 @@ def _find_best_root(self, covariation=True, force_positive=True, slope=None, **k return Treg.optimal_reroot(force_positive=force_positive, slope=slope, keep_node_order=self.keep_node_order)['node'] -def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwargs): +def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, selective_confidence=None, **kwargs): ''' Converts branch length to years and plots the time tree on a time axis. @@ -1166,11 +1171,13 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg facecolor=[0.7+0.1*(1+yi%2)] * 3, edgecolor=[1,1,1]) ax.add_patch(r) - if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks: - label_str = "%1.2f"%(step*(year//step)) if step<1 else str(int(year)) - ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str, - horizontalalignment='center') - ax.set_axis_off() + if step>=1: + if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks: + label_str = "%1.2f"%(step*(year//step)) if step<1 else str(int(year)) + ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str, + horizontalalignment='center') + if step>=1: + ax.set_axis_off() # add confidence intervals to the tree graph -- grey bars if confidence: @@ -1185,7 +1192,7 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg raise NotReadyError("confidence needs to be either a float (for max posterior region) or a two numbers specifying lower and upper bounds") for n in tt.tree.find_clades(): - if not n.bad_branch: + if not n.bad_branch and (selective_confidence is None or selective_confidence(n)): pos = cfunc(n, confidence) ax.plot(pos-offset, np.ones(len(pos))*n.ypos, lw=3, c=(0.5,0.5,0.5)) return fig, ax