Skip to content

Commit

Permalink
lint: fix NPY002 for examples/models/sir/sir_python.py
Browse files Browse the repository at this point in the history
Fixed NPY002 issues in SIR models implemented in examples/models/sir/sir_python.py.

Added tests to verify the new seed management works.
  • Loading branch information
marcofavorito authored and marcofavoritobi committed Sep 22, 2023
1 parent daf4a93 commit d7a3e84
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,5 @@ Java\ model/

examples/output

!tests/fixtures/data/test_sir_python.npy
!tests/fixtures/data/test_sir_python_w_breaks_python.npy
8 changes: 2 additions & 6 deletions examples/models/sir/sir_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ def SIR(theta: NDArray, N: int, seed: int | None) -> NDArray: # noqa: N802, N80
Returns:
simulated series
"""
np.random.seed(seed=seed)

num_agents = 100000
g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[5])

model = ep.SIRModel(g)
model = ep.SIRModel(g, seed=seed)

cfg = ModelConfig.Configuration()
cfg.add_model_parameter("beta", theta[3]) # infection rate
Expand Down Expand Up @@ -102,12 +100,10 @@ def SIR_w_breaks( # noqa: N802
Returns:
simulated series
"""
np.random.seed(seed=seed)

num_agents = 100000
g = nx.watts_strogatz_graph(num_agents, int(theta[0]), theta[1], seed=theta[11])

model = ep.SIRModel(g)
model = ep.SIRModel(g, seed=seed)

cfg = ModelConfig.Configuration()
cfg.add_model_parameter("beta", theta[3]) # infection rate
Expand Down
Binary file added tests/fixtures/data/test_sir_python.npy
Binary file not shown.
96 changes: 96 additions & 0 deletions tests/test_examples/test_sir_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Black-box ABM Calibration Kit (Black-it)
# Copyright (C) 2021-2023 Banca d'Italia
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Test the SIR model implementation in Python."""

import numpy as np
import pytest

try:
from examples.models.sir.sir_python import SIR, SIR_w_breaks
except ModuleNotFoundError as e:
pytest.skip(
f"skipping tests for SIR python models, reason: {str(e)}",
allow_module_level=True,
)

from tests.conftest import TEST_DIR


def test_sir() -> None:
"""Test the 'SIR' function in examples/models/sir/sir_python.py."""
expected_output = np.load(TEST_DIR / "fixtures" / "data" / "test_sir_python.npy")
model_seed = 0

lattice_order = 20
rewire_probability = 0.2
percentage_infected = 0.05
beta = 0.2
gamma = 0.15
networkx_seed = 0
theta = [
lattice_order,
rewire_probability,
percentage_infected,
beta,
gamma,
networkx_seed,
]

n = 100
output = SIR(theta, n, seed=model_seed)

assert np.isclose(output, expected_output).all()


def test_sir_w_breaks() -> None:
"""Test the 'SIR_w_breaks' function in examples/models/sir/sir_python.py."""
expected_output = np.load(
TEST_DIR / "fixtures" / "data" / "test_sir_w_breaks_python.npy",
)
model_seed = 0

lattice_order = 20
rewire_probability = 0.2
percentage_infected = 0.05
beta_1 = 0.2
gamma_1 = 0.15
beta_2 = 0.3
beta_3 = 0.1
beta_4 = 0.01
t_break_1 = 10
t_break_2 = 20
t_break_3 = 30
networkx_seed = 0
theta = [
lattice_order,
rewire_probability,
percentage_infected,
beta_1,
gamma_1,
beta_2,
beta_3,
beta_4,
t_break_1,
t_break_2,
t_break_3,
networkx_seed,
]

n = 100
output = SIR_w_breaks(theta, n, seed=model_seed)

assert np.isclose(output, expected_output).all()

0 comments on commit d7a3e84

Please sign in to comment.