Skip to content

Commit

Permalink
fixes & remove sciml dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 3, 2024
1 parent 031b524 commit cfb0b5a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
11 changes: 7 additions & 4 deletions python/sdist/amici/jax/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from petab_sciml import MLModel, Layer, Node

import equinox as eqx
import jax.numpy as jnp

Expand Down Expand Up @@ -30,7 +30,10 @@ def tanhshrink(x: jnp.ndarray) -> jnp.ndarray:
return x - jnp.tanh(x)

Check warning on line 30 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L29-L30

Added lines #L29 - L30 were not covered by tests


def generate_equinox(ml_model: MLModel, filename: Path | str):
def generate_equinox(ml_model: "MLModel", filename: Path | str): # noqa: F821

Check warning on line 33 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L33

Added line #L33 was not covered by tests
# TODO: move to top level import and replace forward type definitions
from petab_sciml import Layer

filename = Path(filename)
layer_indent = 12
node_indent = 8

Check warning on line 39 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L37-L39

Added lines #L37 - L39 were not covered by tests
Expand Down Expand Up @@ -87,7 +90,7 @@ def _process_argval(v):
return str(v)

Check warning on line 90 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L85-L90

Added lines #L85 - L90 were not covered by tests


def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821
layer_map = {

Check warning on line 94 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L93-L94

Added lines #L93 - L94 were not covered by tests
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
Expand Down Expand Up @@ -146,7 +149,7 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
return f"{' ' * indent}'{layer.layer_id}': {layer_str}"

Check warning on line 149 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L146-L149

Added lines #L146 - L149 were not covered by tests


def _generate_forward(node: Node, indent, layer_type=str) -> str:
def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F821
if node.op == "placeholder":

Check warning on line 153 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L152-L153

Added lines #L152 - L153 were not covered by tests
# TODO: inconsistent target vs name
return f"{' ' * indent}{node.name} = input"

Check warning on line 155 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L155

Added line #L155 was not covered by tests
Expand Down
4 changes: 2 additions & 2 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
outdir: Path | str | None = None,
verbose: bool | int | None = False,
model_name: str | None = "model",
hybridisation: dict[str, str] = {},
hybridisation: dict[str, dict] = None,
):
"""
Generate AMICI jax files for the ODE provided to the constructor.
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(

self.model: DEModel = ode_model

self.hybridisation = hybridisation
self.hybridisation = hybridisation if hybridisation is not None else {}

Check warning on line 162 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L162

Added line #L162 was not covered by tests

self._code_printer = AmiciJaxCodePrinter()

Expand Down
5 changes: 3 additions & 2 deletions tests/sciml/test_sciml.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,9 @@ def test_ude(test):
# gradient
sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)(
jax_problem,
solver=diffrax.Tsit5(),
controller=diffrax.PIDController(atol=1e-10, rtol=1e-10),
solver=diffrax.Kvaerno5(),
controller=diffrax.PIDController(atol=1e-14, rtol=1e-14),
max_steps=2**16,
)
expected = (
pd.concat(
Expand Down

0 comments on commit cfb0b5a

Please sign in to comment.