Skip to content

Commit

Permalink
Merge pull request #248 from neherlab/feat/outlier-detection
Browse files Browse the repository at this point in the history
Feat/outlier detection
  • Loading branch information
rneher committed Jul 7, 2023
2 parents dee02eb + 421aef3 commit 4894b93
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 26 deletions.
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.11.0: new clock filter method

Previously, only a crude analysis of whether the divergence of tips roughly follows a linear trend was implemented. Tips that deviated too much from that regression line were flagged as outliers and this threshold was parameterized as number of interquartile distances of the distribution of residuals `n_iqd`.
This filter is not very sensitive and often misses misdated tips that severely distort the tree but still fall within the distribution of root-to-tip distances at that time.
To overcome this, we implemented a novel filtering method that fits a simple gaussian model of divergence accumulation.

# 0.10.1: bug fix release

* avoid probability loss at the end of the domain of distributions
Expand Down
6 changes: 5 additions & 1 deletion treetime/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ def add_aln_group(parser, required=True):
def add_reroot_group(parser):
parser.add_argument('--clock-filter', type=float, default=4.0,
help="ignore tips that don't follow a loose clock, "
"'clock-filter=number of interquartile ranges from regression'. "
"'clock-filter=number of interquartile ranges from regression (method=`residual`)' "
"or z-score of local clock deviation (method=`local`). "
"Default=4.0, set to 0 to switch off.")
parser.add_argument('--clock-filter-method', choices=['residual', 'local'], default='residual',
help="Use residuals from global clock (`residual`, default) or local clock deviation (`clock`) "
"to filter out tips that don't follow the clock")
reroot_group = parser.add_mutually_exclusive_group()
reroot_group.add_argument('--reroot', nargs='+', default='best', help=reroot_description)
reroot_group.add_argument('--keep-root', required = False, action="store_true", default=False,
Expand Down
173 changes: 173 additions & 0 deletions treetime/clock_filter_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import numpy as np
import pandas as pd

def residual_filter(tt, n_iqd):
terminals = tt.tree.get_terminals()
clock_rate = tt.clock_model['slope']
icpt = tt.clock_model['intercept']
res = {}
for node in terminals:
if hasattr(node, 'raw_date_constraint') and (node.raw_date_constraint is not None):
res[node] = node.dist2root - clock_rate*np.mean(node.raw_date_constraint) - icpt

residuals = np.array(list(res.values()))
iqd = np.percentile(residuals,75) - np.percentile(residuals,25)
outliers = {}
for node,r in res.items():
if abs(r)>n_iqd*iqd and node.up.up is not None:
node.bad_branch=True
outliers[node.name] = {'tau':(node.dist2root - icpt)/clock_rate, 'avg_date': np.mean(node.raw_date_constraint),
'exact_date': node.raw_date_constraint if type(node) is float else None,
'residual': r/iqd}
else:
node.bad_branch=False

if len(outliers):
outlier_df = pd.DataFrame(outliers).T.loc[:,['avg_date', 'tau', 'residual']]\
.rename(columns={'avg_date':'given_date', 'tau':'apparent_date'})
tt.logger("Clock_filter.residual_filter marked the following outliers:", 2, warn=True)
if tt.verbose>=2:
print(outlier_df)
return len(outliers)

def local_filter(tt, z_score_threshold):
tt.logger(f"TreeTime.ClockFilter: starting local_outlier_detection", 2)

node_info = collect_node_info(tt)

node_info, z_scale = calculate_node_timings(tt, node_info)
tt.logger(f"TreeTime.ClockFilter: z-scale {z_scale:1.2f}", 2)

outliers = flag_outliers(tt, node_info, z_score_threshold, z_scale)

for n in tt.tree.get_terminals():
if n.name in outliers:
n.bad_branch = True

if len(outliers):
outlier_df = pd.DataFrame(outliers).T.loc[:,['avg_date', 'tau', 'z']]\
.rename(columns={'avg_date':'given_date', 'tau':'apparent_date'})
tt.logger("Clock_filter.local_filter marked the following outliers", 2, warn=True)
if tt.verbose>=2:
print(outlier_df)
return len(outliers)


def flag_outliers(tt, node_info, z_score_threshold, z_scale):
outliers = {}
for n in tt.tree.get_terminals():
n_info = node_info[n.name]
if n_info['exact_date']:
z = (n_info['avg_date'] - n_info['tau'])/z_scale
if np.abs(z) > z_score_threshold:
n_info['z'] = z
outliers[n.name] = n_info
elif n.raw_date_constraint and len(n.raw_date_constraint):
zs = [(n_info['tau'] - x)/z_scale for x in n.raw_date_constraint]
if zs[0]*zs[1]>0 and np.min(np.abs(zs))>z_score_threshold:
n_info['z'] = z
outliers[n.name] = n_info

return outliers

def calculate_node_timings(tt, node_info, eps=0.2):
mu = tt.clock_model['slope']*tt.data.full_length
sigma_sq = (3/mu)**2
tt.logger(f"Clockfilter.calculate_node_timings: mu={mu:1.3e}/y, sigma={3/mu:1.3e}y", 2)
for n in tt.tree.find_clades(order='postorder'):
p = node_info[n.name]
if not p['exact_date'] or p['skip']:
continue

if n.is_terminal():
prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps))
p["a"] = (p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps))/prefactor
else:
children = [node_info[c.name] for c in n if (not node_info[c.name]['skip']) and node_info[c.name]['exact_date']]
if n==tt.tree.root:
tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
prefactor = (p["observations"]/sigma_sq + tmp_children_2)
p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + tmp_children_1)/prefactor
else:
tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps) + tmp_children_2)
p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps)+tmp_children_1)/prefactor
p["b"] = mu**2/(p["nmuts"]+eps)/prefactor

