From 9ffacb6d0fc7fd24cd85e0ed16e1ad5c33b1fbfb Mon Sep 17 00:00:00 2001
From: "julie.mlls" <julie.mollies.pro@gmail.com>
Date: Tue, 10 Dec 2024 09:21:54 +0400
Subject: [PATCH] masque nuage modifie

---
 sen2chain/cloud_mask.py | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/sen2chain/cloud_mask.py b/sen2chain/cloud_mask.py
index 26ed50b..07073af 100755
--- a/sen2chain/cloud_mask.py
+++ b/sen2chain/cloud_mask.py
@@ -298,6 +298,7 @@ def create_cloud_mask_v2(
 
     out_temp_path = Path(Config().get("temp_path"))
     out_dilate = str(out_temp_path / (out_path.stem + "_tmp_dilate.tif"))
+    out_temp2 = str(out_temp_path / (out_path.stem + "_tmp_dilate_2.tif"))
 
     CLD_seuil = 25
     with rasterio.open(str(cloud_mask)) as cld_src:
@@ -341,6 +342,7 @@ def create_cloud_mask_v2(
 
     # Save to JP2000
     src_ds = gdal.Open(out_dilate)
+    src_ds = gdal.Translate(out_temp2, src_ds, outputType=gdal.GDT_Byte)
     driver = gdal.GetDriverByName("JP2OpenJPEG")
     dst_ds = driver.CreateCopy(
         str(out_path),
@@ -349,8 +351,8 @@ def create_cloud_mask_v2(
     )
     dst_ds = None
     src_ds = None
-
     os.remove(out_dilate)
+    os.remove(out_temp2)
     logger.info("Done: {}".format(out_path.name))
 
 
@@ -373,6 +375,7 @@ def create_cloud_mask_b11(
 
     out_temp_path = Path(Config().get("temp_path"))
     out_mask = str(out_temp_path / (out_path.stem + "_tmp_mask.tif"))
+    out_temp2 = str(out_temp_path / (out_path.stem + "_tmp_mask_2.tif"))
 
     b11_seuil = 1500
     with rasterio.open(str(b11_path)) as b11_src:
@@ -439,6 +442,7 @@ def create_cloud_mask_b11(
 
     # Save to JP2000
     src_ds = gdal.Open(out_mask)
+    src_ds = gdal.Translate(out_temp2, src_ds, outputType=gdal.GDT_Byte)
     driver = gdal.GetDriverByName("JP2OpenJPEG")
     dst_ds = driver.CreateCopy(
         str(out_path),
@@ -449,6 +453,7 @@ def create_cloud_mask_b11(
     src_ds = None
 
     os.remove(out_mask)
+    os.remov(out_temp2)
     logger.info("Done: {}".format(out_path.name))
 
 
@@ -456,7 +461,7 @@ def create_cloud_mask_v003(
     cloud_mask: Union[str, pathlib.PosixPath],
     out_path="./CM003.jp2",
     probability: int = 1,
-    iterations: int = 5,
+    iterations: int = 1,
 ) -> None:
     
     """
@@ -472,6 +477,7 @@ def create_cloud_mask_v003(
 
     out_temp_path = Path(Config().get("temp_path"))
     out_temp = str(out_temp_path / (out_path.stem + "_tmp_cm003.tif"))
+    out_temp2 = str(out_temp_path / (out_path.stem + "_tmp_cm003_2.tif"))
 
     with rasterio.open(str(cloud_mask)) as cld_src:
         cld_profile = cld_src.profile
@@ -500,6 +506,7 @@ def create_cloud_mask_v003(
 
     # Save to JP2000
     src_ds = gdal.Open(out_temp)
+    src_ds = gdal.Translate(out_temp2, src_ds, outputType=gdal.GDT_Byte)
     driver = gdal.GetDriverByName("JP2OpenJPEG")
     dst_ds = driver.CreateCopy(
         str(out_path),
@@ -510,13 +517,14 @@ def create_cloud_mask_v003(
     src_ds = None
 
     os.remove(out_temp)
+    os.remove(out_temp2)
     logger.info("Done: {}".format(out_path.name))
 
 
 def create_cloud_mask_v004(
     scl_path: Union[str, pathlib.PosixPath],
     out_path="./CM004.jp2",
-    iterations: int = 5,
+    iterations: int = 1,
     cld_shad: bool = True,
     cld_med_prob: bool = True,
     cld_hi_prob: bool = True,
@@ -537,6 +545,7 @@ def create_cloud_mask_v004(
 
     out_temp_path = Path(Config().get("temp_path"))
     out_temp = str(out_temp_path / (out_path.stem + "_tmp_cm004.tif"))
+    out_temp2 = str(out_temp_path / (out_path.stem + "_tmp_cm004_2.tif"))
 
     with rasterio.open(str(scl_path)) as scl_src:
         scl_profile = scl_src.profile
@@ -575,6 +584,8 @@ def create_cloud_mask_v004(
 
     # Save to JP2000
     src_ds = gdal.Open(out_temp)
+    src_ds = gdal.Translate(out_temp2, src_ds, outputType=gdal.GDT_Byte)
+    
     driver = gdal.GetDriverByName("JP2OpenJPEG")
     dst_ds = driver.CreateCopy(
         str(out_path),
@@ -583,6 +594,8 @@ def create_cloud_mask_v004(
     )
     dst_ds = None
     src_ds = None
-
+    
     os.remove(out_temp)
-    logger.info("Done: {}".format(out_path.name))
+    os.remove(out_temp2)
+    
+    logger.info("Done : {}".format(out_path.name))
-- 
GitLab