Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 7, 2024
1 parent 61ab683 commit 01cd6bb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
15 changes: 13 additions & 2 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations, petab_simulate
from amici.jax.petab import (
JAXProblem,
run_simulations,
petab_simulate,
ReturnValue,
)
from amici.jax.model import JAXModel

warn(
Expand All @@ -18,4 +23,10 @@
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations", "petab_simulate"]
__all__ = [
"JAXModel",
"JAXProblem",
"run_simulations",
"petab_simulate",
"ReturnValue",
]
4 changes: 3 additions & 1 deletion python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def run_simulations(
**DEFAULT_CONTROLLER_SETTINGS
),
max_steps: int = 2**10,
ret: ReturnValue = ReturnValue.llh,
ret: ReturnValue | str = ReturnValue.llh,
):
"""
Run simulations for a problem.
Expand All @@ -582,6 +582,8 @@ def run_simulations(
:return:
Overall output value and condition specific results and statistics.
"""
ret = ReturnValue[ret]

if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem
from amici.jax import JAXProblem, ReturnValue
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

Expand Down Expand Up @@ -208,7 +208,7 @@ def check_fields_jax(
okwargs = kwargs | {
"adjoint": diffrax.DirectAdjoint(),
"max_steps": 2**8,
"ret": output,
"ret": ReturnValue[output],
}
if sensi_order == amici.SensitivityOrder.none:
r_jax[output] = fun(p, **okwargs)[0]
Expand Down

0 comments on commit 01cd6bb

Please sign in to comment.