Skip to content

Commit 09f087c

Browse files
committed
Fix classifiers in setup.cfg
1 parent 016a555 commit 09f087c

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

benchmarks/benchmark.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torchtree.evolution.alignment import Alignment, Sequence
1111
from torchtree.evolution.coalescent import ConstantCoalescent
1212
from torchtree.evolution.datatype import NucleotideDataType
13+
from torchtree.evolution.io import read_tree, read_tree_and_alignment
1314
from torchtree.evolution.site_pattern import compress_alignment
1415
from torchtree.evolution.substitution_model import JC69
1516
from torchtree.evolution.taxa import Taxa, Taxon
@@ -21,7 +22,6 @@
2122
ReparameterizedTimeTreeModel,
2223
heights_from_branch_lengths,
2324
)
24-
from torchtree.io import read_tree, read_tree_and_alignment
2525

2626

2727
def benchmark(f):
@@ -473,10 +473,13 @@ def fn_grad(ratios_root_height):
473473
def ratio_transform(args):
474474
replicates = args.replicates
475475
tree = read_tree(args.tree, True, True)
476+
taxa_count = len(tree.taxon_namespace)
476477
taxa = []
477478
for node in tree.leaf_node_iter():
478479
taxa.append(Taxon(node.label, {'date': node.date}))
479-
ratios_root_height = Parameter("internal_heights", torch.tensor([0.5] * 67 + [10]))
480+
ratios_root_height = Parameter(
481+
"internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
482+
)
480483
tree_model = ReparameterizedTimeTreeModel(
481484
"tree", tree, Taxa('taxa', taxa), ratios_root_height
482485
)
@@ -632,11 +635,12 @@ def fn3_grad_jit(ratios_root_height):
632635

633636
def constant_coalescent(args):
634637
tree = read_tree(args.tree, True, True)
638+
taxa_count = len(tree.taxon_namespace)
635639
taxa = []
636640
for node in tree.leaf_node_iter():
637641
taxa.append(Taxon(node.label, {'date': node.date}))
638642
ratios_root_height = Parameter(
639-
"internal_heights", torch.tensor([0.5] * 67 + [20.0])
643+
"internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0])
640644
)
641645
tree_model = ReparameterizedTimeTreeModel(
642646
"tree", tree, Taxa('taxa', taxa), ratios_root_height
@@ -743,7 +747,7 @@ def fn3_grad(tree_model, ratios_root_height, pop_size):
743747
return log_p
744748

745749
x, counts = torch.unique(tree_model.sampling_times, return_counts=True)
746-
counts = torch.cat((counts, torch.tensor([-1] * 68)))
750+
counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1))))
747751

748752
with torch.no_grad():
749753
total_time, log_p = fn3(args.replicates, tree_model, pop_size)

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ long_description = file: README.md
1010
license = GPL3
1111
license_file = LICENSE
1212
classifiers =
13-
Intended Audience :: Science/Research",
13+
Intended Audience :: Science/Research
1414
License :: OSI Approved :: GNU General Public License v3 (GPLv3)
1515
Operating System :: OS Independent
16-
Programming Language :: Python :: 3.5
1716
Programming Language :: Python :: 3.6
1817
Programming Language :: Python :: 3.7
1918
Programming Language :: Python :: 3.8
19+
Programming Language :: Python :: 3.9
2020
Topic :: Scientific/Engineering :: Bio-Informatics
2121

2222
[options]

0 commit comments

Comments
 (0)