Skip to content

Commit

Permalink
fix net test-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 3, 2024
1 parent b8632f1 commit 031b524
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 81 deletions.
38 changes: 5 additions & 33 deletions python/sdist/amici/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,13 @@ def _process_argval(v):

def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
layer_map = {

Check warning on line 91 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L90-L91

Added lines #L90 - L91 were not covered by tests
"InstanceNorm1d": "eqx.nn.LayerNorm",
"InstanceNorm2d": "eqx.nn.LayerNorm",
"InstanceNorm3d": "eqx.nn.LayerNorm",
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
"Flatten": "amici.jax.nn.Flatten",
}
if layer.layer_type.startswith(("BatchNorm", "AlphaDropout")):
if layer.layer_type.startswith(

Check warning on line 96 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L96

Added line #L96 was not covered by tests
("BatchNorm", "AlphaDropout", "InstanceNorm")
):
raise NotImplementedError(
f"{layer.layer_type} layers currently not supported"
)
Expand All @@ -117,30 +116,12 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
"Conv2d": {
"bias": "use_bias",
},
"InstanceNorm1d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
"InstanceNorm2d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
"InstanceNorm3d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
"LayerNorm": {
"affine": "elementwise_affine",
"normalized_shape": "shape",
},
}
kwarg_ignore = {

Check warning on line 124 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L124

Added line #L124 was not covered by tests
"InstanceNorm1d": ("track_running_stats", "momentum"),
"InstanceNorm2d": ("track_running_stats", "momentum"),
"InstanceNorm3d": ("track_running_stats", "momentum"),
"BatchNorm1d": ("track_running_stats", "momentum"),
"BatchNorm2d": ("track_running_stats", "momentum"),
"BatchNorm3d": ("track_running_stats", "momentum"),
"Dropout1d": ("inplace",),
"Dropout2d": ("inplace",),
}
Expand All @@ -162,13 +143,6 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
kwargs += [f"key=keys[{ilayer}]"]
type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}")
layer_str = f"{type_str}({', '.join(kwargs)})"
if layer.layer_type.startswith(("InstanceNorm",)):
if layer.layer_type.endswith(("1d", "2d", "3d")):
layer_str = f"jax.vmap({layer_str}, in_axes=1, out_axes=1)"
if layer.layer_type.endswith(("2d", "3d")):
layer_str = f"jax.vmap({layer_str}, in_axes=2, out_axes=2)"
if layer.layer_type.endswith("3d"):
layer_str = f"jax.vmap({layer_str}, in_axes=3, out_axes=3)"
return f"{' ' * indent}'{layer.layer_id}': {layer_str}"

Check warning on line 146 in python/sdist/amici/jax/nn.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/nn.py#L143-L146

Added lines #L143 - L146 were not covered by tests


Expand All @@ -179,10 +153,8 @@ def _generate_forward(node: Node, indent, layer_type=str) -> str:

if node.op == "call_module":
fun_str = f"self.layers['{node.target}']"
if layer_type.startswith(
("InstanceNorm", "Conv", "Linear", "LayerNorm")
):
if layer_type in ("LayerNorm", "InstanceNorm"):
if layer_type.startswith(("Conv", "Linear", "LayerNorm")):
if layer_type in ("LayerNorm",):
dims = f"len({fun_str}.shape)+1"
if layer_type == "Linear":
dims = 2
Expand Down
130 changes: 82 additions & 48 deletions tests/sciml/test_sciml.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,26 @@ def change_directory(destination):
cases_dir = Path(__file__).parent / "testsuite" / "test_cases"


def _reshape_flat_array(array_flat):
array_flat["ix"] = array_flat["ix"].astype(str)
ix_cols = [
f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";")))
]
if len(ix_cols) == 1:
array_flat[ix_cols[0]] = array_flat["ix"].apply(int)
else:
array_flat[ix_cols] = pd.DataFrame(
array_flat["ix"].str.split(";").apply(np.array).to_list(),
index=array_flat.index,
).astype(int)
array_flat.sort_values(by=ix_cols, inplace=True)
array_shape = tuple(array_flat[ix_cols].max().astype(int) + 1)
array = np.array(array_flat["value"].values).reshape(array_shape)
return array


