diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 7fe21257a9..e3677af346 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -378,6 +378,7 @@ " iy_trafos=jnp.array(iy_trafos),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", + " steady_state_event=diffrax.steady_state_event(),\n", " max_steps=2**10,\n", " adjoint=diffrax.DirectAdjoint(),\n", " ret=ReturnValue.y, # Return observables\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index bf93217153..8b2c09fcc6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -12,6 +12,8 @@ import jax import jaxtyping as jt +from collections.abc import Callable + class ReturnValue(enum.Enum): llh = "log-likelihood" @@ -32,6 +34,13 @@ class JAXModel(eqx.Module): JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by classes inheriting from JAXModel. + + :ivar api_version: + API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION). + :ivar MODEL_API_VERSION: + API version of the base class. + :ivar jax_py_file: + Path to the JAX model file. """ MODEL_API_VERSION = "0.0.2" @@ -248,6 +257,9 @@ def _eq( x0: jt.Float[jt.Array, "nxs"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]: """ @@ -278,10 +290,20 @@ def _eq( stepsize_controller=controller, max_steps=max_steps, adjoint=diffrax.DirectAdjoint(), - event=diffrax.Event(cond_fn=diffrax.steady_state_event()), + event=diffrax.Event( + cond_fn=steady_state_event, + ), throw=False, ) - return sol.ys[-1, :], sol.stats + # If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the + # solution is the last state and the event mask is False. In the latter case, we return inf for the steady + # state. + ys = jnp.where( + sol.event_mask, + sol.ys[-1, :], + jnp.inf * jnp.ones_like(sol.ys[-1, :]), + ) + return ys, sol.stats def _solve( self, @@ -450,6 +472,9 @@ def simulate_condition( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), @@ -525,7 +550,13 @@ def simulate_condition( # Post-equilibration if ts_posteq.shape[0]: x_solver, stats_posteq = self._eq( - p, tcl, x_solver, solver, controller, max_steps + p, + tcl, + x_solver, + solver, + controller, + steady_state_event, + max_steps, ) else: stats_posteq = None @@ -596,6 +627,9 @@ def preequilibrate_condition( mask_reinit: jt.Bool[jt.Array, "*nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: r""" @@ -603,6 +637,10 @@ def preequilibrate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param solver: ODE solver :param controller: @@ -619,7 +657,13 @@ def preequilibrate_condition( tcl = self._tcl(x0, p) current_x = self._x_solver(x0) current_x, stats_preeq = self._eq( - p, tcl, current_x, solver, controller, max_steps + p, + tcl, + current_x, + solver, + controller, + steady_state_event, + max_steps, ) return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 6a7da4b42f..80acc9969a 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -3,6 +3,7 @@ from numbers import Number from collections.abc import Iterable from pathlib import Path +from collections.abc import Callable import diffrax @@ -465,6 +466,9 @@ def run_simulation( simulation_condition: tuple[str, ...], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ret: ReturnValue = ReturnValue.llh, @@ -507,6 +511,7 @@ def run_simulation( solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, adjoint=diffrax.RecursiveCheckpointAdjoint() if ret in (ReturnValue.llh, ReturnValue.chi2) else diffrax.DirectAdjoint(), @@ -518,6 +523,9 @@ def run_preequilibration( simulation_condition: str, solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ @@ -539,12 +547,13 @@ def run_preequilibration( simulation_condition, p ) return self.model.preequilibrate_condition( - p=eqx.debug.backward_nan(p), + p=p, mask_reinit=mask_reinit, x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, ) @@ -555,6 +564,9 @@ def run_simulations( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ret: ReturnValue | str = ReturnValue.llh, ): @@ -569,6 +581,9 @@ def run_simulations( ODE solver to use for simulation. :param controller: Step size controller to use for simulation. + :param steady_state_event: + Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation. :param ret: @@ -583,7 +598,9 @@ def run_simulations( simulation_conditions = problem.get_all_simulation_conditions() preeqs = { - sc: problem.run_preequilibration(sc, solver, controller, max_steps) + sc: problem.run_preequilibration( + sc, solver, controller, steady_state_event, max_steps + ) # only run preequilibration once per condition for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} } @@ -593,6 +610,7 @@ def run_simulations( sc, solver, controller, + steady_state_event, max_steps, preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), ret=ret, @@ -617,6 +635,9 @@ def petab_simulate( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ): """ @@ -637,6 +658,7 @@ def petab_simulate( problem, solver=solver, controller=controller, + steady_state_event=steady_state_event, max_steps=max_steps, ret=ReturnValue.y, ) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 672982686c..78fa026cfc 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -11,11 +11,12 @@ import diffrax import numpy as np from beartype import beartype +from petab.v1.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem -from amici.jax import JAXProblem, ReturnValue +from amici.jax import JAXProblem, ReturnValue, run_simulations from numpy.testing import assert_allclose from test_petab_objective import lotka_volterra # noqa: F401 @@ -198,6 +199,7 @@ def check_fields_jax( "solver": diffrax.Kvaerno5(), "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), "adjoint": diffrax.RecursiveCheckpointAdjoint(), + "steady_state_event": diffrax.steady_state_event(), "max_steps": 2**8, # max_steps } fun = beartype(jax_model.simulate_condition) @@ -266,6 +268,28 @@ def check_fields_jax( ) +def test_preequilibration_failure(lotka_volterra): # noqa: F811 + petab_problem = lotka_volterra + # oscillating system, preequilibation should fail when interaction is active + with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert not np.isinf(r[0].item()) + petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( + petab_problem.measurement_df[SIMULATION_CONDITION_ID] + ) + with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert np.isinf(r[0].item()) + + @skip_on_valgrind def test_serialisation(lotka_volterra): # noqa: F811 petab_problem = lotka_volterra diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 4a63d8bfda..d9f836b0b4 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -7,6 +7,7 @@ from functools import partial from pathlib import Path + import fiddy import amici import numpy as np @@ -342,8 +343,9 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - llh_jax, _ = beartype(run_simulations)(jax_problem) + if problem_id in problems_for_gradient_check: + beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True )(jax_problem) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 5fe61adcf2..4fcbe0b631 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -4,6 +4,8 @@ import logging import sys +import diffrax + import amici import pandas as pd import petab.v1 as petab @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax): if jax: from amici.jax import JAXProblem, run_simulations, petab_simulate + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) jax_problem = JAXProblem(model, problem) - llh, ret = run_simulations(jax_problem) - chi2, _ = run_simulations(jax_problem, ret="chi2") - simulation_df = petab_simulate(jax_problem) + llh, ret = run_simulations( + jax_problem, steady_state_event=steady_state_event + ) + chi2, _ = run_simulations( + jax_problem, ret="chi2", steady_state_event=steady_state_event + ) + simulation_df = petab_simulate( + jax_problem, steady_state_event=steady_state_event + ) simulation_df.rename( columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True )