Skip to content

Commit

Permalink
Improver api changes (#2031)
Browse files Browse the repository at this point in the history
* missing improver_api_changes

* Update to StandardiseMetadata class init arguments
  • Loading branch information
cpelley authored Sep 20, 2024
1 parent 03f3f29 commit cc8fa30
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 70 deletions.
1 change: 1 addition & 0 deletions improver/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"LightningFromCapePrecip": "improver.lightning",
"LightningMultivariateProbability_USAF2024": "improver.lightning",
"ManipulateReliabilityTable": "improver.calibration.reliability_calibration",
"maximum_in_height": "improver.utilities.cube_manipulation",
"MaxInTimeWindow": "improver.cube_combiner",
"MergeCubes": "improver.utilities.cube_manipulation",
"MergeCubesForWeightedBlending": "improver.blending.weighted_blend",
Expand Down
5 changes: 2 additions & 3 deletions improver/cli/standardise.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ def process(
"""
from improver.standardise import StandardiseMetadata

return StandardiseMetadata()(
cube,
return StandardiseMetadata(
new_name=new_name,
new_units=new_units,
coords_to_remove=coords_to_remove,
coord_modification=coord_modification,
attributes_dict=attributes_config,
)
)(cube)
91 changes: 51 additions & 40 deletions improver/standardise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,46 @@
class StandardiseMetadata(BasePlugin):
"""Plugin to standardise cube metadata"""

def __init__(
self,
new_name: Optional[str] = None,
new_units: Optional[str] = None,
coords_to_remove: Optional[List[str]] = None,
coord_modification: Optional[Dict[str, float]] = None,
attributes_dict: Optional[Dict[str, Any]] = None,
):
"""
Instantiate our class for standardising cube metadata.
Args:
new_name:
Optional rename for output cube
new_units:
Optional unit conversion for output cube
coords_to_remove:
Optional list of scalar coordinates to remove from output cube
coord_modification:
Optional dictionary used to directly modify the values of
scalar coordinates. To be used with extreme caution.
For example this dictionary might take the form:
{"height": 1.5} to set the height coordinate to have a value
of 1.5m (assuming original units of m).
This can be used to align e.g. temperatures defined at slightly
different heights where this difference is considered small
enough to ignore. Type is inferred, so providing a value of 2
will result in an integer type, whilst a value of 2.0 will
result in a float type.
attributes_dict:
Optional dictionary of required attribute updates. Keys are
attribute names, and values are the required changes.
See improver.metadata.amend.amend_attributes for details.
"""
self._new_name = new_name
self._new_units = new_units
self._coords_to_remove = coords_to_remove
self._coord_modification = coord_modification
self._attributes_dict = attributes_dict

@staticmethod
def _rm_air_temperature_status_flag(cube: Cube) -> Cube:
"""
Expand Down Expand Up @@ -198,15 +238,7 @@ def _discard_redundant_cell_methods(cube: Cube) -> None:

cube.cell_methods = updated_cms

def process(
self,
cube: Cube,
new_name: Optional[str] = None,
new_units: Optional[str] = None,
coords_to_remove: Optional[List[str]] = None,
coord_modification: Optional[Dict[str, float]] = None,
attributes_dict: Optional[Dict[str, Any]] = None,
) -> Cube:
def process(self, cube: Cube) -> Cube:
"""
Perform compulsory and user-configurable metadata adjustments. The
compulsory adjustments are:
Expand All @@ -220,27 +252,6 @@ def process(
Args:
cube:
Input cube to be standardised
new_name:
Optional rename for output cube
new_units:
Optional unit conversion for output cube
coords_to_remove:
Optional list of scalar coordinates to remove from output cube
coord_modification:
Optional dictionary used to directly modify the values of
scalar coordinates. To be used with extreme caution.
For example this dictionary might take the form:
{"height": 1.5} to set the height coordinate to have a value
of 1.5m (assuming original units of m).
This can be used to align e.g. temperatures defined at slightly
different heights where this difference is considered small
enough to ignore. Type is inferred, so providing a value of 2
will result in an integer type, whilst a value of 2.0 will
result in a float type.
attributes_dict:
Optional dictionary of required attribute updates. Keys are
attribute names, and values are the required changes.
See improver.metadata.amend.amend_attributes for details.
Returns:
The processed cube
Expand All @@ -249,16 +260,16 @@ def process(
cube = self._rm_air_temperature_status_flag(cube)
cube = self._collapse_scalar_dimensions(cube)

if new_name:
cube.rename(new_name)
if new_units:
cube.convert_units(new_units)
if coords_to_remove:
self._remove_scalar_coords(cube, coords_to_remove)
if coord_modification:
self._modify_scalar_coord_value(cube, coord_modification)
if attributes_dict:
amend_attributes(cube, attributes_dict)
if self._new_name:
cube.rename(self._new_name)
if self._new_units:
cube.convert_units(self._new_units)
if self._coords_to_remove:
self._remove_scalar_coords(cube, self._coords_to_remove)
if self._coord_modification:
self._modify_scalar_coord_value(cube, self._coord_modification)
if self._attributes_dict:
amend_attributes(cube, self._attributes_dict)
self._discard_redundant_cell_methods(cube)

# this must be done after unit conversion as if the input is an integer
Expand Down
49 changes: 22 additions & 27 deletions improver_tests/standardise/test_StandardiseMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ class HaltExecution(Exception):
def test_as_cubelist_called(mock_as_cube):
mock_as_cube.side_effect = HaltExecution
try:
StandardiseMetadata()(
sentinel.cube,
StandardiseMetadata(
new_name=sentinel.new_name,
new_units=sentinel.new_units,
coords_to_remove=sentinel.coords_to_remove,
coord_modification=sentinel.coord_modification,
attributes_dict=sentinel.attributes_dict,
)
)(sentinel.cube)
except HaltExecution:
pass
mock_as_cube.assert_called_once_with(sentinel.cube)
Expand All @@ -51,12 +50,11 @@ def setUp(self):
time_bounds=[datetime(2019, 10, 10, 23), datetime(2019, 10, 11)],
frt=datetime(2019, 10, 10, 18),
)
self.plugin = StandardiseMetadata()

def test_null(self):
"""Test process method with default arguments returns an unchanged
cube"""
result = self.plugin.process(self.cube.copy())
result = StandardiseMetadata().process(self.cube.copy())
self.assertIsInstance(result, iris.cube.Cube)
self.assertArrayAlmostEqual(result.data, self.cube.data)
self.assertEqual(result.metadata, self.cube.metadata)
Expand All @@ -72,7 +70,7 @@ def test_standardise_time_coords(self):
np.float64
)
self.cube.coord("forecast_period").convert_units("hours")
result = self.plugin.process(self.cube)
result = StandardiseMetadata().process(self.cube)
self.assertEqual(result.coord("forecast_period").units, "seconds")
self.assertEqual(result.coord("forecast_period").points.dtype, np.int32)
self.assertEqual(result.coord("forecast_period").bounds.dtype, np.int32)
Expand All @@ -86,13 +84,13 @@ def test_standardise_time_coords_missing_fp(self):
np.float64
)
self.cube.remove_coord("forecast_period")
result = self.plugin.process(self.cube)
result = StandardiseMetadata().process(self.cube)
self.assertEqual(result.coord("time").points.dtype, np.int64)

def test_collapse_scalar_dimensions(self):
"""Test scalar dimension is collapsed"""
cube = iris.util.new_axis(self.cube, "time")
result = self.plugin.process(cube)
result = StandardiseMetadata().process(cube)
dim_coord_names = [coord.name() for coord in result.coords(dim_coords=True)]
aux_coord_names = [coord.name() for coord in result.coords(dim_coords=False)]
self.assertSequenceEqual(result.shape, (5, 5))
Expand All @@ -104,7 +102,7 @@ def test_realization_not_collapsed(self):
realization = AuxCoord([1], "realization")
self.cube.add_aux_coord(realization)
cube = iris.util.new_axis(self.cube, "realization")
result = self.plugin.process(cube)
result = StandardiseMetadata().process(cube)
dim_coord_names = [coord.name() for coord in result.coords(dim_coords=True)]
self.assertSequenceEqual(result.shape, (1, 5, 5))
self.assertIn("realization", dim_coord_names)
Expand All @@ -129,14 +127,14 @@ def test_metadata_changes(self):
# Modifier for scalar height coordinate
coord_modification = {"height": 2.0}

result = self.plugin.process(
self.cube,
plugin = StandardiseMetadata(
new_name=new_name,
new_units="degC",
coords_to_remove=["forecast_period"],
coord_modification=coord_modification,
attributes_dict=attribute_changes,
)
result = plugin.process(self.cube)
self.assertEqual(result.name(), new_name)
self.assertEqual(result.units, "degC")
self.assertArrayAlmostEqual(result.data, expected_data, decimal=5)
Expand All @@ -150,11 +148,10 @@ def test_attempt_modify_dimension_coord(self):

coord_modification = {"latitude": [0.1, 1.2, 3.4, 5.6, 7]}
msg = "Modifying dimension coordinate values is not allowed "
plugin = StandardiseMetadata(coord_modification=coord_modification)

with self.assertRaisesRegex(ValueError, msg):
self.plugin.process(
self.cube, coord_modification=coord_modification,
)
plugin.process(self.cube)

def test_attempt_modify_multi_valued_coord(self):
"""Test that an exception is raised if the coord_modification is used
Expand All @@ -168,10 +165,9 @@ def test_attempt_modify_multi_valued_coord(self):
coord_modification = {"kittens": [2, 3, 4, 5, 6]}
msg = "Modifying multi-valued coordinates is not allowed."

plugin = StandardiseMetadata(coord_modification=coord_modification)
with self.assertRaisesRegex(ValueError, msg):
self.plugin.process(
cube, coord_modification=coord_modification,
)
plugin.process(cube)

def test_attempt_modify_time_coord(self):
"""Test that an exception is raised if the coord_modification targets
Expand All @@ -181,10 +177,9 @@ def test_attempt_modify_time_coord(self):
for coord in ["time", "forecast_period", "forecast_reference_time"]:
coord_modification = {coord: 100}

plugin = StandardiseMetadata(coord_modification=coord_modification)
with self.assertRaisesRegex(ValueError, msg):
self.plugin.process(
self.cube, coord_modification=coord_modification,
)
plugin.process(self.cube)

def test_discard_cellmethod(self):
"""Test changes to cell_methods"""
Expand All @@ -193,7 +188,7 @@ def test_discard_cellmethod(self):
iris.coords.CellMethod(method="point", coords="time"),
iris.coords.CellMethod(method="max", coords="realization"),
]
result = self.plugin.process(cube,)
result = StandardiseMetadata().process(cube)
self.assertEqual(
result.cell_methods,
(iris.coords.CellMethod(method="max", coords="realization"),),
Expand All @@ -203,7 +198,7 @@ def test_float_deescalation(self):
"""Test precision de-escalation from float64 to float32"""
cube = self.cube.copy()
cube.data = cube.data.astype(np.float64)
result = self.plugin.process(cube)
result = StandardiseMetadata().process(cube)
self.assertEqual(result.data.dtype, np.float32)
self.assertArrayAlmostEqual(result.data, self.cube.data, decimal=4)

Expand All @@ -213,7 +208,7 @@ def test_float_deescalation_with_unit_change(self):
cube = set_up_variable_cube(
np.ones((5, 5), dtype=np.int16), name="rainrate", units="mm h-1"
)
result = self.plugin.process(cube, new_units="m s-1")
result = StandardiseMetadata(new_units="m s-1").process(cube)
self.assertEqual(cube.dtype, np.float32)
self.assertEqual(result.data.dtype, np.float32)

Expand All @@ -237,7 +232,7 @@ def test_air_temperature_status_flag_coord(self):
# surface altitude.
result_no_sf = cube.copy()
result_no_sf.data[:, 0, ...] = np.nan
target = self.plugin.process(result_no_sf)
target = StandardiseMetadata().process(result_no_sf)

cube_with_flags = cube.copy()
flag_status = np.zeros((3, 3, 5, 5), dtype=np.int8)
Expand All @@ -253,7 +248,7 @@ def test_air_temperature_status_flag_coord(self):
)
cube_with_flags.add_aux_coord(status_flag_coord, (0, 1, 2, 3))

result = self.plugin.process(cube_with_flags)
result = StandardiseMetadata().process(cube_with_flags)
self.assertArrayEqual(result.data, target.data)
self.assertEqual(result.coords(), target.coords())

Expand All @@ -277,7 +272,7 @@ def test_air_temperature_status_flag_coord_without_realization(self):
# surface altitude.
result_no_sf = cube.copy()
result_no_sf.data[0, ...] = np.nan
target = self.plugin.process(result_no_sf)
target = StandardiseMetadata().process(result_no_sf)

cube_with_flags = cube.copy()
flag_status = np.zeros((3, 5, 5), dtype=np.int8)
Expand All @@ -293,7 +288,7 @@ def test_air_temperature_status_flag_coord_without_realization(self):
)
cube_with_flags.add_aux_coord(status_flag_coord, (0, 1, 2))

result = self.plugin.process(cube_with_flags)
result = StandardiseMetadata().process(cube_with_flags)
self.assertArrayEqual(result.data, target.data)
self.assertEqual(result.coords(), target.coords())

Expand Down

0 comments on commit cc8fa30

Please sign in to comment.