From 5bc3b8c55a85cbdf854d7467cfc613b7028d0ffb Mon Sep 17 00:00:00 2001
From: ptresson <paul.tresson@ird.fr>
Date: Thu, 17 Oct 2024 11:41:00 +0200
Subject: [PATCH] move some parameter handling in base function for encoder

---
 encoder.py            | 159 +------------------------
 tests/test_encoder.py |   1 +
 utils/algo.py         | 261 ++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 265 insertions(+), 156 deletions(-)
 create mode 100644 utils/algo.py

diff --git a/encoder.py b/encoder.py
index 17ec23c..bb9027a 100644
--- a/encoder.py
+++ b/encoder.py
@@ -57,6 +57,7 @@ from .utils.misc import (QGISLogHandler,
                          log_parameters_to_csv,
                          )
 from .utils.trch import quantize_model
+from .utils.algo import IAMAPAlgorithm
 
 from .tg.datasets import RasterDataset
 from .tg.utils import stack_samples, BoundingBox
@@ -66,7 +67,7 @@ from .tg.transforms import AugmentationSequential
 
 
 
-class EncoderAlgorithm(QgsProcessingAlgorithm):
+class EncoderAlgorithm(IAMAPAlgorithm):
     """
     """
 
@@ -772,23 +773,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
         feedback.pushInfo(
                 f'FEEDBACK :\n{feedback}')
 
-        rlayer = self.parameterAsRasterLayer(
-            parameters, self.INPUT, context)
-        
-        if rlayer is None:
-            raise QgsProcessingException(
-                self.invalidRasterError(parameters, self.INPUT))
-
-        self.selected_bands = self.parameterAsInts(
-            parameters, self.BANDS, context)
-
-        if len(self.selected_bands) == 0:
-            self.selected_bands = list(range(1, rlayer.bandCount()+1))
-
-        if max(self.selected_bands) > rlayer.bandCount():
-            raise QgsProcessingException(
-                self.tr("The chosen bands exceed the largest band number!")
-            )
+        self.process_geo_parameters(parameters, context, feedback)
 
         ckpt_path = self.parameterAsFile(
             parameters, self.CKPT, context)
@@ -814,12 +799,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
             parameters, self.STRIDE, context)
         self.size = self.parameterAsInt(
             parameters, self.SIZE, context)
