From 26667e7ae22d6c6dfb50b37ba434d4055de13520 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Mon, 11 Mar 2024 09:57:25 +0100 Subject: [PATCH] to_scl_dilation_mask: use defaults from spec related to https://github.com/Open-EO/openeo-geopyspark-driver/issues/715 --- openeo_driver/ProcessGraphDeserializer.py | 33 ++++++++++++----------- tests/test_views_execute.py | 28 +++++++++++++++++++ 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/openeo_driver/ProcessGraphDeserializer.py b/openeo_driver/ProcessGraphDeserializer.py index 0f4ec397..8ee85aeb 100644 --- a/openeo_driver/ProcessGraphDeserializer.py +++ b/openeo_driver/ProcessGraphDeserializer.py @@ -1993,23 +1993,26 @@ def mask_scl_dilation(args: Dict, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/to_scl_dilation_mask.json")) @process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/to_scl_dilation_mask.json")) -def to_scl_dilation_mask(args: Dict, env: EvalEnv): - cube: DriverDataCube = extract_arg(args, "data") - if not isinstance(cube, DriverDataCube): - raise ProcessParameterInvalidException( - parameter="data", - process="to_scl_dilation_mask", - reason=f"Invalid data type {type(cube)!r} expected raster-cube.", - ) +def to_scl_dilation_mask(args: ProcessArgs, env: EvalEnv): + cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube) + if hasattr(cube, "to_scl_dilation_mask"): - erosion_kernel_size = args.get("erosion_kernel_size", 0) - mask1_values = args.get("mask1_values", [2, 4, 5, 6, 7]) - mask2_values = args.get("mask2_values", [3, 8, 9, 10, 11]) - kernel1_size = args.get("kernel1_size", 17) - kernel2_size = args.get("kernel2_size", 201) - return cube.to_scl_dilation_mask(erosion_kernel_size, mask1_values, mask2_values, kernel1_size, kernel2_size) + # Get default values from spec + spec = read_spec("openeo-processes/experimental/to_scl_dilation_mask.json") + defaults = {param["name"]: param["default"] for param in spec["parameters"] if "default" in param} + optionals = { + arg: args.get_optional(arg, default=defaults[arg]) + for arg in [ + "erosion_kernel_size", + "mask1_values", + "mask2_values", + "kernel1_size", + "kernel2_size", + ] + } + return cube.to_scl_dilation_mask(**optionals) else: - return cube + raise FeatureUnsupportedException(message="to_scl_dilation_mask is not supported") @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/mask_l1c.json")) diff --git a/tests/test_views_execute.py b/tests/test_views_execute.py index 1ab9f016..e85065f0 100644 --- a/tests/test_views_execute.py +++ b/tests/test_views_execute.py @@ -4187,3 +4187,31 @@ def test_verify_for_synchronous_processing_failure(api, caplog): res = api.result(pg) res.assert_status_code(200) assert "Unexpected error while verifying synchronous processing: Nope, catch this" in caplog.text + + +def test_to_scl_dilation_mask_defaults(api): + api.check_result( + { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "SENTINEL2_L2A_SENTINELHUB", "bands": ["SCL"]}, + }, + "to_scl_dilation_mask": { + "process_id": "to_scl_dilation_mask", + "arguments": {"data": {"from_node": "loadcollection1"}}, + "result": True, + }, + } + ) + + dummy = dummy_backend.get_collection("SENTINEL2_L2A_SENTINELHUB") + assert dummy.to_scl_dilation_mask.call_count == 1 + args, kwargs = dummy.to_scl_dilation_mask.call_args + assert args == () + assert kwargs == { + "erosion_kernel_size": 0, + "kernel1_size": 17, + "kernel2_size": 201, + "mask1_values": [2, 4, 5, 6, 7], + "mask2_values": [3, 8, 9, 10, 11], + }