Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more strict condition for fall back on delta function convolution #235

Merged
merged 2 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* move most function related to IO of the command line wrappers into a separate file.
* make TreeTime own its random number generator and add `--rgn-seed` to control state in CLI. See [PR #234](https://github.com/neherlab/treetime/pull/234)
* add flag `--greedy-resolve` as inverse to `--stochastic-resolve` with the aim of switching the two.
* tighten conditions that trigger approximation of narrow distribution as a delta function in convolution using FFT [PR #235](https://github.com/neherlab/treetime/pull/235).

# 0.9.6: bug fixes and new mode of polytomy resolution
* in cases when very large polytomies are resolved, the multiplication of the discretized message results in messages/distributions of length 1. This resulted in an error, since interpolation objects require at least two points. This is now caught and a small discrete grid created.
Expand Down
9 changes: 7 additions & 2 deletions treetime/clock_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,13 @@ def _cleanup():
if hasattr(self, 'merger_model') and self.merger_model:
time_points = parent.marginal_pos_LH.x
if len(time_points)<5:
time_points = np.linspace(np.min([x.xmin for x in complementary_msgs]),
np.max([x.xmax for x in complementary_msgs]), 50)
time_points = np.unique(np.concatenate([
time_points,
np.linspace(np.min([x.xmin for x in complementary_msgs]),
np.max([x.xmax for x in complementary_msgs]), 50),
np.linspace(np.min([x.effective_support[0] for x in complementary_msgs]),
np.max([x.effective_support[1] for x in complementary_msgs]), 50),
]))
# As Lx (the product of child messages) does not include the node contribution this must
# be added to recover the full distribution of the parent node w/o contribution of the focal node.
complementary_msgs.append(self.merger_model.node_contribution(parent, time_points))
Expand Down
3 changes: 2 additions & 1 deletion treetime/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def x(self):
def y(self):
if self.is_delta:
print("Warning: evaluating log probability of a delta distribution.")
import ipdb; ipdb.set_trace()
return [self.weight]
else:
return self._peak_val + self._func.y
Expand Down Expand Up @@ -344,7 +345,7 @@ def calc_effective_support(self, cutoff=1e-15):

def _adjust_grid(self, rel_tol=0.01, yc=10):
n_iter=0
while len(self.y)>200 and n_iter<5:
while len(self.x)>200 and n_iter<5:
interp_err = 2*self.y[1:-1] - self.y[2:] - self.y[:-2]
ind = np.ones_like(self.y, dtype=bool)
dy = self.y-self.peak_val
Expand Down
19 changes: 11 additions & 8 deletions treetime/node_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,21 @@ def convolve_fft(cls, node_interp, branch_interp, fft_grid_size=FFT_FWHM_GRID_SI
b_support_range = b_effsupport[1]-b_effsupport[0]
n_support_range = n_effsupport[1]-n_effsupport[0]
#ratio = n_support_range/b_support_range
ratio = node_interp.fwhm/branch_interp.fwhm
ratio = n_support_range/branch_interp.fwhm

if ratio < 1/fft_grid_size and 4*dt > node_interp.fwhm:
## node distribution is much narrower than the branch distribution, proceed as if node distribution is
## a delta distribution
if ratio < 1/fft_grid_size and 4.0*dt > node_interp.fwhm:
## node distribution is much narrower than the branch distribution, proceed as if
# node distribution is a delta distribution with the peak 4 full-width-half-maxima
# away from the nominal peak to avoid slicing the relevant range to zero
log_scale_node_interp = node_interp.integrate(return_log=True, a=node_interp.xmin,b=node_interp.xmax,n=max(100, len(node_interp.x))) #probability of node distribution
if inverse_time:
x = branch_interp.x + node_interp._peak_pos
dist = Distribution(x, branch_interp(x - node_interp._peak_pos) - log_scale_node_interp, min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
x = branch_interp.x + max(n_effsupport[0], node_interp._peak_pos - 4.0*node_interp.fwhm)
dist = Distribution(x, branch_interp(x - node_interp._peak_pos) - log_scale_node_interp,
min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
else:
x = - branch_interp.x + node_interp._peak_pos
dist = Distribution(x, branch_interp(branch_interp.x) - log_scale_node_interp, min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
x = - branch_interp.x + min(n_effsupport[1], node_interp._peak_pos + 4.0*node_interp.fwhm)
dist = Distribution(x, branch_interp(branch_interp.x) - log_scale_node_interp,
min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
return dist
elif ratio > fft_grid_size and 4*dt > branch_interp.fwhm:
raise ValueError("ERROR: Unexpected behavior: branch distribution is much narrower than the node distribution.")
Expand Down
3 changes: 1 addition & 2 deletions treetime/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def timetree(params):
if params.aln is None and params.sequence_length is None:
print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr)
return 1
print(f"rng_seed: {params.rng_seed}")

myTree = TreeTime(dates=dates, tree=params.tree, ref=ref,
aln=aln, gtr=gtr, seq_len=params.sequence_length,
verbose=params.verbose, fill_overhangs=not params.keep_overhangs,
Expand Down Expand Up @@ -379,7 +379,6 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
if hasattr(params, 'stochastic_resolve'):
stochastic_resolve = params.stochastic_resolve
else: stochastic_resolve = False
print(f"stochastic_resolve: {stochastic_resolve}")

# determine whether confidence intervals are to be computed and how the
# uncertainty in the rate estimate should be treated
Expand Down