Skip to content

Commit

Permalink
Fix scenario filter for parameter values of multidimensional entities (
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Nov 21, 2024
2 parents b5f78bd + 510414f commit 1564a61
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
32 changes: 32 additions & 0 deletions spinedb_api/filters/scenario_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state):
Alias: a subquery for parameter value filtered by selected scenario
"""
ext_entity_sq = _ext_entity_sq(db_map, state)
ext_entity_element_count_sq = (
db_map.query(
db_map.entity_element_sq.c.entity_id,
func.count(db_map.entity_element_sq.c.element_id).label("element_count"),
)
.group_by(db_map.entity_element_sq.c.entity_id)
.subquery()
)
ext_entity_class_dimension_count_sq = (
db_map.query(
db_map.entity_class_dimension_sq.c.entity_class_id,
func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"),
)
.group_by(db_map.entity_class_dimension_sq.c.entity_class_id)
.subquery()
)
ext_parameter_value_sq = (
db_map.query(
state.original_parameter_value_sq,
Expand Down Expand Up @@ -387,6 +403,22 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state):
and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True),
),
)
.outerjoin(
ext_entity_element_count_sq, ext_entity_element_count_sq.c.entity_id == ext_parameter_value_sq.c.entity_id
)
.outerjoin(
ext_entity_class_dimension_count_sq,
ext_entity_class_dimension_count_sq.c.entity_class_id == ext_parameter_value_sq.c.entity_class_id,
)
.filter(
or_(
and_(
ext_entity_element_count_sq.c.element_count == None,
ext_entity_class_dimension_count_sq.c.dimension_count == None,
),
ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count,
)
)
.subquery()
)

Expand Down
56 changes: 56 additions & 0 deletions tests/filters/test_scenario_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,62 @@ def test_parameter_values_for_entities_that_swim_against_active_by_default(self)
self.assertEqual(len(values), 1)
self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3)

def test_parameter_values_of_multidimensional_entity_whose_elements_have_entity_alternatives(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self._assert_success(db_map.add_scenario_item(name="base"))
self._assert_success(
db_map.add_scenario_alternative_item(scenario_name="base", alternative_name="Base", rank=1)
)
self._assert_success(db_map.add_entity_class_item(name="Object"))
self._assert_success(db_map.add_entity_item(name="visible", entity_class_name="Object"))
self._assert_success(
db_map.add_entity_alternative_item(
entity_class_name="Object", entity_byname=("visible",), alternative_name="Base", active=True
)
)
self._assert_success(db_map.add_entity_item(name="invisible", entity_class_name="Object"))
self._assert_success(
db_map.add_entity_alternative_item(
entity_class_name="Object", entity_byname=("invisible",), alternative_name="Base", active=False
)
)
self._assert_success(db_map.add_entity_class_item(name="Relationship", dimension_name_list=("Object",)))
self._assert_success(
db_map.add_entity_item(element_name_list=("visible",), entity_class_name="Relationship")
)
self._assert_success(
db_map.add_entity_item(element_name_list=("invisible",), entity_class_name="Relationship")
)
self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="Relationship"))
value, value_type = to_database(2.3)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Relationship",
entity_byname=("visible",),
parameter_definition_name="y",
alternative_name="Base",
value=value,
type=value_type,
)
)
value, value_type = to_database(-2.3)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Relationship",
entity_byname=("invisible",),
parameter_definition_name="y",
alternative_name="Base",
value=value,
type=value_type,
)
)
db_map.commit_session("Add test data")
config = scenario_filter_config("base")
scenario_filter_from_dict(db_map, config)
values = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(values), 1)
self.assertEqual(from_database(values[0].value, values[0].type), 2.3)


class TestScenarioFilterUtils(DataBuilderTestCase):
def test_scenario_filter_config(self):
Expand Down

0 comments on commit 1564a61

Please sign in to comment.