From cfb0b5a5485f87a41e900c063f060507285436a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 3 Dec 2024 17:54:56 +0000 Subject: [PATCH] fixes & remove sciml dependency --- python/sdist/amici/jax/nn.py | 11 +++++++---- python/sdist/amici/jax/ode_export.py | 4 ++-- tests/sciml/test_sciml.py | 5 +++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index 343a749ea6..d503df2393 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -1,6 +1,6 @@ from pathlib import Path -from petab_sciml import MLModel, Layer, Node + import equinox as eqx import jax.numpy as jnp @@ -30,7 +30,10 @@ def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: return x - jnp.tanh(x) -def generate_equinox(ml_model: MLModel, filename: Path | str): +def generate_equinox(ml_model: "MLModel", filename: Path | str): # noqa: F821 + # TODO: move to top level import and replace forward type definitions + from petab_sciml import Layer + filename = Path(filename) layer_indent = 12 node_indent = 8 @@ -87,7 +90,7 @@ def _process_argval(v): return str(v) -def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str: +def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821 layer_map = { "Dropout1d": "eqx.nn.Dropout", "Dropout2d": "eqx.nn.Dropout", @@ -146,7 +149,7 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str: return f"{' ' * indent}'{layer.layer_id}': {layer_str}" -def _generate_forward(node: Node, indent, layer_type=str) -> str: +def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F821 if node.op == "placeholder": # TODO: inconsistent target vs name return f"{' ' * indent}{node.name} = input" diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 385bc65e07..f36f67ab85 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -130,7 +130,7 @@ def __init__( outdir: Path | str | None = None, verbose: bool | int | None = False, model_name: str | None = "model", - hybridisation: dict[str, str] = {}, + hybridisation: dict[str, dict] = None, ): """ Generate AMICI jax files for the ODE provided to the constructor. @@ -159,7 +159,7 @@ def __init__( self.model: DEModel = ode_model - self.hybridisation = hybridisation + self.hybridisation = hybridisation if hybridisation is not None else {} self._code_printer = AmiciJaxCodePrinter() diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 4986899fd1..e872718f48 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -327,8 +327,9 @@ def test_ude(test): # gradient sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( jax_problem, - solver=diffrax.Tsit5(), - controller=diffrax.PIDController(atol=1e-10, rtol=1e-10), + solver=diffrax.Kvaerno5(), + controller=diffrax.PIDController(atol=1e-14, rtol=1e-14), + max_steps=2**16, ) expected = ( pd.concat(