Skip to content

Commit

Permalink
Store problem configuration in v2.Problem (#338)
Browse files Browse the repository at this point in the history
Introduces `v2.Problem.config` which contains the info from the PEtab yaml file.
The same as #326, but for `v2.Problem`.



---------

Co-authored-by: Dilan Pathirana <[email protected]>
  • Loading branch information
dweindl and dilpath authored Dec 9, 2024
1 parent 4a551a7 commit d7f7e3a
Showing 1 changed file with 71 additions and 31 deletions.
102 changes: 71 additions & 31 deletions petab/v2/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING

import pandas as pd
from pydantic import AnyUrl, BaseModel, Field

from ..v1 import (
conditions,
Expand All @@ -25,6 +26,7 @@
yaml,
)
from ..v1.models.model import Model, model_factory
from ..v1.problem import ListOfFiles, VersionNumber
from ..v1.yaml import get_path_prefix
from ..v2.C import * # noqa: F403
from . import experiments
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
observable_df: pd.DataFrame = None,
mapping_df: pd.DataFrame = None,
extensions_config: dict = None,
config: ProblemConfig = None,
):
from ..v2.lint import default_validation_tasks

Expand All @@ -88,7 +91,7 @@ def __init__(
self.validation_tasks: list[
ValidationTask
] = default_validation_tasks.copy()

self.config = config
if self.experiment_df is not None:
warnings.warn(
"The experiment table is not yet supported and "
Expand Down Expand Up @@ -199,40 +202,37 @@ def get_path(filename):
"Consider using "
"petab.CompositeProblem.from_yaml() instead."
)
config = ProblemConfig(
**yaml_config, base_path=base_path, filepath=yaml_file
)
problem0 = config.problems[0]

problem0 = yaml_config["problems"][0]

if isinstance(yaml_config[PARAMETER_FILE], list):
if isinstance(config.parameter_file, list):
parameter_df = parameters.get_parameter_df(
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
[get_path(f) for f in config.parameter_file]
)
else:
parameter_df = (
parameters.get_parameter_df(
get_path(yaml_config[PARAMETER_FILE])
)
if yaml_config[PARAMETER_FILE]
parameters.get_parameter_df(get_path(config.parameter_file))
if config.parameter_file
else None
)

if len(problem0[MODEL_FILES] or []) > 1:
if len(problem0.model_files or []) > 1:
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
raise NotImplementedError(
"Support for multiple models is not yet implemented."
)
if not problem0[MODEL_FILES]:
model = None
else:
model_id, model_info = next(iter(problem0[MODEL_FILES].items()))
model = None
if problem0.model_files:
model_id, model_info = next(iter(problem0.model_files.items()))
model = model_factory(
get_path(model_info[MODEL_LOCATION]),
model_info[MODEL_LANGUAGE],
get_path(model_info.location),
model_info.language,
model_id=model_id,
)

measurement_files = [
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
]
measurement_files = [get_path(f) for f in problem0.measurement_files]
# If there are multiple tables, we will merge them
measurement_df = (
core.concat_tables(
Expand All @@ -242,19 +242,15 @@ def get_path(filename):
else None
)

condition_files = [
get_path(f) for f in problem0.get(CONDITION_FILES, [])
]
condition_files = [get_path(f) for f in problem0.condition_files]
# If there are multiple tables, we will merge them
condition_df = (
core.concat_tables(condition_files, conditions.get_condition_df)
if condition_files
else None
)

experiment_files = [
get_path(f) for f in problem0.get(EXPERIMENT_FILES, [])
]
experiment_files = [get_path(f) for f in problem0.experiment_files]
# If there are multiple tables, we will merge them
experiment_df = (
core.concat_tables(experiment_files, experiments.get_experiment_df)
Expand All @@ -263,7 +259,7 @@ def get_path(filename):
)

visualization_files = [
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
get_path(f) for f in problem0.visualization_files
]
# If there are multiple tables, we will merge them
visualization_df = (
Expand All @@ -272,17 +268,15 @@ def get_path(filename):
else None
)

observable_files = [
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
]
observable_files = [get_path(f) for f in problem0.observable_files]
# If there are multiple tables, we will merge them
observable_df = (
core.concat_tables(observable_files, observables.get_observable_df)
if observable_files
else None
)

mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
mapping_files = [get_path(f) for f in problem0.mapping_files]
# If there are multiple tables, we will merge them
mapping_df = (
core.concat_tables(mapping_files, mapping.get_mapping_df)
Expand All @@ -299,7 +293,7 @@ def get_path(filename):
model=model,
visualization_df=visualization_df,
mapping_df=mapping_df,
extensions_config=yaml_config.get(EXTENSIONS, {}),
extensions_config=config.extensions,
)

@staticmethod
Expand Down Expand Up @@ -981,3 +975,49 @@ def add_experiment(self, id_: str, *args):
if self.experiment_df is not None
else tmp_df
)


class ModelFile(BaseModel):
"""A file in the PEtab problem configuration."""

location: str | AnyUrl
language: str


class SubProblem(BaseModel):
"""A `problems` object in the PEtab problem configuration."""

model_files: dict[str, ModelFile] | None = {}
measurement_files: ListOfFiles = []
condition_files: ListOfFiles = []
experiment_files: ListOfFiles = []
observable_files: ListOfFiles = []
visualization_files: ListOfFiles = []
mapping_files: ListOfFiles = []


class ExtensionConfig(BaseModel):
"""The configuration of a PEtab extension."""

name: str
version: str
config: dict


class ProblemConfig(BaseModel):
"""The PEtab problem configuration."""

filepath: str | AnyUrl | None = Field(
None,
description="The path to the PEtab problem configuration.",
exclude=True,
)
base_path: str | AnyUrl | None = Field(
None,
description="The base path to resolve relative paths.",
exclude=True,
)
format_version: VersionNumber = "2.0.0"
parameter_file: str | AnyUrl | None = None
problems: list[SubProblem] = []
extensions: list[ExtensionConfig] = []

0 comments on commit d7f7e3a

Please sign in to comment.