From 74346e9e931557c021f62115d43018dd1d8f8647 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 26 Nov 2024 22:29:03 +0100 Subject: [PATCH] Improve model extension import Previously, it wasn't possible to import two model modules with the same name. Now this is at least possible if they are in different locations. Overwriting and importing a previously imported extension is still not supported. --- python/sdist/amici/__init__.py | 123 +++++++++--------------- python/sdist/amici/__init__.template.py | 38 +++++++- python/tests/test_events.py | 4 +- python/tests/test_sbml_import.py | 64 +++++++++--- swig/modelname.template.i | 49 +++++++++- 5 files changed, 174 insertions(+), 104 deletions(-) diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index ee78a6045b..6788fefe77 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -7,14 +7,13 @@ """ import contextlib -import datetime +import importlib.util import importlib import os import re import sys -import sysconfig from pathlib import Path -from types import ModuleType as ModelModule +from types import ModuleType from typing import Any from collections.abc import Callable @@ -145,6 +144,8 @@ def get_model(self) -> amici.Model: def get_jax_model(self) -> JAXModel: ... AmiciModel = Union[amici.Model, amici.ModelPtr] +else: + ModelModule = ModuleType class add_path: @@ -182,6 +183,29 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path = self.orginal_path +def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType: + """Import a module from a given path. + + Import a module from a given path. The module is not added to + `sys.modules`. The `_self` attribute of the module is set to the module + itself. + + :param module_name: + Name of the module. + :param module_path: + Path to the module file. Absolute or relative to the current working + directory. + """ + module_path = Path(module_path).resolve() + if not module_path.is_file(): + raise ModuleNotFoundError(f"Module file not found: {module_path}") + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module._self = module + spec.loader.exec_module(module) + return module + + def import_model_module( module_name: str, module_path: Path | str ) -> ModelModule: @@ -195,86 +219,29 @@ def import_model_module( :return: The model module """ - module_path = str(module_path) + model_root = str(module_path) # ensure we will find the newly created module importlib.invalidate_caches() if not os.path.isdir(module_path): - raise ValueError(f"module_path '{module_path}' is not a directory.") - - module_path = os.path.abspath(module_path) - ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") - ext_mod_name = f"{module_name}._{module_name}" - - # module already loaded? - if (m := sys.modules.get(ext_mod_name)) and m.__file__.endswith( - ext_suffix - ): - # this is the c++ extension we can't unload - loaded_file = Path(m.__file__) - needed_file = Path( - module_path, - module_name, - f"_{module_name}{ext_suffix}", - ) - # if we import a matlab-generated model where the extension - # is in a different directory - needed_file_matlab = Path( - module_path, - f"_{module_name}{ext_suffix}", - ) - if not needed_file.exists(): - if needed_file_matlab.exists(): - needed_file = needed_file_matlab - else: - raise ModuleNotFoundError( - f"Cannot find extension module for {module_name} in " - f"{module_path}." - ) - - if not loaded_file.samefile(needed_file): - # this is not the right module, and we can't unload it - raise RuntimeError( - f"Cannot import extension for {module_name} from " - f"{module_path}, because an extension with the same name was " - f"has already been imported from {loaded_file.parent}. " - "Import the module with a different name or restart the " - "Python kernel." - ) - # this is the right file, but did it change on disk? - t_imported = m._get_import_time() # noqa: protected-access - t_modified = os.path.getmtime(m.__file__) - if t_imported < t_modified: - t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat() - t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat() - raise RuntimeError( - f"Cannot import extension for {module_name} from " - f"{module_path}, because an extension in the same location " - f"has already been imported, but the file was modified on " - f"disk. \nImported at {t_imp_str}\nModified at {t_mod_str}.\n" - "Import the module with a different name or restart the " - "Python kernel." - ) - - # unlike extension modules, Python modules can be unloaded - if module_name in sys.modules: - # if a module with that name is already in sys.modules, we remove it, - # along with all other modules from that package. otherwise, there - # will be trouble if two different models with the same name are to - # be imported. - del sys.modules[module_name] - # collect first, don't delete while iterating - to_unload = { - loaded_module_name - for loaded_module_name in sys.modules.keys() - if loaded_module_name.startswith(f"{module_name}.") - } - for m in to_unload: - del sys.modules[m] - - with set_path(module_path): - return importlib.import_module(module_name) + raise ValueError(f"module_path '{model_root}' is not a directory.") + + module_path = Path(model_root, module_name, "__init__.py") + + # We may want to import a matlab-generated model where the extension + # is in a different directory. This is not a regular use case. It's only + # used in the amici tests and can be removed at any time. + # The models (currently) use the default swig-import and require + # modifying sys.path. + module_path_matlab = Path(model_root, f"{module_name}.py") + if not module_path.is_file() and module_path_matlab.is_file(): + with set_path(model_root): + return _module_from_path(module_name, module_path_matlab) + + module = _module_from_path(module_name, module_path) + module._self = module + return module class AmiciVersionError(RuntimeError): diff --git a/python/sdist/amici/__init__.template.py b/python/sdist/amici/__init__.template.py index 56064535e8..077c961db8 100644 --- a/python/sdist/amici/__init__.template.py +++ b/python/sdist/amici/__init__.template.py @@ -1,9 +1,12 @@ """AMICI-generated module for model TPL_MODELNAME""" +import datetime +import os from pathlib import Path from typing import TYPE_CHECKING import amici + if TYPE_CHECKING: from amici.jax import JAXModel @@ -18,14 +21,39 @@ "version currently installed." ) -from .TPL_MODELNAME import * # noqa: F403, F401 -from .TPL_MODELNAME import getModel as get_model # noqa: F401 +TPL_MODELNAME = amici._module_from_path( + "TPL_MODELNAME.TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py" +) +for var in dir(TPL_MODELNAME): + if not var.startswith("_"): + globals()[var] = getattr(TPL_MODELNAME, var) +get_model = TPL_MODELNAME.getModel +# _self: this module; will be set during import +TPL_MODELNAME._model_module = _self # noqa: F821 -def get_jax_model() -> "JAXModel": - from .jax import JAXModel_TPL_MODELNAME - return JAXModel_TPL_MODELNAME() +def get_jax_model() -> "JAXModel": + # If the model directory was meanwhile overwritten, this would load the + # new version, which would not match the previously imported extension. + # This is not allowed, as it would lead to inconsistencies. + jax_py_file = Path(__file__).parent / "jax.py" + jax_py_file = jax_py_file.resolve() + t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access + t_modified = os.path.getmtime(jax_py_file) + if t_imported < t_modified: + t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat() + t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat() + raise RuntimeError( + f"Refusing to import {jax_py_file} which was changed since " + f"TPL_MODELNAME was imported. This is to avoid inconsistencies " + "between the different model implementations.\n" + f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n" + "Import the module with a different name or restart the " + "Python kernel." + ) + jax = amici._module_from_path("jax", jax_py_file) + return jax.JAXModel_TPL_MODELNAME() __version__ = "TPL_PACKAGE_VERSION" diff --git a/python/tests/test_events.py b/python/tests/test_events.py index d16877fd2e..87d6738ffc 100644 --- a/python/tests/test_events.py +++ b/python/tests/test_events.py @@ -724,7 +724,7 @@ def test_handling_of_fixed_time_point_event_triggers(): end """ module_name = "test_events_time_based" - with TemporaryDirectory(prefix=module_name, delete=False) as outdir: + with TemporaryDirectory(prefix=module_name) as outdir: antimony2amici( ant_model, model_name=module_name, @@ -765,7 +765,7 @@ def test_multiple_event_assignment_with_compartment(): """ # watch out for too long path names on windows ... module_name = "tst_mltple_ea_w_cmprtmnt" - with TemporaryDirectory(prefix=module_name, delete=False) as outdir: + with TemporaryDirectory(prefix=module_name) as outdir: antimony2amici( ant_model, model_name=module_name, diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 7a3f0a2720..b089d3fb37 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -199,6 +199,7 @@ def test_logging_works(observable_dependent_error_model, caplog): @skip_on_valgrind def test_model_module_is_set(observable_dependent_error_model): model_module = observable_dependent_error_model + assert model_module.getModel().module is model_module assert isinstance(model_module.getModel().module, amici.ModelModule) @@ -763,11 +764,14 @@ def test_constraints(): @skip_on_valgrind -def test_same_extension_error(): +def test_import_same_model_name(): """Test for error when loading a model with the same extension name as an already loaded model.""" from amici.antimony_import import antimony2amici + from amici import import_model_module + # create three versions of a toy model with different parameter values + # to detect which model was loaded ant_model_1 = """ model test_same_extension_error species A = 0 @@ -776,40 +780,68 @@ def test_same_extension_error(): end """ ant_model_2 = ant_model_1.replace("1", "2") + ant_model_3 = ant_model_1.replace("1", "3") module_name = "test_same_extension" - with TemporaryDirectory(prefix=module_name, delete=False) as outdir: + with TemporaryDirectory(prefix=module_name) as outdir: + outdir_1 = Path(outdir, "model_1") + outdir_2 = Path(outdir, "model_2") + + # import the first two models, with the same name, + # but in different location (this is now supported) antimony2amici( ant_model_1, model_name=module_name, - output_dir=outdir, + output_dir=outdir_1, compute_conservation_laws=False, ) - model_module_1 = amici.import_model_module( - module_name=module_name, module_path=outdir + + antimony2amici( + ant_model_2, + model_name=module_name, + output_dir=outdir_2, + compute_conservation_laws=False, + ) + + model_module_1 = import_model_module( + module_name=module_name, module_path=outdir_1 ) assert model_module_1.get_model().getParameters()[0] == 1.0 + # no error if the same model is loaded again without changes on disk - model_module_1 = amici.import_model_module( - module_name=module_name, module_path=outdir + model_module_1b = import_model_module( + module_name=module_name, module_path=outdir_1 + ) + # downside: the modules will compare as different + assert (model_module_1 == model_module_1b) is False + assert model_module_1.__file__ == model_module_1b.__file__ + assert model_module_1b.get_model().getParameters()[0] == 1.0 + + model_module_2 = import_model_module( + module_name=module_name, module_path=outdir_2 ) assert model_module_1.get_model().getParameters()[0] == 1.0 + assert model_module_2.get_model().getParameters()[0] == 2.0 - # Try to import another model with the same name + # import the third model, with the same name and location as the second + # model -- this is not supported, because there is some caching at + # the C level we cannot control (or don't know how to) # On Windows, this will give "permission denied" when building the - # extension + # extension, because we cannot delete a shared library that is in use + if sys.platform == "win32": return antimony2amici( - ant_model_2, + ant_model_3, model_name=module_name, - output_dir=outdir, - compute_conservation_laws=False, + output_dir=outdir_2, ) - with pytest.raises(RuntimeError, match="has already been imported"): - amici.import_model_module( - module_name=module_name, module_path=outdir - ) + + with pytest.raises(RuntimeError, match="in the same location"): + import_model_module(module_name=module_name, module_path=outdir_2) + + # this should not affect the previously loaded models assert model_module_1.get_model().getParameters()[0] == 1.0 + assert model_module_2.get_model().getParameters()[0] == 2.0 diff --git a/swig/modelname.template.i b/swig/modelname.template.i index d7aab8ed8a..db857348b4 100644 --- a/swig/modelname.template.i +++ b/swig/modelname.template.i @@ -1,4 +1,48 @@ -%module TPL_MODELNAME +%define MODULEIMPORT +" +import amici +import datetime +import importlib.util +import os +import sysconfig +from pathlib import Path + +ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') +_TPL_MODELNAME = amici._module_from_path( + 'TPL_MODELNAME._TPL_MODELNAME' if __package__ or '.' in __name__ + else '_TPL_MODELNAME', + Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}', +) + +def _get_import_time(): + return _TPL_MODELNAME._get_import_time() + +t_imported = _get_import_time() +t_modified = os.path.getmtime(__file__) +if t_imported < t_modified: + t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat() + t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat() + module_path = Path(__file__).resolve() + raise RuntimeError( + f'Cannot import extension for TPL_MODELNAME from ' + f'{module_path}, because an extension in the same location ' + f'has already been imported, but the file was modified on ' + f'disk. \\nImported at {t_imp_str}\\nModified at {t_mod_str}.\\n' + 'Import the module with a different name or restart the ' + 'Python kernel.' + ) +" +%enddef + +%module(package="TPL_MODELNAME",moduleimport=MODULEIMPORT) TPL_MODELNAME + +%pythoncode %{ +# the model-package __init__.py module (will be set during import) +_model_module = None + + +%} + %import amici.i // Add necessary symbols to generated header @@ -30,8 +74,7 @@ static double _get_import_time(); // Make model module accessible from the model %feature("pythonappend") amici::generic_model::getModel %{ if '.' in __name__: - import sys - val.module = sys.modules['.'.join(__name__.split('.')[:-1])] + val.module = _model_module %}