diff --git a/surround/__init__.py b/surround/__init__.py index 1cc198ff..dfac8ce2 100644 --- a/surround/__init__.py +++ b/surround/__init__.py @@ -4,7 +4,7 @@ from .state import State from .stage import Validator, Filter, Estimator from .visualiser import Visualiser -from .config import Config +from .config import Config, has_config from .assembler import Assembler from .runners import Runner, RunMode diff --git a/surround/assembler.py b/surround/assembler.py index 72ddd58a..f3d87313 100644 --- a/surround/assembler.py +++ b/surround/assembler.py @@ -7,7 +7,7 @@ from abc import ABC from datetime import datetime -from .config import Config +from .config import Config, has_config from .stage import Filter, Estimator, Validator from .visualiser import Visualiser @@ -57,7 +57,8 @@ class Assembler(ABC): """ # pylint: disable=too-many-instance-attributes - def __init__(self, assembler_name=""): + @has_config + def __init__(self, assembler_name="", config=None): """ Constructor for an Assembler pipeline: @@ -66,7 +67,7 @@ def __init__(self, assembler_name=""): """ self.assembler_name = assembler_name - self.config = Config(auto_load=True) + self.config = config self.stages = None self.estimator = None self.validator = None diff --git a/surround/config.py b/surround/config.py index b961bbb6..04ebd328 100644 --- a/surround/config.py +++ b/surround/config.py @@ -1,5 +1,6 @@ import ast import os +import functools from pathlib import Path from collections.abc import Mapping @@ -51,6 +52,19 @@ class Config(Mapping): SURRROUND_PREDICT_DEBUG=False """ + __instance = None + + @staticmethod + def instance(): + """ + Static method which returns the a singleton instance of Config. + """ + + if not Config.__instance: + Config.__instance = Config(auto_load=True) + + return Config.__instance + def __init__(self, project_root=None, package_path=None, auto_load=False): """ Constructor of the Config class, loads the default YAML file into storage. @@ -373,3 +387,40 @@ def __len__(self): """ return len(self._storage) + +def has_config(func=None, name="config", filename=None): + """ + Decorator that injects the singleton config instance into the arguments of the function. + e.g. + ``` + @has_config + def some_func(config): + value = config.get_path("some.config") + ... + + @has_config(name="global_config") + def other_func(global_config, new_config): + value = config.get_path("some.config") + + @has_config(filename="override.yaml") + def some_func(config): + value = config.get_path("override.value") + ``` + """ + + @functools.wraps(func) + def function_wrapper(*args, **kwargs): + config = Config.instance() + if filename: + path = os.path.join(config.get_path("package_path"), filename) + config.read_config_files([path]) + kwargs[name] = config + return func(*args, **kwargs) + + if func: + return function_wrapper + + def recursive_wrapper(func): + return has_config(func, name, filename) + + return recursive_wrapper diff --git a/templates/new/batch_main.py.txt b/templates/new/batch_main.py.txt index 8c800212..defa8066 100644 --- a/templates/new/batch_main.py.txt +++ b/templates/new/batch_main.py.txt @@ -5,7 +5,7 @@ Runners and assemblies are defined in here. import os import argparse -from surround import Surround, Assembler, Config +from surround import Surround, Assembler, has_config from .stages import Baseline, InputValidator, ReportGenerator from .file_system_runner import FileSystemRunner @@ -20,8 +20,8 @@ ASSEMBLIES = [ .set_visualiser(ReportGenerator()) ] -def main(): - config = Config(auto_load=True) +@has_config +def main(config=None): default_runner = config.get_path('runner.default') default_assembler = config.get_path('assembler.default') diff --git a/templates/new/web_main.py.txt b/templates/new/web_main.py.txt index 7de7cb84..ef0feec9 100644 --- a/templates/new/web_main.py.txt +++ b/templates/new/web_main.py.txt @@ -5,7 +5,7 @@ Runners and ASSEMBLIES are defined in here. import os import argparse -from surround import Surround, Assembler, Config +from surround import Surround, Assembler, has_config from .stages import Baseline, InputValidator, ReportGenerator from .file_system_runner import FileSystemRunner from .web_runner import WebRunner @@ -22,8 +22,8 @@ ASSEMBLIES = [ .set_visualiser(ReportGenerator()) ] -def main(): - config = Config(auto_load=True) +@has_config +def main(config=None): default_runner = config.get_path('runner.default') default_assembler = config.get_path('assembler.default')