From c4ee43077dab4bf9bf4dbde073a5ce6cab682a49 Mon Sep 17 00:00:00 2001
From: Impact <pascal.mouquet@ird.fr>
Date: Thu, 23 Dec 2021 13:03:23 +0400
Subject: [PATCH] implementation of new cloudmask cm003 version, with
 possibility to chose cld_prb threshold (probability) and number of dilatation
 iterations (iterations). cm001 stays default cloudmask version

---
 sen2chain/cloud_mask.py | 53 +++++++++++++++++++++++++++++++++++++++++
 sen2chain/products.py   | 28 +++++++++++++++-------
 sen2chain/tiles.py      | 30 +++++++++++++++++------
 3 files changed, 96 insertions(+), 15 deletions(-)

diff --git a/sen2chain/cloud_mask.py b/sen2chain/cloud_mask.py
index eadd536..6653ef8 100755
--- a/sen2chain/cloud_mask.py
+++ b/sen2chain/cloud_mask.py
@@ -390,3 +390,56 @@ def create_cloud_mask_b11(
     os.remove(out_mask)
     logger.info("Done: {}".format(out_path.name))
 
+def create_cloud_mask_v003(cloud_mask: Union[str, pathlib.PosixPath],
+                           out_path="./cm003.jp2",
+                           probability: int = 1, 
+                           iterations: int = 5
+                           ) -> None:
+    """
+    create cloud mask version cm003. This cloudmask uses a simple thresholding and dilatations 
+    over the 20m cloud_probability band from Sen2Cor. The threshold value and number of dilatation cycles 
+    can be manually modified by the user. Default values 1% and 5 cycles. 
+    :param cloud_mask: path to the 20m cloud mask raster.
+    :param out_path: path to the output file.
+    :param probability: threshold in percent for the 20m cloud_probability band binarisation.
+    :param iterations: number of dilatation cylces to apply.
+    """
+    
+    out_temp_path = Path(Config().get("temp_path"))
+    out_temp = str(out_temp_path / (out_path.stem + "_tmp_cm003.tif"))
+    
+    with rasterio.open(str(cloud_mask)) as cld_src:
+        cld_profile = cld_src.profile
+        cld = cld_src.read(1).astype(np.int8)
+    
+    cld = np.where(cld >= probability, 1, 0)
+                               
+    kernel = np.array([[0, 1, 1, 1, 0, 0],
+                       [1, 1, 1, 1, 1, 1],
+                       [1, 1, 1, 1, 1, 1],
+                       [1, 1, 1, 1, 1, 1],
+                       [0, 1, 1, 1, 1, 0]])
+                       
+    cld_dilated = ndimage.binary_dilation(cld, kernel, iterations = iterations)
+    cld_profile.update(driver="Gtiff",
+                       compress="DEFLATE",
+                       tiled=False,
+                       dtype=np.int8)
+                       
+    with rasterio.Env(GDAL_CACHEMAX=512) as env:
+        with rasterio.open(str(out_temp), "w", **cld_profile) as dst:
+            dst.write(cld_dilated.astype(np.int8), 1)
+    
+    # Save to JP2000 
+    src_ds = gdal.Open(out_temp)
+    driver = gdal.GetDriverByName("JP2OpenJPEG")
+    dst_ds = driver.CreateCopy(str(out_path), src_ds,
+                               options=["CODEC=JP2", "QUALITY=100", "REVERSIBLE=YES", "YCBCR420=NO"])
+    dst_ds = None
+    src_ds = None
+
+    os.remove(out_temp)
+    logger.info("Done: {}".format(out_path.name))
+
+    
+
diff --git a/sen2chain/products.py b/sen2chain/products.py
index aca2814..2c8051e 100755
--- a/sen2chain/products.py
+++ b/sen2chain/products.py
@@ -17,7 +17,7 @@ from .utils import grouper, setPermissions, get_current_Sen2Cor_version
 from .config import Config, SHARED_DATA
 from .xmlparser import MetadataParser, Sen2ChainMetadataParser
 from .sen2cor import process_sen2cor
-from .cloud_mask import create_cloud_mask, create_cloud_mask_v2, create_cloud_mask_b11
+from .cloud_mask import create_cloud_mask, create_cloud_mask_v2, create_cloud_mask_b11, create_cloud_mask_v003
 from .indices import IndicesCollection
 from .colormap import create_l2a_ql, create_l1c_ql
 
@@ -517,12 +517,15 @@ class L2aProduct(Product):
                            probability: int = 1,
                            iterations: int = 5,
                            reprocess: bool = False,
-                           out_path_mask = None,
-                           out_path_mask_b11 = None
+                           #~ out_path_mask = None,
+                           #~ out_path_mask_b11 = None
                            ) -> "L2aProduct":
         """
         """
