From 3fc43e582687c6b0954460d40f726dce05bc7d56 Mon Sep 17 00:00:00 2001 From: Marco Favorito Date: Thu, 16 Mar 2023 20:04:11 +0100 Subject: [PATCH] feat: add extras mechanism to finer-grained dependency selection Inspired by: https://github.com/bancaditalia/black-it/issues/36#issuecomment-1472327834. This commit includes the extra-dependencies mechanism of setuptools to overcome limitations specific to certain dependencies (e.g. no support for some Python interpreter versions). The changes use the following conventions for extras names: - `[all]`: install all dependencies from all extras - `[X-sampler]`: install all dependencies to make X sampler to work - `[X-loss]`: install all dependencies to make X loss function to work. We do not have yet an example for the last item for the moment; but for "forward-compatibility" of the nomenclature, we leave the -sampler suffix. E.g. for GPy, we could have the extra called gp-sampler, that installs GPy on-demand, and not installed if not needed by the user. This commit also includes a mechanism to handle import errors for the non-installed dependencies for some component. Such mechanism provides a useful message to the user, e.g. it raises an exception with a useful error message pointing out to the missing extra in its local installation of black-it. --- README.md | 21 +++- black_it/_load_dependency.py | 112 +++++++++++++++++++ black_it/samplers/gaussian_process.py | 25 ++++- black_it/samplers/xgboost.py | 25 ++++- poetry.lock | 7 +- pyproject.toml | 8 ++ tests/test_calibrator.py | 4 + tests/test_samplers/test_gaussian_process.py | 3 + tests/test_samplers/test_xgboost.py | 4 + tests/utils/base.py | 34 +++++- tox.ini | 4 + 11 files changed, 239 insertions(+), 8 deletions(-) create mode 100644 black_it/_load_dependency.py diff --git a/README.md b/README.md index aad9d1f4..acc293e4 100644 --- a/README.md +++ b/README.md @@ -38,19 +38,34 @@ matter of days, with no need to reimplement all the plumbings from scratch. This project requires Python v3.8 or later. -To install the latest version of the package from [PyPI](https://pypi.org/project/black-it/): +To install the latest version of the package from [PyPI](https://pypi.org/project/black-it/), with all the extra dependencies (recommended): ``` -pip install black-it +pip install "black-it[all]" ``` Or, directly from GitHub: ``` -pip install git+https://github.com/bancaditalia/black-it.git#egg=black-it +pip install git+https://github.com/bancaditalia/black-it.git#egg="black-it[all]" ``` If you'd like to contribute to the package, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) guide. +### Feature-specific Package Dependencies + +We use the [optional dependencies mechanism of `setuptools`](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies) +(also called _extras_) to allow users to avoid dependencies for features they don't use. + +For the basic features of the package, you can install the `black-it` package without extras, e.g. `pip install black-it`. +However, for certain components, you will need to install some more extras using the syntax `pip install black-it[extra-1,extra-2,...]`. + +For example, the [Gaussian Process Sampler](https://bancaditalia.github.io/black-it/samplers/#black_it.samplers.gaussian_process.GaussianProcessSampler) +depends on the Python package [`GPy`](https://github.com/SheffieldML/GPy/). +If the Gaussian Process sampler is not needed by your application, you can avoid its installation by just installing `black-it` as explained above. +However, if you need the sampler, you must install `black-it` with the `gp-sampler` extra: `pip install black-it[gp-sampler]`. + +The special extra `all` will install all the dependencies. + ## Quick Example The GitHub repo of Black-it contains a series ready-to-run calibration examples. diff --git a/black_it/_load_dependency.py b/black_it/_load_dependency.py new file mode 100644 index 00000000..e379c6c2 --- /dev/null +++ b/black_it/_load_dependency.py @@ -0,0 +1,112 @@ +# Black-box ABM Calibration Kit (Black-it) +# Copyright (C) 2021-2023 Banca d'Italia +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +""" +Python module to handle extras dependencies loading and import errors. + +This is a private module of the library. There should be no point in using it directly from client code. +""" + +import sys +from typing import Optional + +# known extras and their dependencies +_GPY_PACKAGE_NAME = "GPy" +_GP_SAMPLER_EXTRA_NAME = "gp-sampler" + +_XGBOOST_PACKAGE_NAME = "xgboost" +_XGBOOST_SAMPLER_EXTRA_NAME = "xgboost-sampler" + + +class DependencyNotInstalled(Exception): + """Library exception for when a required dependency is not installed.""" + + def __init__(self, component_name: str, package_name: str, extra_name: str) -> None: + """Initialize the exception object.""" + message = ( + f"Cannot import package '{package_name}', required by component {component_name}. " + f"To solve the issue, you can install the extra '{extra_name}': pip install black-it[{extra_name}]" + ) + super().__init__(message) + + +class GPyNotSupportedOnPy311Exception(Exception): + """Specific exception class for import error of GPy on Python 3.11.""" + + __ERROR_MSG = ( + f"The GaussianProcessSampler depends on '{_GPY_PACKAGE_NAME}', which is not supported on Python 3.11; " + f"see https://github.com/bancaditalia/black-it/issues/36" + ) + + def __init__(self) -> None: + """Initialize the exception object.""" + super().__init__(self.__ERROR_MSG) + + +def _check_import_error_else_raise_exception( + import_error: Optional[ImportError], + component_name: str, + package_name: str, + black_it_extra_name: str, +) -> None: + """ + Check an import error; raise the DependencyNotInstalled exception with a useful message. + + Args: + import_error: the ImportError object generated by the failed attempt. If None, then no error occurred. + component_name: the component for which the dependency is needed + package_name: the Python package name of the dependency + black_it_extra_name: the name of the black-it extra to install to solve the issue. + """ + if import_error is None: + # nothing to do. + return + + # an import error happened; we need to raise error to the caller + raise DependencyNotInstalled(component_name, package_name, black_it_extra_name) + + +def _check_gpy_import_error_else_raise_exception( + import_error: Optional[ImportError], + component_name: str, + package_name: str, + black_it_extra_name: str, +) -> None: + """ + Check GPy import error and if an error occurred, raise erorr with a useful error message. + + We need to handle two cases: + + - the user is using Python 3.11: the GPy package cannot be installed there; + see https://github.com/SheffieldML/GPy/issues/998 + - the user did not install the 'gp-sampler' extra. + + Args: + import_error: the ImportError object generated by the failed attempt. If None, then no error occurred. + component_name: the component for which the dependency is needed + package_name: the Python package name of the dependency + black_it_extra_name: the name of the black-it extra to install to solve the issue. + """ + if import_error is None: + # nothing to do. + return + + if sys.version_info == (3, 11): + raise GPyNotSupportedOnPy311Exception() + + _check_import_error_else_raise_exception( + import_error, component_name, package_name, black_it_extra_name + ) diff --git a/black_it/samplers/gaussian_process.py b/black_it/samplers/gaussian_process.py index 16d3dd46..b747f831 100644 --- a/black_it/samplers/gaussian_process.py +++ b/black_it/samplers/gaussian_process.py @@ -20,14 +20,26 @@ from enum import Enum from typing import Optional, Tuple, cast -import GPy import numpy as np -from GPy.models import GPRegression from numpy.typing import NDArray from scipy.special import erfc # pylint: disable=no-name-in-module +from black_it._load_dependency import ( + _GP_SAMPLER_EXTRA_NAME, + _GPY_PACKAGE_NAME, + _check_import_error_else_raise_exception, +) from black_it.samplers.surrogate import MLSurrogateSampler +_GPY_IMPORT_ERROR: Optional[ImportError] +try: + import GPy + from GPy.models import GPRegression +except ImportError as e: + _GPY_IMPORT_ERROR = e +else: + _GPY_IMPORT_ERROR = None + class _AcquisitionTypes(Enum): """Enumeration of allowed acquisition types.""" @@ -71,6 +83,8 @@ def __init__( # pylint: disable=too-many-arguments optimize_restarts: number of independent random trials of the optimization of the GP hyperparameters acquisition: type of acquisition function, it can be 'expected_improvement' of simply 'mean' """ + self.__check_gpy_import_error() + self._validate_acquisition(acquisition) super().__init__( @@ -81,6 +95,13 @@ def __init__( # pylint: disable=too-many-arguments self.acquisition = acquisition self._gpmodel: Optional[GPRegression] = None + @classmethod + def __check_gpy_import_error(cls) -> None: + """Check if an import error happened while attempting to import the 'GPy' package.""" + _check_import_error_else_raise_exception( + _GPY_IMPORT_ERROR, cls.__name__, _GPY_PACKAGE_NAME, _GP_SAMPLER_EXTRA_NAME + ) + @staticmethod def _validate_acquisition(acquisition: str) -> None: """ diff --git a/black_it/samplers/xgboost.py b/black_it/samplers/xgboost.py index 9f2d8d39..ba4c1a65 100644 --- a/black_it/samplers/xgboost.py +++ b/black_it/samplers/xgboost.py @@ -19,15 +19,27 @@ from typing import Optional, cast import numpy as np -import xgboost as xgb from numpy.typing import NDArray +from black_it._load_dependency import ( + _XGBOOST_PACKAGE_NAME, + _XGBOOST_SAMPLER_EXTRA_NAME, + _check_import_error_else_raise_exception, +) from black_it.samplers.surrogate import MLSurrogateSampler MAX_FLOAT32 = np.finfo(np.float32).max MIN_FLOAT32 = np.finfo(np.float32).min EPS_FLOAT32 = np.finfo(np.float32).eps +_XGBOOST_IMPORT_ERROR: Optional[ImportError] +try: + import xgboost as xgb +except ImportError as e: + _XGBOOST_IMPORT_ERROR = e +else: + _XGBOOST_IMPORT_ERROR = None + class XGBoostSampler(MLSurrogateSampler): """This class implements xgboost sampling.""" @@ -64,6 +76,7 @@ def __init__( # pylint: disable=too-many-arguments References: Lamperti, Roventini, and Sani, "Agent-based model calibration using machine learning surrogates" """ + self.__check_xgboost_import_error() super().__init__( batch_size, random_state, max_deduplication_passes, candidate_pool_size ) @@ -75,6 +88,16 @@ def __init__( # pylint: disable=too-many-arguments self._n_estimators = n_estimators self._xg_regressor: Optional[xgb.XGBRegressor] = None + @classmethod + def __check_xgboost_import_error(cls) -> None: + """Check if an import error happened while attempting to import the 'xgboost' package.""" + _check_import_error_else_raise_exception( + _XGBOOST_IMPORT_ERROR, + cls.__name__, + _XGBOOST_PACKAGE_NAME, + _XGBOOST_SAMPLER_EXTRA_NAME, + ) + @property def colsample_bytree(self) -> float: """Get the colsample_bytree parameter.""" diff --git a/poetry.lock b/poetry.lock index 3c723d1b..55154759 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5409,7 +5409,12 @@ files = [ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[extras] +all = ["GPy", "xgboost"] +gp-sampler = ["GPy"] +xgboost-sampler = ["xgboost"] + [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "f93afe542963751c007963960fdda77db21bedca2c49ee5c7ff22054f679f186" +content-hash = "1133be970965fc08db2c336306b04350ff67d332c904335dafbc741ff8b10f9a" diff --git a/pyproject.toml b/pyproject.toml index e8a32d80..03dac5c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,14 @@ tox = "^4.4.7" twine = "^4.0.0" vulture = "^2.3" +GPy = { version = "^1.10.0", optional = true } +xgboost = { version = "^1.7.2", optional = true } + +[tool.poetry.extras] +gp-sampler = ["GPy"] +xgboost-sampler = ["xgboost"] +all = ["GPy", "xgboost"] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_calibrator.py b/tests/test_calibrator.py index f2f0c80c..e09f0556 100644 --- a/tests/test_calibrator.py +++ b/tests/test_calibrator.py @@ -38,8 +38,12 @@ from black_it.search_space import SearchSpace from .fixtures.test_models import NormalMV # type: ignore +from .utils.base import no_gpy_installed, no_python311_for_gpy, no_xgboost_installed +@no_python311_for_gpy +@no_gpy_installed +@no_xgboost_installed class TestCalibrate: # pylint: disable=too-many-instance-attributes,attribute-defined-outside-init """Test the Calibrator.calibrate method.""" diff --git a/tests/test_samplers/test_gaussian_process.py b/tests/test_samplers/test_gaussian_process.py index bf3f443d..0173b6fd 100644 --- a/tests/test_samplers/test_gaussian_process.py +++ b/tests/test_samplers/test_gaussian_process.py @@ -22,6 +22,9 @@ from black_it.samplers.gaussian_process import GaussianProcessSampler, _AcquisitionTypes from black_it.search_space import SearchSpace +from tests.utils.base import no_gpy_installed, no_python311_for_gpy + +pytestmark = [no_python311_for_gpy, no_gpy_installed] class TestGaussianProcess2D: # pylint: disable=attribute-defined-outside-init diff --git a/tests/test_samplers/test_xgboost.py b/tests/test_samplers/test_xgboost.py index a92f0f4d..2d9e9562 100644 --- a/tests/test_samplers/test_xgboost.py +++ b/tests/test_samplers/test_xgboost.py @@ -23,6 +23,10 @@ from black_it.search_space import SearchSpace from ..fixtures.test_models import BH4 # type: ignore +from ..utils.base import no_xgboost_installed + +pytestmark = no_xgboost_installed + expected_params = np.array([[0.24, 0.26], [0.26, 0.02], [0.08, 0.24], [0.15, 0.15]]) diff --git a/tests/utils/base.py b/tests/utils/base.py index 4ad7cef2..c6a8629a 100644 --- a/tests/utils/base.py +++ b/tests/utils/base.py @@ -16,14 +16,19 @@ """Generic utility functions.""" import dataclasses +import importlib import shutil import signal import subprocess # nosec B404 +import sys +import types from functools import wraps -from typing import Callable, List, Type, Union +from typing import Any, Callable, List, Optional, Type, Union import pytest +from _pytest.mark.structures import MarkDecorator +from black_it._load_dependency import _GPY_PACKAGE_NAME, _XGBOOST_PACKAGE_NAME from tests.conftest import DEFAULT_SUBPROCESS_TIMEOUT @@ -170,3 +175,30 @@ def wrapper(*args, **kwargs): # type: ignore return wrapper return decorator + + +def try_import_else_none(module_name: str) -> Optional[types.ModuleType]: + """Try to import a module; if it fails, return None.""" + try: + return importlib.import_module(module_name) + except ImportError: + return None + + +def try_import_else_skip(package_name: str, **skipif_kwargs: Any) -> MarkDecorator: + """Try to import the package; else skip the test(s).""" + return pytest.mark.skipif( + try_import_else_none(package_name) is None, + reason=f"Cannot run the test because the package '{package_name}' is not installed", + **skipif_kwargs, + ) + + +no_python311_for_gpy = pytest.mark.skipif( + (3, 11) <= sys.version_info < (3, 12), + reason="GPy not supported on Python 3.11, see: https://github.com/bancaditalia/black-it/issues/36", +) + + +no_gpy_installed = try_import_else_skip(_GPY_PACKAGE_NAME) +no_xgboost_installed = try_import_else_skip(_XGBOOST_PACKAGE_NAME) diff --git a/tox.ini b/tox.ini index 1d672d51..f5fb2a0a 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,10 @@ basepython = python3 [testenv] setenv = PYTHONPATH = {toxinidir} +extras = + all + gp-sampler + xgboost-sampler deps = pytest>=7.1.2,<7.2.0 pytest-cov>=3.0.0,<3.1.0