Skip to content

Commit

Permalink
Make CoalescentModelInterface expect a TimeTreeModelInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
4ment committed Mar 10, 2024
1 parent 2216fc0 commit f34a571
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ def ratio_transform_jacobian(args):
tree_model()

@benchmark
def fn(internal_heights):
def fn():
return torch.tensor([tree_model.inst.transform_jacobian()])

@benchmark
def fn_grad(internal_heights):
def fn_grad():
return torch.tensor(tree_model.inst.gradient_transform_jacobian())

total_time, log_det_jac = fn(args.replicates, internal_heights)
total_time, log_det_jac = fn(args.replicates)
print(f" {args.replicates} evaluations: {total_time} ({log_det_jac})")

internal_heights.requires_grad = True
grad_total_time, grad_log_det_jac = fn_grad(args.replicates, internal_heights)
grad_total_time, grad_log_det_jac = fn_grad(args.replicates)
print(f" {args.replicates} gradient evaluations: {grad_total_time}")

if args.output:
Expand Down
2 changes: 1 addition & 1 deletion torchtree_physher/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "1.0.0-dev3"
__version__ = "1.0.0-dev4"
8 changes: 4 additions & 4 deletions torchtree_physher/physher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,20 @@ PYBIND11_MODULE(physher, m) {

py::class_<ConstantCoalescentModelInterface, CoalescentModelInterface>(
m, "ConstantCoalescentModel")
.def(py::init<double, TreeModelInterface *>());
.def(py::init<double, TimeTreeModelInterface *>());

py::class_<PiecewiseConstantCoalescentInterface, CoalescentModelInterface>(
m, "PiecewiseConstantCoalescentModel")
.def(py::init<const std::vector<double>, TreeModelInterface *>());
.def(py::init<const std::vector<double>, TimeTreeModelInterface *>());

py::class_<PiecewiseConstantCoalescentGridInterface,
CoalescentModelInterface>(m,
"PiecewiseConstantCoalescentGridModel")
.def(py::init<const std::vector<double>, TreeModelInterface *, double>());
.def(py::init<const std::vector<double>, TimeTreeModelInterface *, double>());

py::class_<PiecewiseLinearCoalescentGridInterface, CoalescentModelInterface>(
m, "PiecewiseLinearCoalescentGridModel")
.def(py::init<const std::vector<double>, TreeModelInterface *, double>());
.def(py::init<const std::vector<double>, TimeTreeModelInterface *, double>());

py::class_<CTMCScaleModelInterface, CallableModelInterface>(m,
"CTMCScaleModel")
Expand Down

0 comments on commit f34a571

Please sign in to comment.