node_info[tt.tree.root.name]["tau"] = node_info[tt.tree.root.name]["a"]

## need to deal with tips without exact dates below.
dev = []
for n in tt.tree.get_nonterminals(order='preorder'):
p = node_info[n.name]
for c in n:
c_info = node_info[c.name]
if c_info['skip']:
c_info['tau']=p['tau']
else:
if c_info['exact_date']:
c_info["tau"] = c_info["a"] + c_info["b"]*p["tau"]
else:
c_info["tau"] = p["tau"] + c_info['nmuts']/mu
if c.is_terminal() and c_info['exact_date']:
dev.append(c_info['avg_date']-c_info['tau'])

return node_info, np.std(dev)


def collect_node_info(tt, percentile_for_exact_date=90):
node_info = {}
aln = tt.aln or False
if aln and (not tt.sequence_reconstruction):
tt.infer_ancestral_sequences(infer_gtr=False)
L = tt.data.full_length

date_uncertainty = [np.abs(n.raw_date_constraint[1]-n.raw_date_constraint[0])
if type(n.raw_date_constraint)!=float else 0.0
for n in tt.tree.get_terminals()
if n.raw_date_constraint is not None]
from scipy.stats import scoreatpercentile
uncertainty_cutoff = scoreatpercentile(date_uncertainty, percentile_for_exact_date)*1.01

for n in tt.tree.get_nonterminals(order='postorder'):
parent = {"dates": [], "tips": {}, "skip":False}
exact_dates = 0
for c in n:
if c.is_terminal():
child = {'skip':False}
child["nmuts"] = len([m for m in c.mutations if m[-1] in 'ACGT']) if aln \
else np.round(c.branch_length*L)
if c.raw_date_constraint is None:
child['exact_date'] = False
elif type(c.raw_date_constraint)==float:
child['exact_date'] = True
else:
child['exact_date'] = np.abs(c.raw_date_constraint[1]-c.raw_date_constraint[0])<=uncertainty_cutoff

if child['exact_date']:
exact_dates += 1
if child["nmuts"]==0:
child['skip'] = True
parent["tips"][c.name]={'date': np.mean(c.raw_date_constraint),
'exact_date':child['exact_date']}
else:
child['skip'] = False
child['observations'] = 1
if c.raw_date_constraint is not None:
child["avg_date"] = np.mean(c.raw_date_constraint)
node_info[c.name] = child
else:
if node_info[c.name]['exact_date']:
exact_dates += 1

parent['exact_date'] = exact_dates>0

parent["nmuts"] = len([m for m in n.mutations if m[-1] in 'ACGT']) if aln else np.round(n.branch_length*L)
d = [v['date'] for v in parent['tips'].values() if v['exact_date']]
parent["observations"] = len(d)
parent["avg_date"] = np.mean(d) if len(d) else 0.0
node_info[n.name] = parent

return node_info
36 changes: 13 additions & 23 deletions treetime/treetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def run(self, raise_uncaught_exceptions=False, **kwargs):
sys.exit(2)