-        logger.info("Computing cloudmask version {}: {}".format(cm_version, self.identifier))
+        if cm_version == "cm003":
+            logger.info("Computing cloudmask version {}, probability {}%, iteration(s) {}: {}".format(cm_version, probability, iterations, self.identifier))
+        else:
+            logger.info("Computing cloudmask version {}: {}".format(cm_version, self.identifier))
 
         cloudmask = NewCloudMaskProduct(l2a_identifier = self.identifier, 
                                         sen2chain_processing_version = self.sen2chain_processing_version,
@@ -559,7 +562,16 @@ class L2aProduct(Product):
                 else:
                     logger.info("No cloudmask version cm001 found, please compute this one first")
             elif cm_version == "cm003":
-                toto=12
+                if cloudmask.path.exists():     # in version 3.8 will be updated using missing_ok = True
+                    cloudmask.path.unlink()
+                    cloudmask._info_path.unlink()
+                create_cloud_mask_v003(cloud_mask = self.msk_cldprb_20m,
+                                       out_path = cloudmask.path,
+                                       probability = probability,
+                                       iterations = iterations,
+                                       )
+                
+                
             elif cm_version == "cm004":
                 toto=12
             else:
@@ -1062,7 +1074,7 @@ class NewCloudMaskProduct:
         else:
             self.tile = self.get_tile(identifier or l2a_identifier)
             self.l2a = (l2a_identifier or self.get_l2a(identifier)).replace(".SAFE", "")
-            self.suffix = [i for i in ["CM001", "CM002-B11", "CM003-PRB" + str(probability) + "ITR" + str(iterations)] if cm_version.upper() in i][0]
+            self.suffix = [i for i in ["CM001", "CM002-B11", "CM003-PRB" + str(probability) + "-ITER" + str(iterations)] if cm_version.upper() in i][0]
             self.identifier = identifier or self.l2a + "_" + self.suffix + ".jp2"
             self.cm_version, self.probability, self.iterations = self.get_cm_version(self.identifier)
             self.path = self._library_path / self.tile / self.l2a / self.identifier
@@ -1104,7 +1116,7 @@ class NewCloudMaskProduct:
         :param string: string from which to extract the version name.
         """
         try:
-            return re.findall(r"S2.+_(CM[0-9]{3})-PRB(.*)ITR(.*)\.", identifier)[0]
+            return re.findall(r"S2.+_(CM[0-9]{3})-PRB(.*)-ITER(.*)\.", identifier)[0]
         except:
             try:
                 return [re.findall(r"S2.+_(CM[0-9]{3}).+", identifier)[0], None, None]
@@ -1169,7 +1181,7 @@ class IndiceProduct:
             self.indice = (indice or identifier.replace(".", "_").split("_")[7]).upper()
             self.masked = masked
             if self.masked:
-                self.suffix = [i for i in ["CM001", "CM002-B11", "CM003-PRB" + str(probability) + "ITR" + str(iterations)] if cm_version.upper() in i][0]
+                self.suffix = [i for i in ["CM001", "CM002-B11", "CM003-PRB" + str(probability) + "-ITER" + str(iterations)] if cm_version.upper() in i][0]
                 self.identifier = identifier or self.l2a + "_" + self.indice + "_" + self.suffix + ".jp2"
                 self.cm_version, self.probability, self.iterations = NewCloudMaskProduct.get_cm_version(self.identifier)
             else:
diff --git a/sen2chain/tiles.py b/sen2chain/tiles.py
index 7425ca5..925cae3 100644
--- a/sen2chain/tiles.py
+++ b/sen2chain/tiles.py
@@ -18,7 +18,6 @@ from datetime import datetime
 from pprint import pformat
 # type annotations
 from typing import List, Dict, Iterable
-from itertools import chain
 
 from .config import Config, SHARED_DATA
 from .utils import str_to_datetime, human_size, getFolderSize
@@ -173,7 +172,7 @@ class CloudMaskList(ProductsList):
         filtered = CloudMaskList()
         for k, v in self._dict.items():
             if "_CM003" in k:
-                if "-PRB" + str(probability) + "ITER" + str(iterations) in k:
+                if "-PRB" + str(probability) + "-ITER" + str(iterations) in k:
                     filtered[k] = {"date": v["date"], "cloud_cover": v["cloud_cover"]}
             else:
                 filtered[k] = {"date": v["date"], "cloud_cover": v["cloud_cover"]}
@@ -539,12 +538,23 @@ class Tile:
         return prodlist
         
     def compute_l2a(self,
+                    reprocess: bool = False,
+                    p_60m_missing: bool = False,
                     date_min: str = None,
                     date_max: str = None,
                     nb_proc: int = 4):
         """
-        Compute all missing l2a for l1c products
+        Compute all missing l2a for l1c products between date_min and date_max
+        If reprocess = True reprocess already processed products
+        
         """
+        if reprocess:
+            if p_60m_missing:
+                l2a_remove_list = [product.identifier for product in self.l2a.filter_dates(date_min = date_min, date_max = date_max) if not L2aProduct(product.identifier).b01_60m]
+            else:
+                l2a_remove_list = [product.identifier for product in self.l2a.filter_dates(date_min = date_min, date_max = date_max)]
+            if l2a_remove_list:
+                self.remove_l2a(l2a_remove_list)
         l1c_process_list =  []
         l1c_process_list.append(list(p.identifier for p in self.l2a_missings.filter_dates(date_min = date_min, date_max = date_max)))
         l1c_process_list = list(chain.from_iterable(l1c_process_list))
@@ -556,6 +566,8 @@ class Tile:
         l2a_res = False
         if l1c_process_list:
             l2a_res = l2a_multiprocessing(l1c_process_list, nb_proc=nb_proc)
+            
+        
         
     #~ def compute_cloudmasks(self,
                            #~ version: str = "cm001",
@@ -1031,12 +1043,15 @@ class Tile:
                 logger.info("Removing: {}".format(l1c.path))
                 shutil.rmtree(str(l1c.path))
                 
-    def remove_l2a(self):
+    def remove_l2a(self, 
+                   identifier_list: list = []):
         """
         Remove l2a files
         """
-        for product in self.l2a:
-            l2a = L2aProduct(product.identifier)
+        if not identifier_list:
+            identifier_list = [product.identifier for product in self.l2a]
+        for identifier in identifier_list:
+            l2a = L2aProduct(identifier)
             if l2a.path.is_symlink():
                 l2a_path = os.readlink(str(l2a.path))
                 logger.info("Removing: {}".format(l2a_path))
@@ -1044,7 +1059,8 @@ class Tile:
                 logger.info("Removing symlink: {}".format(l2a.path))
                 l2a.path.unlink()
             else:
-                #~ l2a_path = os.readlink(str(l2a.path))
                 logger.info("Removing: {}".format(l2a.path))
                 shutil.rmtree(str(l2a.path))
+            logger.info("Removing: {}".format(l2a.path))
+        logger.info("Removed: {} products".format(len(identifier_list)))
 
-- 
GitLab