|
10 | 10 | from torchtree.evolution.alignment import Alignment, Sequence
|
11 | 11 | from torchtree.evolution.coalescent import ConstantCoalescent
|
12 | 12 | from torchtree.evolution.datatype import NucleotideDataType
|
| 13 | +from torchtree.evolution.io import read_tree, read_tree_and_alignment |
13 | 14 | from torchtree.evolution.site_pattern import compress_alignment
|
14 | 15 | from torchtree.evolution.substitution_model import JC69
|
15 | 16 | from torchtree.evolution.taxa import Taxa, Taxon
|
|
21 | 22 | ReparameterizedTimeTreeModel,
|
22 | 23 | heights_from_branch_lengths,
|
23 | 24 | )
|
24 |
| -from torchtree.io import read_tree, read_tree_and_alignment |
25 | 25 |
|
26 | 26 |
|
27 | 27 | def benchmark(f):
|
@@ -473,10 +473,13 @@ def fn_grad(ratios_root_height):
|
473 | 473 | def ratio_transform(args):
|
474 | 474 | replicates = args.replicates
|
475 | 475 | tree = read_tree(args.tree, True, True)
|
| 476 | + taxa_count = len(tree.taxon_namespace) |
476 | 477 | taxa = []
|
477 | 478 | for node in tree.leaf_node_iter():
|
478 | 479 | taxa.append(Taxon(node.label, {'date': node.date}))
|
479 |
| - ratios_root_height = Parameter("internal_heights", torch.tensor([0.5] * 67 + [10])) |
| 480 | + ratios_root_height = Parameter( |
| 481 | + "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10]) |
| 482 | + ) |
480 | 483 | tree_model = ReparameterizedTimeTreeModel(
|
481 | 484 | "tree", tree, Taxa('taxa', taxa), ratios_root_height
|
482 | 485 | )
|
@@ -632,11 +635,12 @@ def fn3_grad_jit(ratios_root_height):
|
632 | 635 |
|
633 | 636 | def constant_coalescent(args):
|
634 | 637 | tree = read_tree(args.tree, True, True)
|
| 638 | + taxa_count = len(tree.taxon_namespace) |
635 | 639 | taxa = []
|
636 | 640 | for node in tree.leaf_node_iter():
|
637 | 641 | taxa.append(Taxon(node.label, {'date': node.date}))
|
638 | 642 | ratios_root_height = Parameter(
|
639 |
| - "internal_heights", torch.tensor([0.5] * 67 + [20.0]) |
| 643 | + "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0]) |
640 | 644 | )
|
641 | 645 | tree_model = ReparameterizedTimeTreeModel(
|
642 | 646 | "tree", tree, Taxa('taxa', taxa), ratios_root_height
|
@@ -743,7 +747,7 @@ def fn3_grad(tree_model, ratios_root_height, pop_size):
|
743 | 747 | return log_p
|
744 | 748 |
|
745 | 749 | x, counts = torch.unique(tree_model.sampling_times, return_counts=True)
|
746 |
| - counts = torch.cat((counts, torch.tensor([-1] * 68))) |
| 750 | + counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1)))) |
747 | 751 |
|
748 | 752 | with torch.no_grad():
|
749 | 753 | total_time, log_p = fn3(args.replicates, tree_model, pop_size)
|
|
0 commit comments