Skip to content

Commit

Permalink
feat: select branches for confidence plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
rneher committed Sep 7, 2023
1 parent 354899c commit a642872
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions treetime/treetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ def merge_nodes(source_arr, isall=False):


def generate_subtree(self, parent):
from .branch_len_interpolator import BranchLenInterpolator
# use the random number generator of TreeTime
exp_dis = self.rng.exponential

Expand Down Expand Up @@ -870,6 +871,10 @@ def generate_subtree(self, parent):
if hasattr(parent, "_cseq"):
new_node._cseq = parent._cseq
self.add_branch_state(new_node)
new_node.branch_length_interpolator = BranchLenInterpolator(new_node, self.gtr,
pattern_multiplicity = self.data.multiplicity(mask=new_node.mask), min_width=self.min_width,
one_mutation=self.one_mutation, branch_length_mode=self.branch_length_mode,
n_grid_points = self.branch_grid_points)
branches_alive = [b for b in branches_alive if b not in [n1,n2]] + [new_node]

remaining_branches = []
Expand Down Expand Up @@ -1079,7 +1084,7 @@ def _find_best_root(self, covariation=True, force_positive=True, slope=None, **k
return Treg.optimal_reroot(force_positive=force_positive, slope=slope, keep_node_order=self.keep_node_order)['node']


def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwargs):
def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, selective_confidence=None, **kwargs):
'''
Converts branch length to years and plots the time tree on a time axis.
Expand Down Expand Up @@ -1166,11 +1171,13 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg
facecolor=[0.7+0.1*(1+yi%2)] * 3,
edgecolor=[1,1,1])
ax.add_patch(r)
if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks:
label_str = "%1.2f"%(step*(year//step)) if step<1 else str(int(year))
ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str,
horizontalalignment='center')
ax.set_axis_off()
if step>=1:
if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks:
label_str = "%1.2f"%(step*(year//step)) if step<1 else str(int(year))
ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str,
horizontalalignment='center')
if step>=1:
ax.set_axis_off()

# add confidence intervals to the tree graph -- grey bars
if confidence:
Expand All @@ -1185,7 +1192,7 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg
raise NotReadyError("confidence needs to be either a float (for max posterior region) or a two numbers specifying lower and upper bounds")

for n in tt.tree.find_clades():
if not n.bad_branch:
if not n.bad_branch and (selective_confidence is None or selective_confidence(n)):
pos = cfunc(n, confidence)
ax.plot(pos-offset, np.ones(len(pos))*n.ypos, lw=3, c=(0.5,0.5,0.5))
return fig, ax
Expand Down

0 comments on commit a642872

Please sign in to comment.