@pytest.mark.parametrize(
"test", [d.stem for d in cases_dir.glob("net_[0-9]*")]
"test", sorted([d.stem for d in cases_dir.glob("net_[0-9]*")])
)
def test_net(test):
test_dir = cases_dir / test
Expand All @@ -59,17 +77,20 @@ def test_net(test):
for ml_model in ml_models.models:
module_dir = outdir / f"{ml_model.mlmodel_id}.py"
if test in (
"net_022",
"net_002",
"net_045",
"net_042",
"net_009",
"net_018",
"net_019",
"net_020",
"net_021",
"net_022",
"net_042",
"net_043",
"net_044",
"net_021",
"net_019",
"net_002",
"net_045",
"net_046",
"net_047",
"net_048",
):
with pytest.raises(NotImplementedError):
generate_equinox(ml_model, module_dir)
Expand All @@ -84,38 +105,14 @@ def test_net(test):
solutions.get("net_ps", solutions["net_input"]),
solutions["net_output"],
):
input_flat = pd.read_csv(test_dir / input_file, sep="\t").sort_values(
by="ix"
)
input_shape = tuple(
np.stack(
input_flat["ix"].astype(str).str.split(";").apply(np.array)
)
.astype(int)
.max(axis=0)
+ 1
)
input = jnp.array(input_flat["value"].values).reshape(input_shape)

output_flat = pd.read_csv(
test_dir / output_file, sep="\t"
).sort_values(by="ix")
output_shape = tuple(
np.stack(
output_flat["ix"].astype(str).str.split(";").apply(np.array)
)
.astype(int)
.max(axis=0)
+ 1
)
output = jnp.array(output_flat["value"].values).reshape(output_shape)
input_flat = pd.read_csv(test_dir / input_file, sep="\t")
input = _reshape_flat_array(input_flat)

output_flat = pd.read_csv(test_dir / output_file, sep="\t")
output = _reshape_flat_array(output_flat)

if "net_ps" in solutions:
par = (
pd.read_csv(test_dir / par_file, sep="\t")
.set_index("parameterId")
.sort_index()
)
par = pd.read_csv(test_dir / par_file, sep="\t")
for ml_model in ml_models.models:
net = nets[ml_model.mlmodel_id](jr.PRNGKey(0))
for layer in net.layers.keys():
Expand All @@ -126,32 +123,67 @@ def test_net(test):
and net.layers[layer].weight is not None
):
prefix = layer_prefix + "_weight"
df = par[
par[petab.PARAMETER_ID].str.startswith(prefix)
]
df["ix"] = (
df[petab.PARAMETER_ID]
.str.split("_")
.str[3:]
.apply(lambda x: ";".join(x))
)
w = _reshape_flat_array(df)
if isinstance(net.layers[layer], eqx.nn.ConvTranspose):
# see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
w = np.flip(
w, axis=tuple(range(2, w.ndim))
).swapaxes(0, 1)
assert w.shape == net.layers[layer].weight.shape
net = eqx.tree_at(
lambda x: x.layers[layer].weight,
net,
jnp.array(
par[par.index.str.startswith(prefix)][
"value"
].values
).reshape(net.layers[layer].weight.shape),
jnp.array(w),
)
if (
isinstance(net.layers[layer], eqx.Module)
and hasattr(net.layers[layer], "bias")
and net.layers[layer].bias is not None
):
prefix = layer_prefix + "_bias"
df = par[
par[petab.PARAMETER_ID].str.startswith(prefix)
]
df["ix"] = (
df[petab.PARAMETER_ID]
.str.split("_")
.str[3:]
.apply(lambda x: ";".join(x))
)
b = _reshape_flat_array(df)
if isinstance(
net.layers[layer],
eqx.nn.Conv | eqx.nn.ConvTranspose,
):
b = np.expand_dims(
b,
tuple(
range(
1,
net.layers[layer].num_spatial_dims + 1,
)
),
)
assert b.shape == net.layers[layer].bias.shape
net = eqx.tree_at(
lambda x: x.layers[layer].bias,
net,
jnp.array(
par[par.index.str.startswith(prefix)][
"value"
].values
).reshape(net.layers[layer].bias.shape),
jnp.array(b),
)
net = eqx.nn.inference_mode(net)

if test == "net_004_alt":
return # skipping, no support for non-cross-correlation in equinox

np.testing.assert_allclose(
net.forward(input),
output,
Expand All @@ -160,7 +192,9 @@ def test_net(test):
)


@pytest.mark.parametrize("test", [d.stem for d in cases_dir.glob("[0-9]*")])
@pytest.mark.parametrize(
"test", sorted([d.stem for d in cases_dir.glob("[0-9]*")])
)
def test_ude(test):
test_dir = cases_dir / test
with open(test_dir / "petab" / "problem_ude.yaml") as f:
Expand Down

0 comments on commit 031b524

Please sign in to comment.