Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat pivoted leaf import mapping correctly even when its position is set to header #257

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _split_mapping(mapping):
non_pivoted = []
pivoted_from_header = []
for m in flattened:
if pivoted and m is flattened[-1]:
if (pivoted or pivoted_from_header) and m is flattened[-1]:
# If any other mapping is pivoted, ignore last mapping's position
break
if m.position == Position.header and m.value is None:
Expand Down
3 changes: 2 additions & 1 deletion spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ImportKey(Enum):

def __str__(self):
name = {
self.ALTERNATIVE_NAME.value: "Alternative names",
self.CLASS_NAME.value: "Class names",
self.OBJECT_CLASS_NAME.value: "Object class names",
self.OBJECT_NAME.value: "Object names",
Expand Down Expand Up @@ -279,7 +280,7 @@ def is_constant(self):
def is_pivoted(self):
if is_pivoted(self.position):
return True
if self.position == Position.header and self.value is None:
if self.position == Position.header and self.value is None and self.child is not None:
return True
if self.child is None:
return False
Expand Down
37 changes: 37 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,43 @@ def test_arrays_get_imported_correctly_when_objects_are_in_header_and_alternativ
},
)

def test_header_position_is_ignored_in_last_mapping_if_other_mappings_are_in_header(self):
header = ["Dimension", "parameter1", "parameter2"]
data_source = iter([["d1", 1.1, -2.3], ["d2", -1.1, 2.3]])
mappings = [
[
{"map_type": "ObjectClass", "position": "table_name"},
{"map_type": "Object", "position": 0},
{"map_type": "ObjectMetadata", "position": "hidden"},
{"map_type": "ParameterDefinition", "position": "header"},
{"map_type": "Alternative", "position": "hidden", "value": "Base"},
{"map_type": "ParameterValueMetadata", "position": "hidden"},
{"map_type": "ParameterValue", "position": "header"},
]
]
convert_function_specs = {0: "string", 1: "float", 2: "float"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}

mapped_data, errors = get_mapped_data(
data_source, mappings, header, table_name="Data", column_convert_fns=convert_functions
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{
"alternatives": {"Base"},
"object_classes": {"Data"},
"object_parameter_values": [
["Data", "d1", "parameter1", 1.1, "Base"],
["Data", "d1", "parameter2", -2.3, "Base"],
["Data", "d2", "parameter1", -1.1, "Base"],
["Data", "d2", "parameter2", 2.3, "Base"],
],
"object_parameters": [("Data", "parameter1"), ("Data", "parameter2")],
"objects": {("Data", "d1"), ("Data", "d2")},
},
)


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/import_mapping/test_import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
IndexNameMapping,
ParameterValueIndexMapping,
ExpandedParameterValueMapping,
ParameterValueMapping,
ParameterValueTypeMapping,
ParameterDefaultValueTypeMapping,
DefaultValueIndexNameMapping,
Expand Down Expand Up @@ -2096,5 +2097,23 @@ def test_child_mapping_without_filter_doesnt_have_filter(self):
self.assertFalse(mapping.has_filter())


class TestIsPivoted(unittest.TestCase):
def test_pivoted_position_returns_true(self):
mapping = AlternativeMapping(-1)
self.assertTrue(mapping.is_pivoted())

def test_recursively_returns_false_when_all_mappings_are_non_pivoted(self):
mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(1)])
self.assertFalse(mapping.is_pivoted())

def test_returns_true_when_position_is_header_and_has_child(self):
mapping = unflatten([AlternativeMapping(Position.header), ParameterValueMapping(0)])
self.assertTrue(mapping.is_pivoted())

def test_returns_false_when_position_is_header_and_is_leaf(self):
mapping = unflatten([AlternativeMapping(0), ParameterValueMapping(Position.header)])
self.assertFalse(mapping.is_pivoted())


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_non_pivoted_columns_when_non_tail_mapping_is_pivoted(self):
root_mapping = unflatten([Mapping(5), Mapping(Position.hidden), Mapping(-1), Mapping(13), Mapping(23)])
self.assertEqual(root_mapping.non_pivoted_columns(), [5, 13])

def test_is_pivoted_returns_true_when_position_is_pivoted(self):
mapping = Mapping(-1)
self.assertTrue(mapping.is_pivoted())

def test_is_pivoted_returns_false_when_all_mappings_are_non_pivoted(self):
mappings = [Mapping(0), Mapping(1)]
root = unflatten(mappings)
self.assertFalse(root.is_pivoted())


if __name__ == "__main__":
unittest.main()