Skip to content

Commit

Permalink
Merge pull request #235 from neherlab/fix/narrow-convolutions
Browse files Browse the repository at this point in the history
more strict condition for fall back on delta function convolution
  • Loading branch information
rneher committed May 2, 2023
2 parents 6da5884 + f2de79c commit 23cd008
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* move most function related to IO of the command line wrappers into a separate file.
* make TreeTime own its random number generator and add `--rng-seed` to control state in CLI. Any previous usage of `numpy.random.seed` will now be ignored in favor of `--rng-seed`. See [PR #234](https://github.com/neherlab/treetime/pull/234)
* add flag `--greedy-resolve` as inverse to `--stochastic-resolve` with the aim of switching the two.
* tighten conditions that trigger approximation of narrow distribution as a delta function in convolution using FFT [PR #235](https://github.com/neherlab/treetime/pull/235).

# 0.9.6: bug fixes and new mode of polytomy resolution
* in cases when very large polytomies are resolved, the multiplication of the discretized message results in messages/distributions of length 1. This resulted in an error, since interpolation objects require at least two points. This is now caught and a small discrete grid created.
Expand Down
9 changes: 7 additions & 2 deletions treetime/clock_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,13 @@ def _cleanup():
if hasattr(self, 'merger_model') and self.merger_model:
time_points = parent.marginal_pos_LH.x
if len(time_points)<5:
time_points = np.linspace(np.min([x.xmin for x in complementary_msgs]),
np.max([x.xmax for x in complementary_msgs]), 50)
time_points = np.unique(np.concatenate([
time_points,
np.linspace(np.min([x.xmin for x in complementary_msgs]),
np.max([x.xmax for x in complementary_msgs]), 50),
np.linspace(np.min([x.effective_support[0] for x in complementary_msgs]),
np.max([x.effective_support[1] for x in complementary_msgs]), 50),
]))
# As Lx (the product of child messages) does not include the node contribution this must
# be added to recover the full distribution of the parent node w/o contribution of the focal node.
complementary_msgs.append(self.merger_model.node_contribution(parent, time_points))
Expand Down
3 changes: 2 additions & 1 deletion treetime/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def x(self):
def y(self):
if self.is_delta:
print("Warning: evaluating log probability of a delta distribution.")
import ipdb; ipdb.set_trace()
return [self.weight]
else:
return self._peak_val + self._func.y
Expand Down Expand Up @@ -344,7 +345,7 @@ def calc_effective_support(self, cutoff=1e-15):

def _adjust_grid(self, rel_tol=0.01, yc=10):
n_iter=0
while len(self.y)>200 and n_iter<5:
while len(self.x)>200 and n_iter<5:
interp_err = 2*self.y[1:-1] - self.y[2:] - self.y[:-2]
ind = np.ones_like(self.y, dtype=bool)
dy = self.y-self.peak_val
Expand Down
19 changes: 11 additions & 8 deletions treetime/node_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,21 @@ def convolve_fft(cls, node_interp, branch_interp, fft_grid_size=FFT_FWHM_GRID_SI
b_support_range = b_effsupport[1]-b_effsupport[0]
n_support_range = n_effsupport[1]-n_effsupport[0]
#ratio = n_support_range/b_support_range
ratio = node_interp.fwhm/branch_interp.fwhm
ratio = n_support_range/branch_interp.fwhm

if ratio < 1/fft_grid_size and 4*dt > node_interp.fwhm:
## node distribution is much narrower than the branch distribution, proceed as if node distribution is
## a delta distribution
if ratio < 1/fft_grid_size and 4.0*dt > node_interp.fwhm:
## node distribution is much narrower than the branch distribution, proceed as if
# node distribution is a delta distribution with the peak 4 full-width-half-maxima
# away from the nominal peak to avoid slicing the relevant range to zero
log_scale_node_interp = node_interp.integrate(return_log=True, a=node_interp.xmin,b=node_interp.xmax,n=max(100, len(node_interp.x))) #probability of node distribution
if inverse_time:
x = branch_interp.x + node_interp._peak_pos
dist = Distribution(x, branch_interp(x - node_interp._peak_pos) - log_scale_node_interp, min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
x = branch_interp.x + max(n_effsupport[0], node_interp._peak_pos - 4.0*node_interp.fwhm)
dist = Distribution(x, branch_interp(x - node_interp._peak_pos) - log_scale_node_interp,
min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
else:
x = - branch_interp.x + node_interp._peak_pos
dist = Distribution(x, branch_interp(branch_interp.x) - log_scale_node_interp, min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
x = - branch_interp.x + min(n_effsupport[1], node_interp._peak_pos + 4.0*node_interp.fwhm)
dist = Distribution(x, branch_interp(branch_interp.x) - log_scale_node_interp,
min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
return dist
elif ratio > fft_grid_size and 4*dt > branch_interp.fwhm:
raise ValueError("ERROR: Unexpected behavior: branch distribution is much narrower than the node distribution.")
Expand Down
3 changes: 1 addition & 2 deletions treetime/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def timetree(params):
if params.aln is None and params.sequence_length is None:
print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr)
return 1
print(f"rng_seed: {params.rng_seed}")

myTree = TreeTime(dates=dates, tree=params.tree, ref=ref,
aln=aln, gtr=gtr, seq_len=params.sequence_length,
verbose=params.verbose, fill_overhangs=not params.keep_overhangs,
Expand Down Expand Up @@ -379,7 +379,6 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
if hasattr(params, 'stochastic_resolve'):
stochastic_resolve = params.stochastic_resolve
else: stochastic_resolve = False
print(f"stochastic_resolve: {stochastic_resolve}")

# determine whether confidence intervals are to be computed and how the
# uncertainty in the rate estimate should be treated
Expand Down

0 comments on commit 23cd008

Please sign in to comment.