diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 3aafc106..2192fd30 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -280,7 +280,7 @@ def is_constant(self): def is_pivoted(self): if is_pivoted(self.position): return True - if self.position == Position.header and self.value is None: + if self.position == Position.header and self.value is None and self.child is not None: return True if self.child is None: return False diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index 7809c64f..0ac76a55 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -26,6 +26,7 @@ IndexNameMapping, ParameterValueIndexMapping, ExpandedParameterValueMapping, + ParameterValueMapping, ParameterValueTypeMapping, ParameterDefaultValueTypeMapping, DefaultValueIndexNameMapping, @@ -2096,5 +2097,23 @@ def test_child_mapping_without_filter_doesnt_have_filter(self): self.assertFalse(mapping.has_filter()) +class TestIsPivoted(unittest.TestCase): + def test_pivoted_position_returns_true(self): + mapping = AlternativeMapping(-1) + self.assertTrue(mapping.is_pivoted()) + + def test_recursively_returns_false_when_all_mappings_are_non_pivoted(self): + mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(1)]) + self.assertFalse(mapping.is_pivoted()) + + def test_returns_true_when_position_is_header_and_has_child(self): + mapping = unflatten([AlternativeMapping(Position.header), ParameterValueMapping(0)]) + self.assertTrue(mapping.is_pivoted()) + + def test_returns_false_when_position_is_header_and_is_leaf(self): + mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(Position.header)]) + self.assertFalse(mapping.is_pivoted()) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mapping.py b/tests/test_mapping.py index acb67ef1..94358de2 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -38,6 +38,15 @@ def test_non_pivoted_columns_when_non_tail_mapping_is_pivoted(self): root_mapping = unflatten([Mapping(5), Mapping(Position.hidden), Mapping(-1), Mapping(13), Mapping(23)]) self.assertEqual(root_mapping.non_pivoted_columns(), [5, 13]) + def test_is_pivoted_returns_true_when_position_is_pivoted(self): + mapping = Mapping(-1) + self.assertTrue(mapping.is_pivoted()) + + def test_is_pivoted_returns_false_when_all_mappings_are_non_pivoted(self): + mappings = [Mapping(0), Mapping(1)] + root = unflatten(mappings) + self.assertFalse(root.is_pivoted()) + if __name__ == "__main__": unittest.main()