Skip to content

Commit

Permalink
Merge pull request #68 from Open-EO/additional-mask
Browse files Browse the repository at this point in the history
Fixed scl mask
  • Loading branch information
GriffinBabe authored Mar 20, 2024
2 parents c72d48e + 1fd09e5 commit 018fd0c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ cython_debug/
#.idea/

tests/test_openeo_gfmap/results/
notebook/
notebook/
data/
36 changes: 22 additions & 14 deletions examples/extraction_pipelines/S2_extraction_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -555,34 +555,42 @@
" fetch_type = FetchType.POLYGON\n",
" bands_to_download = ['S2-L2A-B01', 'S2-L2A-B02', 'S2-L2A-B03', 'S2-L2A-B04', 'S2-L2A-B05', 'S2-L2A-B06', 'S2-L2A-B07', 'S2-L2A-B08', 'S2-L2A-B8A', 'S2-L2A-B09', 'S2-L2A-B11', 'S2-L2A-B12', 'S2-L2A-SCL']\n",
"\n",
" # Compute the SCL dilation and add it to the cube\n",
" sub_collection = connection.load_collection(\n",
" collection_id=\"SENTINEL2_L2A\",\n",
" bands=[\"SCL\"],\n",
" temporal_extent=[start_date, end_date],\n",
" properties={\n",
" \"eo:cloud_cover\": lambda val: val <= 95.0,\n",
" \"tileId\": lambda val: val == row.s2_tile\n",
" }\n",
" )\n",
" scl_dilated_mask = sub_collection.process(\n",
" \"to_scl_dilation_mask\",\n",
" data=sub_collection,\n",
" scl_band_name=\"SCL\",\n",
" kernel1_size=17, # 17px dilation on a 20m layer\n",
" kernel2_size=77, # 77px dilation on a 20m layer\n",
" mask1_values=[2, 4, 5, 6, 7],\n",
" mask2_values=[3, 8, 9, 10, 11],\n",
" erosion_kernel_size=3\n",
" ).rename_labels(\"bands\", [\"S2-L2A-SCL_DILATED_MASK\"])\n",
"\n",
" # Create the job to extract S2\n",
" extraction_parameters = {\n",
" \"target_resolution\": 10,\n",
" \"load_collection\": {\n",
" \"eo:cloud_cover\": lambda val: val <= 95.0,\n",
" \"tileId\": lambda val: val == row.s2_tile\n",
" },\n",
" \"additional_masks\": scl_dilated_mask # Add an additional mask computed from the SCL layer\n",
" }\n",
" extractor = build_sentinel2_l2a_extractor(\n",
" backend_context, bands=bands_to_download, fetch_type=fetch_type.POLYGON, **extraction_parameters \n",
" )\n",
"\n",
" cube = extractor.get_cube(connection, spatial_extent_url, temporal_context)\n",
"\n",
" # Compute the SCL dilation and add it to the cube\n",
" scl_dilated_mask = cube.process(\n",
" \"to_scl_dilation_mask\",\n",
" data=cube,\n",
" scl_band_name=\"S2-L2A-SCL\",\n",
" kernel1_size=17, # 17px dilation on a 20m layer\n",
" kernel2_size=77, # 77px dilation on a 20m layer\n",
" mask1_values=[2, 4, 5, 6, 7],\n",
" mask2_values=[3, 8, 9, 10, 11],\n",
" erosion_kernel_size=3\n",
" ).rename_labels(\"bands\", [\"S2-L2A-SCL_DILATED_MASK\"])\n",
"\n",
" cube = cube.merge_cubes(scl_dilated_mask)\n",
"\n",
" # Compute the distance to cloud and add it to the cube\n",
" scl = cube.filter_bands(['S2-L2A-SCL'])\n",
" distance_to_cloud = scl.apply_neighborhood(\n",
Expand Down
12 changes: 10 additions & 2 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def resample_reproject(

def rename_bands(datacube: openeo.DataCube, mapping: dict) -> openeo.DataCube:
"""Rename the bands from the given mapping scheme"""
# Filter out bands that are not part of the datacube
print(datacube.dimension_labels("bands"))

# Filter out bands that are not part of the datacube
def filter_condition(band_name, _):
return band_name in datacube.metadata.band_names

Expand Down Expand Up @@ -136,6 +135,15 @@ def load_collection(
)
cube = cube.mask(pre_mask.resample_cube_spatial(cube))

# Include a band containing the SCL dilated band
additional_mask = params.get("additional_mask", None)
if additional_mask is not None:
assert isinstance(additional_mask, openeo.DataCube), (
f"The 'include_scl_dilation' parameter must be an openeo datacube, "
f"got {additional_mask}."
)
cube = cube.merge_cubes(additional_mask.resample_cube_spatial(cube))

if fetch_type == FetchType.POLYGON:
if isinstance(spatial_extent, str):
geometry = connection.load_url(
Expand Down
26 changes: 26 additions & 0 deletions src/openeo_gfmap/preprocessing/cloudmasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,29 @@ def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict)
return optical_cube

return optical_cube.merge_cubes(nonoptical_cube)


def cldmask_percentage(cube: openeo.DataCube, percentage: float = 0.95) -> openeo.DataCube:
"""Compute a cloud mask array, that either fully covers an observation or is empty.
It computes the percentage of HIGH_CLOUD_PROBABILITY pixels in the SCL mask. If the percentage
is higher than the given threshold, the mask will be covering the observation, otherwise False.
"""
non_scl_cube = cube.filter_bands(
bands=list(filter(lambda band: "SCL" not in band, cube.metadata.band_names))
)

scl_cube = cube.filter_bands(["SCL"])

cld_mask = scl_cube.apply_neighborhood(
process=openeo.UDF.from_file("udf_mask.py", context={}),
size=[
{"dimension": "x", "unit": "px", "value": 1024},
{"dimension": "y", "unit": "px", "value": 1024},
{"dimension": "t", "value": 1},
],
overlap=[],
)

non_scl_cube = non_scl_cube.mask(cld_mask.resample_cube_spatial(cube))

return non_scl_cube.merge_cubes(scl_cube)
36 changes: 36 additions & 0 deletions src/openeo_gfmap/preprocessing/udf_cldmask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import xarray as xr
from openeo.udf import XarrayDataCube


def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
"""
Computes a cloud mask covering a full observation or nothing depending on the percentage of
high probability cloud pixels. If the amount of before mentioned pixels is higher than 95%,
then returns a mask covering the whole observation, otherwise returns an empty mask.
"""
array = cube.get_array().transpose("t", "bands", "y", "x")

output_array = np.zeros(
shape=(array.shape[0], 1, array.shape[2], array.shape[3]), dtype=np.uint8
)

for i in range(array.shape[0]):
high_proba_count = ((array[i] == 9) * 1).sum()
high_proba_percentage = high_proba_count / (array.shape[2] * array.shape[3])

if high_proba_percentage > 0.95:
output_array[i] = 1

output_array = xr.DataArray(
output_array,
dims=["t", "bands", "y", "x"],
coords={
"t": array.t,
"bands": ["mask"],
"y": array.y,
"x": array.x,
},
)

return XarrayDataCube(output_array)

0 comments on commit 018fd0c

Please sign in to comment.