From d7f7e3ae6bebc809c020cb8086790e1c557400eb Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 9 Dec 2024 22:14:18 +0100 Subject: [PATCH] Store problem configuration in `v2.Problem` (#338) Introduces `v2.Problem.config` which contains the info from the PEtab yaml file. The same as #326, but for `v2.Problem`. --------- Co-authored-by: Dilan Pathirana <59329744+dilpath@users.noreply.github.com> --- petab/v2/problem.py | 102 ++++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 31 deletions(-) diff --git a/petab/v2/problem.py b/petab/v2/problem.py index c22d74e1..b61d8b14 100644 --- a/petab/v2/problem.py +++ b/petab/v2/problem.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING import pandas as pd +from pydantic import AnyUrl, BaseModel, Field from ..v1 import ( conditions, @@ -25,6 +26,7 @@ yaml, ) from ..v1.models.model import Model, model_factory +from ..v1.problem import ListOfFiles, VersionNumber from ..v1.yaml import get_path_prefix from ..v2.C import * # noqa: F403 from . import experiments @@ -73,6 +75,7 @@ def __init__( observable_df: pd.DataFrame = None, mapping_df: pd.DataFrame = None, extensions_config: dict = None, + config: ProblemConfig = None, ): from ..v2.lint import default_validation_tasks @@ -88,7 +91,7 @@ def __init__( self.validation_tasks: list[ ValidationTask ] = default_validation_tasks.copy() - + self.config = config if self.experiment_df is not None: warnings.warn( "The experiment table is not yet supported and " @@ -199,40 +202,37 @@ def get_path(filename): "Consider using " "petab.CompositeProblem.from_yaml() instead." ) + config = ProblemConfig( + **yaml_config, base_path=base_path, filepath=yaml_file + ) + problem0 = config.problems[0] - problem0 = yaml_config["problems"][0] - - if isinstance(yaml_config[PARAMETER_FILE], list): + if isinstance(config.parameter_file, list): parameter_df = parameters.get_parameter_df( - [get_path(f) for f in yaml_config[PARAMETER_FILE]] + [get_path(f) for f in config.parameter_file] ) else: parameter_df = ( - parameters.get_parameter_df( - get_path(yaml_config[PARAMETER_FILE]) - ) - if yaml_config[PARAMETER_FILE] + parameters.get_parameter_df(get_path(config.parameter_file)) + if config.parameter_file else None ) - if len(problem0[MODEL_FILES] or []) > 1: + if len(problem0.model_files or []) > 1: # TODO https://github.com/PEtab-dev/libpetab-python/issues/6 raise NotImplementedError( "Support for multiple models is not yet implemented." ) - if not problem0[MODEL_FILES]: - model = None - else: - model_id, model_info = next(iter(problem0[MODEL_FILES].items())) + model = None + if problem0.model_files: + model_id, model_info = next(iter(problem0.model_files.items())) model = model_factory( - get_path(model_info[MODEL_LOCATION]), - model_info[MODEL_LANGUAGE], + get_path(model_info.location), + model_info.language, model_id=model_id, ) - measurement_files = [ - get_path(f) for f in problem0.get(MEASUREMENT_FILES, []) - ] + measurement_files = [get_path(f) for f in problem0.measurement_files] # If there are multiple tables, we will merge them measurement_df = ( core.concat_tables( @@ -242,9 +242,7 @@ def get_path(filename): else None ) - condition_files = [ - get_path(f) for f in problem0.get(CONDITION_FILES, []) - ] + condition_files = [get_path(f) for f in problem0.condition_files] # If there are multiple tables, we will merge them condition_df = ( core.concat_tables(condition_files, conditions.get_condition_df) @@ -252,9 +250,7 @@ def get_path(filename): else None ) - experiment_files = [ - get_path(f) for f in problem0.get(EXPERIMENT_FILES, []) - ] + experiment_files = [get_path(f) for f in problem0.experiment_files] # If there are multiple tables, we will merge them experiment_df = ( core.concat_tables(experiment_files, experiments.get_experiment_df) @@ -263,7 +259,7 @@ def get_path(filename): ) visualization_files = [ - get_path(f) for f in problem0.get(VISUALIZATION_FILES, []) + get_path(f) for f in problem0.visualization_files ] # If there are multiple tables, we will merge them visualization_df = ( @@ -272,9 +268,7 @@ def get_path(filename): else None ) - observable_files = [ - get_path(f) for f in problem0.get(OBSERVABLE_FILES, []) - ] + observable_files = [get_path(f) for f in problem0.observable_files] # If there are multiple tables, we will merge them observable_df = ( core.concat_tables(observable_files, observables.get_observable_df) @@ -282,7 +276,7 @@ def get_path(filename): else None ) - mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])] + mapping_files = [get_path(f) for f in problem0.mapping_files] # If there are multiple tables, we will merge them mapping_df = ( core.concat_tables(mapping_files, mapping.get_mapping_df) @@ -299,7 +293,7 @@ def get_path(filename): model=model, visualization_df=visualization_df, mapping_df=mapping_df, - extensions_config=yaml_config.get(EXTENSIONS, {}), + extensions_config=config.extensions, ) @staticmethod @@ -981,3 +975,49 @@ def add_experiment(self, id_: str, *args): if self.experiment_df is not None else tmp_df ) + + +class ModelFile(BaseModel): + """A file in the PEtab problem configuration.""" + + location: str | AnyUrl + language: str + + +class SubProblem(BaseModel): + """A `problems` object in the PEtab problem configuration.""" + + model_files: dict[str, ModelFile] | None = {} + measurement_files: ListOfFiles = [] + condition_files: ListOfFiles = [] + experiment_files: ListOfFiles = [] + observable_files: ListOfFiles = [] + visualization_files: ListOfFiles = [] + mapping_files: ListOfFiles = [] + + +class ExtensionConfig(BaseModel): + """The configuration of a PEtab extension.""" + + name: str + version: str + config: dict + + +class ProblemConfig(BaseModel): + """The PEtab problem configuration.""" + + filepath: str | AnyUrl | None = Field( + None, + description="The path to the PEtab problem configuration.", + exclude=True, + ) + base_path: str | AnyUrl | None = Field( + None, + description="The base path to resolve relative paths.", + exclude=True, + ) + format_version: VersionNumber = "2.0.0" + parameter_file: str | AnyUrl | None = None + problems: list[SubProblem] = [] + extensions: list[ExtensionConfig] = []