def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
resolve_polytomies=True, max_iter=0, Tc=None, fixed_clock_rate=None,
def _run(self, root=None, infer_gtr=True, relaxed_clock=None, clock_filter_method='residuals',
n_iqd = None, resolve_polytomies=True, max_iter=0, Tc=None, fixed_clock_rate=None,
time_marginal='never', sequence_marginal=False, branch_length_mode='auto',
vary_rate=False, use_covariation=False, tracelog_file=None,
method_anc = 'probabilistic', assign_gamma=None, stochastic_resolve=False,
Expand Down Expand Up @@ -222,7 +222,8 @@ def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
else:
plot_rtt=False
reroot_mechanism = 'least-squares' if root=='clock_filter' else root
self.clock_filter(reroot=reroot_mechanism, n_iqd=n_iqd, plot=plot_rtt, fixed_clock_rate=fixed_clock_rate)
self.clock_filter(reroot=reroot_mechanism, method=clock_filter_method,
n_iqd=n_iqd, plot=plot_rtt, fixed_clock_rate=fixed_clock_rate)
elif root is not None:
self.reroot(root=root, clock_rate=fixed_clock_rate)

Expand Down Expand Up @@ -383,7 +384,8 @@ def _set_branch_length_mode(self, branch_length_mode):
self.branch_length_mode = 'input'


def clock_filter(self, reroot='least-squares', n_iqd=None, plot=False, fixed_clock_rate=None):
def clock_filter(self, reroot='least-squares', method='residual',
n_iqd=None, plot=False, fixed_clock_rate=None):
r'''
Labels outlier branches that don't seem to follow a molecular clock
and excludes them from subsequent molecular clock estimation and
Expand All @@ -405,34 +407,22 @@ def clock_filter(self, reroot='least-squares', n_iqd=None, plot=False, fixed_clo
If True, plot the results
'''
from .clock_filter_methods import residual_filter, local_filter
if n_iqd is None:
n_iqd = ttconf.NIQD
if type(reroot) is list and len(reroot)==1:
reroot=str(reroot[0])

terminals = self.tree.get_terminals()
if reroot:
self.reroot(root='least-squares' if reroot=='best' else reroot, covariation=False, clock_rate=fixed_clock_rate)
self.reroot(root='least-squares' if reroot=='best' else reroot,
covariation=False, clock_rate=fixed_clock_rate)
else:
self.get_clock_model(covariation=False, slope=fixed_clock_rate)

clock_rate = self.clock_model['slope']
icpt = self.clock_model['intercept']
res = {}
for node in terminals:
if hasattr(node, 'raw_date_constraint') and (node.raw_date_constraint is not None):
res[node] = node.dist2root - clock_rate*np.mean(node.raw_date_constraint) - icpt

residuals = np.array(list(res.values()))
iqd = np.percentile(residuals,75) - np.percentile(residuals,25)
bad_branch_count = 0
for node,r in res.items():
if abs(r)>n_iqd*iqd and node.up.up is not None:
self.logger('TreeTime.ClockFilter: marking %s as outlier, residual %f interquartile distances'%(node.name,r/iqd), 3, warn=True)
node.bad_branch=True
bad_branch_count += 1
else:
node.bad_branch=False
if method=='residual':
bad_branch_count = residual_filter(self, n_iqd)
elif method=='local':
bad_branch_count = local_filter(self, n_iqd)

if bad_branch_count>0.34*self.tree.count_terminals():
self.logger("TreeTime.clock_filter: More than a third of leaves have been excluded by the clock filter. Please check your input data.", 0, warn=True)
Expand Down
5 changes: 3 additions & 2 deletions treetime/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
stochastic_resolve = stochastic_resolve,
Tc=coalescent, max_iter=params.max_iter,
fixed_clock_rate=params.clock_rate,
n_iqd=params.clock_filter,
n_iqd=params.clock_filter, clock_filter_method=params.clock_filter_method,
time_marginal="confidence-only" if (calc_confidence and time_marginal=='never') else time_marginal,
vary_rate = vary_rate,
branch_length_mode = branch_length_mode,
Expand Down Expand Up @@ -809,7 +809,8 @@ def estimate_clock_model(params):
myTree.tip_slack=params.tip_slack
if params.clock_filter:
n_bad = [n.name for n in myTree.tree.get_terminals() if n.bad_branch]
myTree.clock_filter(n_iqd=params.clock_filter, reroot=params.reroot or 'least-squares')
myTree.clock_filter(n_iqd=params.clock_filter, reroot=params.reroot or 'least-squares',
method=params.clock_filter_method)
n_bad_after = [n.name for n in myTree.tree.get_terminals() if n.bad_branch]
if len(n_bad_after)>len(n_bad):
print("The following leaves don't follow a loose clock and "
Expand Down

0 comments on commit 4894b93

Please sign in to comment.