Skip to content

Commit

Permalink
test: add tests for the DependencyNotInstalled exception for xgboost-…
Browse files Browse the repository at this point in the history
…sampler and gp-sampler
  • Loading branch information
marcofavorito authored and marcofavoritobi committed Aug 24, 2023
1 parent b9a8834 commit bd5c2fe
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/test_samplers/test_gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import pytest
from numpy.typing import NDArray

import black_it
from black_it.samplers.gaussian_process import GaussianProcessSampler, _AcquisitionTypes
from black_it.search_space import SearchSpace
from tests.utils.base import no_gpy_installed, no_python311_for_gpy
from tests.utils.extras_tester import generic_test_import_error

pytestmark = [no_python311_for_gpy, no_gpy_installed] # noqa

Expand Down Expand Up @@ -137,3 +139,16 @@ def test_gaussian_process_sample_wrong_acquisition() -> None:
f"got {wrong_acquisition}",
):
GaussianProcessSampler(4, acquisition=wrong_acquisition)


def test_dependency_not_installed_error() -> None:
"""Test the "DependencyNotInstalled" exception in case the dependencies of the component are not installed."""
generic_test_import_error(
module_obj=black_it.samplers.gaussian_process,
import_error_global_variable_name="_GPY_IMPORT_ERROR",
component_initializer=lambda: GaussianProcessSampler(4),
expected_message_pattern=(
r"Cannot import package 'GPy', required by component GaussianProcessSampler\. "
r"To solve the issue, you can install the extra 'gp-sampler': pip install black-it\[gp-sampler\]"
),
)
15 changes: 15 additions & 0 deletions tests/test_samplers/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

import black_it
from black_it.calibrator import Calibrator
from black_it.loss_functions.msm import MethodOfMomentsLoss
from black_it.samplers.halton import HaltonSampler
Expand All @@ -26,6 +27,7 @@

from ..fixtures.test_models import BH4 # type: ignore
from ..utils.base import no_xgboost_installed
from ..utils.extras_tester import generic_test_import_error

pytestmark = no_xgboost_installed # noqa

Expand Down Expand Up @@ -136,3 +138,16 @@ def test_clip_losses() -> None:
assert (
y2 == np.array([0.0, MIN_FLOAT32 + EPS_FLOAT32, MAX_FLOAT32 - EPS_FLOAT32])
).all()


def test_dependency_not_installed_error() -> None:
"""Test the "DependencyNotInstalled" exception in case the dependencies of the component are not installed."""
generic_test_import_error(
module_obj=black_it.samplers.xgboost,
import_error_global_variable_name="_XGBOOST_IMPORT_ERROR",
component_initializer=lambda: XGBoostSampler(4),
expected_message_pattern=(
r"Cannot import package 'xgboost', required by component XGBoostSampler. "
r"To solve the issue, you can install the extra 'xgboost-sampler': pip install black-it\[xgboost-sampler\]"
),
)
73 changes: 73 additions & 0 deletions tests/utils/extras_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Black-box ABM Calibration Kit (Black-it)
# Copyright (C) 2021-2023 Banca d'Italia
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Test utilities to test the correct working of the extras mechanism."""
import contextlib
from types import ModuleType
from typing import Any, Callable, Generator

import pytest

from black_it._load_dependency import DependencyNotInstalled


@contextlib.contextmanager
def _change_variable_value(
module_obj: ModuleType, variable_name: str, value: Any
) -> Generator[None, None, None]:
"""Change, temporarily, the value of a variable in a module."""
old_value = getattr(module_obj, variable_name)
setattr(module_obj, variable_name, value)
yield
setattr(module_obj, variable_name, old_value)


def generic_test_import_error(
module_obj: ModuleType,
import_error_global_variable_name: str,
component_initializer: Callable,
expected_message_pattern: str,
) -> None:
"""
Test that the correct exception is raised when a dependency is not installed.
This function is an utility testing function to test that the correct exception is raised when a dependency is not
installed.
It assumes the module under testing, associated to some Black-it component (e.g. loss, sampler, etc.), has a global
variable named import_error_global_variable_name, which is set to None if the dependency is installed, or to a
DependencyNotInstalled exception if the dependency is not installed. This convention is the one used in the
Black-it code.
For example, samplers.xgboost has a global variable named _XGBOOST_IMPORT_ERROR, which is set to None if the xgboost
package is installed, or to an DependencyNotInstalled if the xgboost package is not installed. Then, during the
initialization of XGBoostSampler, the value of the _XGBOOST_IMPORT_ERROR variable is checked, and if it is not
None, an exception is raised.
This test function checks that the correct exception is raised when the import_error_global_variable_name variable
is set to None and we try to initialize the Black-it component.
Args:
module_obj: the module object to test
import_error_global_variable_name: the name of the global variable in the module object
component_initializer: the function to call to initialize the component under testing
expected_message_pattern: the pattern of the expected exception message
"""
with _change_variable_value(
module_obj, import_error_global_variable_name, ImportError("fake error")
):
with pytest.raises(DependencyNotInstalled, match=expected_message_pattern):
component_initializer()

0 comments on commit bd5c2fe

Please sign in to comment.