-        res = self.parameterAsDouble(
-            parameters, self.RESOLUTION, context)
-        crs = self.parameterAsCrs(
-            parameters, self.CRS, context)
-        extent = self.parameterAsExtent(
-            parameters, self.EXTENT, context)
         self.quantization = self.parameterAsBoolean(
             parameters, self.QUANT, context)
         self.use_gpu = self.parameterAsBoolean(
@@ -842,133 +821,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
         self.remove_tmp_files = self.parameterAsBoolean(
             parameters, self.REMOVE_TEMP_FILES, context)
 
-        rlayer_data_provider = rlayer.dataProvider()
-
-        # handle crs
-        if crs is None or not crs.isValid():
-            crs = rlayer.crs()
-            feedback.pushInfo(
-                f'Layer CRS unit is {crs.mapUnits()}')  # 0 for meters, 6 for degrees, 9 for unknown
-            feedback.pushInfo(
-                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
-            if crs.mapUnits() == Qgis.DistanceUnit.Degrees:
-                crs = self.estimate_utm_crs(rlayer.extent())
-
-        # target crs should use meters as units
-        if crs.mapUnits() != Qgis.DistanceUnit.Meters:
-            feedback.pushInfo(
-                f'Layer CRS unit is {crs.mapUnits()}')
-            feedback.pushInfo(
-                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
-            raise QgsProcessingException(
-                self.tr("Only support CRS with the units as meters")
-            )
-
-        # 0 for meters, 6 for degrees, 9 for unknown
-        UNIT_METERS = 0
-        UNIT_DEGREES = 6
-        if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees:
-            layer_units = 'degrees'
-        else:
-            layer_units = 'meters'
-        # if res is not provided, get res info from rlayer
-        if np.isnan(res) or res == 0:
-            res = rlayer.rasterUnitsPerPixelX()  # rasterUnitsPerPixelY() is negative
-            target_units = layer_units
-        else:
-            # when given res in meters by users, convert crs to utm if the original crs unit is degree
-            if crs.mapUnits() != UNIT_METERS: # Qgis.DistanceUnit.Meters:
-                if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees:
-                    # estimate utm crs based on layer extent
-                    crs = self.estimate_utm_crs(rlayer.extent())
-                else:
-                    raise QgsProcessingException(
-                        f"Resampling of image with the CRS of {crs.authid()} in meters is not supported.")
-            target_units = 'meters'
-            # else:
-            #     res = (rlayer_extent.xMaximum() -
-            #            rlayer_extent.xMinimum()) / rlayer.width()
-        self.res = res
-
-        # handle extent
-        if extent.isNull():
-            extent = rlayer.extent()  # QgsProcessingUtils.combineLayerExtents(layers, crs, context)
-            extent_crs = rlayer.crs()
-        else:
-            if extent.isEmpty():
-                raise QgsProcessingException(
-                    self.tr("The extent for processing can not be empty!"))
-            extent_crs = self.parameterAsExtentCrs(
-                parameters, self.EXTENT, context)
-        # if extent crs != target crs, convert it to target crs
-        if extent_crs != crs:
-            transform = QgsCoordinateTransform(
-                extent_crs, crs, context.transformContext())
-            # extent = transform.transformBoundingBox(extent)
-            # to ensure coverage of the transformed extent
-            # convert extent to polygon, transform polygon, then get boundingBox of the new polygon
-            extent_polygon = QgsGeometry.fromRect(extent)
-            extent_polygon.transform(transform)
-            extent = extent_polygon.boundingBox()
-            extent_crs = crs
-
-        # check intersects between extent and rlayer_extent
-        if rlayer.crs() != crs:
-            transform = QgsCoordinateTransform(
-                rlayer.crs(), crs, context.transformContext())
-            rlayer_extent = transform.transformBoundingBox(
-                rlayer.extent())
-        else:
-            rlayer_extent = rlayer.extent()
-        if not rlayer_extent.intersects(extent):
-            raise QgsProcessingException(
-                self.tr("The extent for processing is not intersected with the input image!"))
-
-        feedback.pushInfo(f'backbne type : {self.backbone_name}')
-        
-        img_width_in_extent = round(
-            (extent.xMaximum() - extent.xMinimum())/self.res)
-        img_height_in_extent = round(
-            (extent.yMaximum() - extent.yMinimum())/self.res)
-
-        # Send some information to the user
-        feedback.pushInfo(
-            f'Layer path: {rlayer_data_provider.dataSourceUri()}')
-        # feedback.pushInfo(
-        #     f'Layer band scale: {rlayer_data_provider.bandScale(self.selected_bands[0])}')
-        feedback.pushInfo(f'Layer name: {rlayer.name()}')
-        if rlayer.crs().authid():
-            feedback.pushInfo(f'Layer CRS: {rlayer.crs().authid()}')
-        else:
-            feedback.pushInfo(
-                f'Layer CRS in WKT format: {rlayer.crs().toWkt()}')
-        feedback.pushInfo(
-            f'Layer pixel size: {rlayer.rasterUnitsPerPixelX()}, {rlayer.rasterUnitsPerPixelY()} {layer_units}')
-
-        feedback.pushInfo(f'Bands selected: {self.selected_bands}')
-
-        if crs.authid():
-            feedback.pushInfo(f'Target CRS: {crs.authid()}')
-        else:
-            feedback.pushInfo(f'Target CRS in WKT format: {crs.toWkt()}')
-        # feedback.pushInfo('Band number is {}'.format(rlayer.bandCount()))
-        # feedback.pushInfo('Band name is {}'.format(rlayer.bandName(1)))
-        feedback.pushInfo(f'Target resolution: {self.res} {target_units}')
-        # feedback.pushInfo('Layer display band name is {}'.format(
-        #     rlayer.dataProvider().displayBandName(1)))
-        feedback.pushInfo(
-            (f'Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},'
-             f'miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}'))
-        feedback.pushInfo(
-            (f'Processing image size: (width {img_width_in_extent}, '
-             f'height {img_height_in_extent})'))
-
-        # feedback.pushInfo(
-        #     f'SAM Image Size: {self.sam_model.image_encoder.img_size}')
-
-        self.rlayer_path = rlayer.dataProvider().dataSourceUri()
-        self.rlayer_dir = os.path.dirname(self.rlayer_path)
-        self.rlayer_name = os.path.basename(self.rlayer_path)
 
         # get mean and sd of dataset from raster metadata
         feedback.pushInfo(f'Computing means and sds for normalization')
@@ -980,11 +832,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
         feedback.pushInfo(f'Means for normalization: {self.means}')
         feedback.pushInfo(f'Std. dev. for normalization: {self.sds}')
 
-        ## passing parameters to self once everything has been processed
-        self.extent = extent
-        self.rlayer = rlayer
-        self.crs = crs
-
 
     # used to handle any thread-sensitive cleanup which is required by the algorithm.
     def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
diff --git a/tests/test_encoder.py b/tests/test_encoder.py
index c68e4f6..4079660 100644
--- a/tests/test_encoder.py
+++ b/tests/test_encoder.py
@@ -63,6 +63,7 @@ class TestEncoderAlgorithm(unittest.TestCase):
                 # '94658648037138c64159ae457c3928dd',
                 # '496ac2e9b92f62d16c8c8f1a0fa07009',
                 # 'a6230b57bcf0050aa6f21107a16a5548',
+                '0fb32cc57a0dd427d9f0165ec6d5418f',
                 '48c3a78773dbc2c4c7bb7885409284ab',
                 '431e034b842129679b99a067f2bd3ba4',
                 '60153535214eaa44458db4e297af72b9',
diff --git a/utils/algo.py b/utils/algo.py
new file mode 100644
index 0000000..607838d
--- /dev/null
+++ b/utils/algo.py
@@ -0,0 +1,261 @@
+import os
+import tempfile
+import numpy as np
+from pathlib import Path
+from typing import Dict, Any
+from qgis.core import (Qgis,
+                       QgsGeometry,
+                       QgsCoordinateTransform,
+                       QgsProcessingException,
+                       QgsProcessingAlgorithm,
+                       QgsProcessingParameterRasterLayer,
+                       QgsProcessingParameterFolderDestination,
+                       QgsProcessingParameterBand,
+                       QgsProcessingParameterNumber,
+                       QgsProcessingParameterEnum,
+                       QgsProcessingParameterExtent,
+                       QgsProcessingParameterCrs,
+                       QgsProcessingParameterDefinition,
+                       )
+
+
+
+class IAMAPAlgorithm(QgsProcessingAlgorithm):
+    """
+    """
+
+    INPUT = 'INPUT'
+    BANDS = 'BANDS'
+    EXTENT = 'EXTENT'
+    OUTPUT = 'OUTPUT'
+    RESOLUTION = 'RESOLUTION'
+    CRS = 'CRS'
+    TMP_DIR = 'iamap_tmp'
+    
+
+    def initAlgorithm(self, config=None):
+        """
+        Here we define the inputs and output of the algorithm, along
+        with some other properties.
+        """
+        cwd = Path(__file__).parent.absolute()
+        tmp_wd = os.path.join(tempfile.gettempdir(), self.TMP_DIR)
+
+        self.addParameter(
+            QgsProcessingParameterRasterLayer(
+                name=self.INPUT,
+                description=self.tr(
+                    'Input raster layer or image file path'),
+            defaultValue=os.path.join(cwd,'assets','test.tif'),
+            ),
+        )
+
+        self.addParameter(
+            QgsProcessingParameterBand(
+                name=self.BANDS,
+                description=self.tr('Selected Bands (defaults to all bands selected)'),
+                defaultValue = None, 
+                parentLayerParameterName=self.INPUT,
+                optional=True,
+                allowMultiple=True,
+            )
+        )
+
+        crs_param = QgsProcessingParameterCrs(
+            name=self.CRS,
+            description=self.tr('Target CRS (default to original CRS)'),
+            optional=True,
+        )
+
+        res_param = QgsProcessingParameterNumber(
+            name=self.RESOLUTION,
+            description=self.tr(
+                'Target resolution in meters (default to native resolution)'),
+            type=QgsProcessingParameterNumber.Double,
+            optional=True,
+            minValue=0,
+            maxValue=100000
+        )
+
+
+        self.addParameter(
+            QgsProcessingParameterExtent(
+                name=self.EXTENT,
+                description=self.tr(
+                    'Processing extent (default to the entire image)'),
+                optional=True
+            )
+        )
+
+        self.addParameter(
+            QgsProcessingParameterFolderDestination(
+                self.OUTPUT,
+                self.tr(
+                    "Output directory (choose the location that the image features will be saved)"),
+            defaultValue=tmp_wd,
+            )
+        )
+
+        self.out_dtype_opt = ['float32', 'int8']
+        dtype_param = QgsProcessingParameterEnum(
+                name=self.OUT_DTYPE,
+                description=self.tr(
+                    'Data type of exported features (int8 saves space)'),
+                options=self.out_dtype_opt,
+                defaultValue=0,
+                )
+
+
+        for param in (
+                crs_param, 
+                res_param, 
+                dtype_param,
+                ):
+            param.setFlags(
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+            self.addParameter(param)
+
+
+    def process_geo_parameters(self,parameters, context, feedback):
+        """
+        Handle geographic parameters that are common to all algorithms (CRS, resolution, extent, selected bands).
+        """
+
+        rlayer = self.parameterAsRasterLayer(
+            parameters, self.INPUT, context)
+        
+        if rlayer is None:
+            raise QgsProcessingException(
+                self.invalidRasterError(parameters, self.INPUT))
+
+        self.rlayer_path = rlayer.dataProvider().dataSourceUri()
+        self.rlayer_dir = os.path.dirname(self.rlayer_path)
+        self.rlayer_name = os.path.basename(self.rlayer_path)
+
+        self.selected_bands = self.parameterAsInts(
+            parameters, self.BANDS, context)
+
+        if len(self.selected_bands) == 0:
+            self.selected_bands = list(range(1, rlayer.bandCount()+1))
+
+        if max(self.selected_bands) > rlayer.bandCount():
+            raise QgsProcessingException(
+                self.tr("The chosen bands exceed the largest band number!")
+            )
+        res = self.parameterAsDouble(
+            parameters, self.RESOLUTION, context)
+        crs = self.parameterAsCrs(
+            parameters, self.CRS, context)
+        extent = self.parameterAsExtent(
+            parameters, self.EXTENT, context)
+
+        # handle crs
+        if crs is None or not crs.isValid():
+            crs = rlayer.crs()
+            feedback.pushInfo(
+                f'Layer CRS unit is {crs.mapUnits()}')  # 0 for meters, 6 for degrees, 9 for unknown
+            feedback.pushInfo(
+                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
+            if crs.mapUnits() == Qgis.DistanceUnit.Degrees:
+                crs = self.estimate_utm_crs(rlayer.extent())
+
+        # target crs should use meters as units
+        if crs.mapUnits() != Qgis.DistanceUnit.Meters:
+            feedback.pushInfo(
+                f'Layer CRS unit is {crs.mapUnits()}')
+            feedback.pushInfo(
+                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
+            raise QgsProcessingException(
+                self.tr("Only support CRS with the units as meters")
+            )
+
+        # 0 for meters, 6 for degrees, 9 for unknown
+        UNIT_METERS = 0
+        UNIT_DEGREES = 6
+        if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees:
+            layer_units = 'degrees'
+        else:
+            layer_units = 'meters'
+        # if res is not provided, get res info from rlayer
+        if np.isnan(res) or res == 0:
+            res = rlayer.rasterUnitsPerPixelX()  # rasterUnitsPerPixelY() is negative
+            target_units = layer_units
+        else:
+            # when given res in meters by users, convert crs to utm if the original crs unit is degree
+            if crs.mapUnits() != UNIT_METERS: # Qgis.DistanceUnit.Meters:
+                if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees:
+                    # estimate utm crs based on layer extent
+                    crs = self.estimate_utm_crs(rlayer.extent())
+                else:
+                    raise QgsProcessingException(
+                        f"Resampling of image with the CRS of {crs.authid()} in meters is not supported.")
+            target_units = 'meters'
+            # else:
+            #     res = (rlayer_extent.xMaximum() -
+            #            rlayer_extent.xMinimum()) / rlayer.width()
+
+        # handle extent
+        if extent.isNull():
+            extent = rlayer.extent()  # QgsProcessingUtils.combineLayerExtents(layers, crs, context)
+            extent_crs = rlayer.crs()
+        else:
+            if extent.isEmpty():
+                raise QgsProcessingException(
+                    self.tr("The extent for processing can not be empty!"))
+            extent_crs = self.parameterAsExtentCrs(
+                parameters, self.EXTENT, context)
+        # if extent crs != target crs, convert it to target crs
+        if extent_crs != crs:
+            transform = QgsCoordinateTransform(
+                extent_crs, crs, context.transformContext())
+            # extent = transform.transformBoundingBox(extent)
+            # to ensure coverage of the transformed extent
+            # convert extent to polygon, transform polygon, then get boundingBox of the new polygon
+            extent_polygon = QgsGeometry.fromRect(extent)
+            extent_polygon.transform(transform)
+            extent = extent_polygon.boundingBox()
+            extent_crs = crs
+
+        # check intersects between extent and rlayer_extent
+        if rlayer.crs() != crs:
+            transform = QgsCoordinateTransform(
+                rlayer.crs(), crs, context.transformContext())
+            rlayer_extent = transform.transformBoundingBox(
+                rlayer.extent())
+        else:
+            rlayer_extent = rlayer.extent()
+        if not rlayer_extent.intersects(extent):
+            raise QgsProcessingException(
+                self.tr("The extent for processing is not intersected with the input image!"))
+
+
+        img_width_in_extent = round(
+            (extent.xMaximum() - extent.xMinimum())/res)
+        img_height_in_extent = round(
+            (extent.yMaximum() - extent.yMinimum())/res)
+
+        feedback.pushInfo(
+            (f'Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},'
+             f'miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}'))
+        feedback.pushInfo(
+            (f'Processing image size: (width {img_width_in_extent}, '
+             f'height {img_height_in_extent})'))
+
+        # Send some information to the user
+        feedback.pushInfo(
+            f'Layer path: {rlayer.dataProvider().dataSourceUri()}')
+        # feedback.pushInfo(
+        #     f'Layer band scale: {rlayer_data_provider.bandScale(self.selected_bands[0])}')
+        feedback.pushInfo(f'Layer name: {rlayer.name()}')
+
+        feedback.pushInfo(f'Bands selected: {self.selected_bands}')
+
+        self.extent = extent
+        self.rlayer = rlayer
+        self.crs = crs
+        self.res = res
+
+
+    # used to handle any thread-sensitive cleanup which is required by the algorithm.
+    def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
+        return {}
-- 
GitLab