Skip to content

Commit

Permalink
Treat pivoted leaf import mapping correctly even when its position is…
Browse files Browse the repository at this point in the history
… set to header (#257)
  • Loading branch information
soininen authored Aug 9, 2023
2 parents 4168c30 + f4272dd commit f50f7ba
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _split_mapping(mapping):
non_pivoted = []
pivoted_from_header = []
for m in flattened:
if pivoted and m is flattened[-1]:
if (pivoted or pivoted_from_header) and m is flattened[-1]:
# If any other mapping is pivoted, ignore last mapping's position
break
if m.position == Position.header and m.value is None:
Expand Down
3 changes: 2 additions & 1 deletion spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ImportKey(Enum):

def __str__(self):
name = {
self.ALTERNATIVE_NAME.value: "Alternative names",
self.CLASS_NAME.value: "Class names",
self.OBJECT_CLASS_NAME.value: "Object class names",
self.OBJECT_NAME.value: "Object names",
Expand Down Expand Up @@ -279,7 +280,7 @@ def is_constant(self):
def is_pivoted(self):
if is_pivoted(self.position):
return True
if self.position == Position.header and self.value is None:
if self.position == Position.header and self.value is None and self.child is not None:
return True
if self.child is None:
return False
Expand Down
37 changes: 37 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,43 @@ def test_arrays_get_imported_correctly_when_objects_are_in_header_and_alternativ
},
)

def test_header_position_is_ignored_in_last_mapping_if_other_mappings_are_in_header(self):
header = ["Dimension", "parameter1", "parameter2"]
data_source = iter([["d1", 1.1, -2.3], ["d2", -1.1, 2.3]])
mappings = [
[
{"map_type": "ObjectClass", "position": "table_name"},
{"map_type": "Object", "position": 0},
{"map_type": "ObjectMetadata", "position": "hidden"},
{"map_type": "ParameterDefinition", "position": "header"},
{"map_type": "Alternative", "position": "hidden", "value": "Base"},
{"map_type": "ParameterValueMetadata", "position": "hidden"},
{"map_type": "ParameterValue", "position": "header"},
]
]
convert_function_specs = {0: "string", 1: "float", 2: "float"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}

mapped_data, errors = get_mapped_data(
data_source, mappings, header, table_name="Data", column_convert_fns=convert_functions
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{
"alternatives": {"Base"},
"object_classes": {"Data"},
"object_parameter_values": [
["Data", "d1", "parameter1", 1.1, "Base"],
["Data", "d1", "parameter2", -2.3, "Base"],
["Data", "d2", "parameter1", -1.1, "Base"],
["Data", "d2", "parameter2", 2.3, "Base"],
],
"object_parameters": [("Data", "parameter1"), ("Data", "parameter2")],
"objects": {("Data", "d1"), ("Data", "d2")},
},
)


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/import_mapping/test_import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
IndexNameMapping,
ParameterValueIndexMapping,
ExpandedParameterValueMapping,
ParameterValueMapping,
ParameterValueTypeMapping,
ParameterDefaultValueTypeMapping,
DefaultValueIndexNameMapping,
Expand Down Expand Up @@ -2096,5 +2097,23 @@ def test_child_mapping_without_filter_doesnt_have_filter(self):
self.assertFalse(mapping.has_filter())


class TestIsPivoted(unittest.TestCase):
def test_pivoted_position_returns_true(self):
mapping = AlternativeMapping(-1)
self.assertTrue(mapping.is_pivoted())

def test_recursively_returns_false_when_all_mappings_are_non_pivoted(self):
mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(1)])
self.assertFalse(mapping.is_pivoted())

def test_returns_true_when_position_is_header_and_has_child(self):
mapping = unflatten([AlternativeMapping(Position.header), ParameterValueMapping(0)])
self.assertTrue(mapping.is_pivoted())

def test_returns_false_when_position_is_header_and_is_leaf(self):
mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(Position.header)])
self.assertFalse(mapping.is_pivoted())


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_non_pivoted_columns_when_non_tail_mapping_is_pivoted(self):
root_mapping = unflatten([Mapping(5), Mapping(Position.hidden), Mapping(-1), Mapping(13), Mapping(23)])
self.assertEqual(root_mapping.non_pivoted_columns(), [5, 13])

def test_is_pivoted_returns_true_when_position_is_pivoted(self):
mapping = Mapping(-1)
self.assertTrue(mapping.is_pivoted())

def test_is_pivoted_returns_false_when_all_mappings_are_non_pivoted(self):
mappings = [Mapping(0), Mapping(1)]
root = unflatten(mappings)
self.assertFalse(root.is_pivoted())


if __name__ == "__main__":
unittest.main()

0 comments on commit f50f7ba

Please sign in to comment.