-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_sir_python.py
96 lines (82 loc) · 2.51 KB
/
test_sir_python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Black-box ABM Calibration Kit (Black-it)
# Copyright (C) 2021-2024 Banca d'Italia
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Test the SIR model implementation in Python."""
import numpy as np
import pytest
try:
from examples.models.sir.sir_python import SIR, SIR_w_breaks
except ModuleNotFoundError as e:
pytest.skip(
f"skipping tests for SIR python models, reason: {e!s}",
allow_module_level=True,
)
from tests.conftest import TEST_DIR
def test_sir() -> None:
"""Test the 'SIR' function in examples/models/sir/sir_python.py."""
expected_output = np.load(TEST_DIR / "fixtures" / "data" / "test_sir_python.npy")
model_seed = 0
lattice_order = 20
rewire_probability = 0.2
percentage_infected = 0.05
beta = 0.2
gamma = 0.15
networkx_seed = 0
theta = [
lattice_order,
rewire_probability,
percentage_infected,
beta,
gamma,
networkx_seed,
]
n = 100
output = SIR(theta, n, seed=model_seed)
assert np.isclose(output, expected_output).all()
def test_sir_w_breaks() -> None:
"""Test the 'SIR_w_breaks' function in examples/models/sir/sir_python.py."""
expected_output = np.load(
TEST_DIR / "fixtures" / "data" / "test_sir_w_breaks_python.npy",
)
model_seed = 0
lattice_order = 20
rewire_probability = 0.2
percentage_infected = 0.05
beta_1 = 0.2
gamma_1 = 0.15
beta_2 = 0.3
beta_3 = 0.1
beta_4 = 0.01
t_break_1 = 10
t_break_2 = 20
t_break_3 = 30
networkx_seed = 0
theta = [
lattice_order,
rewire_probability,
percentage_infected,
beta_1,
gamma_1,
beta_2,
beta_3,
beta_4,
t_break_1,
t_break_2,
t_break_3,
networkx_seed,
]
n = 100
output = SIR_w_breaks(theta, n, seed=model_seed)
assert np.isclose(output, expected_output).all()