From 5bc3b8c55a85cbdf854d7467cfc613b7028d0ffb Mon Sep 17 00:00:00 2001 From: ptresson <paul.tresson@ird.fr> Date: Thu, 17 Oct 2024 11:41:00 +0200 Subject: [PATCH] move some parameter handling in base function for encoder --- encoder.py | 159 +------------------------ tests/test_encoder.py | 1 + utils/algo.py | 261 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 265 insertions(+), 156 deletions(-) create mode 100644 utils/algo.py diff --git a/encoder.py b/encoder.py index 17ec23c..bb9027a 100644 --- a/encoder.py +++ b/encoder.py @@ -57,6 +57,7 @@ from .utils.misc import (QGISLogHandler, log_parameters_to_csv, ) from .utils.trch import quantize_model +from .utils.algo import IAMAPAlgorithm from .tg.datasets import RasterDataset from .tg.utils import stack_samples, BoundingBox @@ -66,7 +67,7 @@ from .tg.transforms import AugmentationSequential -class EncoderAlgorithm(QgsProcessingAlgorithm): +class EncoderAlgorithm(IAMAPAlgorithm): """ """ @@ -772,23 +773,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): feedback.pushInfo( f'FEEDBACK :\n{feedback}') - rlayer = self.parameterAsRasterLayer( - parameters, self.INPUT, context) - - if rlayer is None: - raise QgsProcessingException( - self.invalidRasterError(parameters, self.INPUT)) - - self.selected_bands = self.parameterAsInts( - parameters, self.BANDS, context) - - if len(self.selected_bands) == 0: - self.selected_bands = list(range(1, rlayer.bandCount()+1)) - - if max(self.selected_bands) > rlayer.bandCount(): - raise QgsProcessingException( - self.tr("The chosen bands exceed the largest band number!") - ) + self.process_geo_parameters(parameters, context, feedback) ckpt_path = self.parameterAsFile( parameters, self.CKPT, context) @@ -814,12 +799,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): parameters, self.STRIDE, context) self.size = self.parameterAsInt( parameters, self.SIZE, context) - res = self.parameterAsDouble( - parameters, self.RESOLUTION, context) - crs = self.parameterAsCrs( - parameters, self.CRS, context) - extent = self.parameterAsExtent( - parameters, self.EXTENT, context) self.quantization = self.parameterAsBoolean( parameters, self.QUANT, context) self.use_gpu = self.parameterAsBoolean( @@ -842,133 +821,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): self.remove_tmp_files = self.parameterAsBoolean( parameters, self.REMOVE_TEMP_FILES, context) - rlayer_data_provider = rlayer.dataProvider() - - # handle crs - if crs is None or not crs.isValid(): - crs = rlayer.crs() - feedback.pushInfo( - f'Layer CRS unit is {crs.mapUnits()}') # 0 for meters, 6 for degrees, 9 for unknown - feedback.pushInfo( - f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}') - if crs.mapUnits() == Qgis.DistanceUnit.Degrees: - crs = self.estimate_utm_crs(rlayer.extent()) - - # target crs should use meters as units - if crs.mapUnits() != Qgis.DistanceUnit.Meters: - feedback.pushInfo( - f'Layer CRS unit is {crs.mapUnits()}') - feedback.pushInfo( - f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}') - raise QgsProcessingException( - self.tr("Only support CRS with the units as meters") - ) - - # 0 for meters, 6 for degrees, 9 for unknown - UNIT_METERS = 0 - UNIT_DEGREES = 6 - if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees: - layer_units = 'degrees' - else: - layer_units = 'meters' - # if res is not provided, get res info from rlayer - if np.isnan(res) or res == 0: - res = rlayer.rasterUnitsPerPixelX() # rasterUnitsPerPixelY() is negative - target_units = layer_units - else: - # when given res in meters by users, convert crs to utm if the original crs unit is degree - if crs.mapUnits() != UNIT_METERS: # Qgis.DistanceUnit.Meters: - if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees: - # estimate utm crs based on layer extent - crs = self.estimate_utm_crs(rlayer.extent()) - else: - raise QgsProcessingException( - f"Resampling of image with the CRS of {crs.authid()} in meters is not supported.") - target_units = 'meters' - # else: - # res = (rlayer_extent.xMaximum() - - # rlayer_extent.xMinimum()) / rlayer.width() - self.res = res - - # handle extent - if extent.isNull(): - extent = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context) - extent_crs = rlayer.crs() - else: - if extent.isEmpty(): - raise QgsProcessingException( - self.tr("The extent for processing can not be empty!")) - extent_crs = self.parameterAsExtentCrs( - parameters, self.EXTENT, context) - # if extent crs != target crs, convert it to target crs - if extent_crs != crs: - transform = QgsCoordinateTransform( - extent_crs, crs, context.transformContext()) - # extent = transform.transformBoundingBox(extent) - # to ensure coverage of the transformed extent - # convert extent to polygon, transform polygon, then get boundingBox of the new polygon - extent_polygon = QgsGeometry.fromRect(extent) - extent_polygon.transform(transform) - extent = extent_polygon.boundingBox() - extent_crs = crs - - # check intersects between extent and rlayer_extent - if rlayer.crs() != crs: - transform = QgsCoordinateTransform( - rlayer.crs(), crs, context.transformContext()) - rlayer_extent = transform.transformBoundingBox( - rlayer.extent()) - else: - rlayer_extent = rlayer.extent() - if not rlayer_extent.intersects(extent): - raise QgsProcessingException( - self.tr("The extent for processing is not intersected with the input image!")) - - feedback.pushInfo(f'backbne type : {self.backbone_name}') - - img_width_in_extent = round( - (extent.xMaximum() - extent.xMinimum())/self.res) - img_height_in_extent = round( - (extent.yMaximum() - extent.yMinimum())/self.res) - - # Send some information to the user - feedback.pushInfo( - f'Layer path: {rlayer_data_provider.dataSourceUri()}') - # feedback.pushInfo( - # f'Layer band scale: {rlayer_data_provider.bandScale(self.selected_bands[0])}') - feedback.pushInfo(f'Layer name: {rlayer.name()}') - if rlayer.crs().authid(): - feedback.pushInfo(f'Layer CRS: {rlayer.crs().authid()}') - else: - feedback.pushInfo( - f'Layer CRS in WKT format: {rlayer.crs().toWkt()}') - feedback.pushInfo( - f'Layer pixel size: {rlayer.rasterUnitsPerPixelX()}, {rlayer.rasterUnitsPerPixelY()} {layer_units}') - - feedback.pushInfo(f'Bands selected: {self.selected_bands}') - - if crs.authid(): - feedback.pushInfo(f'Target CRS: {crs.authid()}') - else: - feedback.pushInfo(f'Target CRS in WKT format: {crs.toWkt()}') - # feedback.pushInfo('Band number is {}'.format(rlayer.bandCount())) - # feedback.pushInfo('Band name is {}'.format(rlayer.bandName(1))) - feedback.pushInfo(f'Target resolution: {self.res} {target_units}') - # feedback.pushInfo('Layer display band name is {}'.format( - # rlayer.dataProvider().displayBandName(1))) - feedback.pushInfo( - (f'Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},' - f'miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}')) - feedback.pushInfo( - (f'Processing image size: (width {img_width_in_extent}, ' - f'height {img_height_in_extent})')) - - # feedback.pushInfo( - # f'SAM Image Size: {self.sam_model.image_encoder.img_size}') - - self.rlayer_path = rlayer.dataProvider().dataSourceUri() - self.rlayer_dir = os.path.dirname(self.rlayer_path) - self.rlayer_name = os.path.basename(self.rlayer_path) # get mean and sd of dataset from raster metadata feedback.pushInfo(f'Computing means and sds for normalization') @@ -980,11 +832,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): feedback.pushInfo(f'Means for normalization: {self.means}') feedback.pushInfo(f'Std. dev. for normalization: {self.sds}') - ## passing parameters to self once everything has been processed - self.extent = extent - self.rlayer = rlayer - self.crs = crs - # used to handle any thread-sensitive cleanup which is required by the algorithm. def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]: diff --git a/tests/test_encoder.py b/tests/test_encoder.py index c68e4f6..4079660 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -63,6 +63,7 @@ class TestEncoderAlgorithm(unittest.TestCase): # '94658648037138c64159ae457c3928dd', # '496ac2e9b92f62d16c8c8f1a0fa07009', # 'a6230b57bcf0050aa6f21107a16a5548', + '0fb32cc57a0dd427d9f0165ec6d5418f', '48c3a78773dbc2c4c7bb7885409284ab', '431e034b842129679b99a067f2bd3ba4', '60153535214eaa44458db4e297af72b9', diff --git a/utils/algo.py b/utils/algo.py new file mode 100644 index 0000000..607838d --- /dev/null +++ b/utils/algo.py @@ -0,0 +1,261 @@ +import os +import tempfile +import numpy as np +from pathlib import Path +from typing import Dict, Any +from qgis.core import (Qgis, + QgsGeometry, + QgsCoordinateTransform, + QgsProcessingException, + QgsProcessingAlgorithm, + QgsProcessingParameterRasterLayer, + QgsProcessingParameterFolderDestination, + QgsProcessingParameterBand, + QgsProcessingParameterNumber, + QgsProcessingParameterEnum, + QgsProcessingParameterExtent, + QgsProcessingParameterCrs, + QgsProcessingParameterDefinition, + ) + + + +class IAMAPAlgorithm(QgsProcessingAlgorithm): + """ + """ + + INPUT = 'INPUT' + BANDS = 'BANDS' + EXTENT = 'EXTENT' + OUTPUT = 'OUTPUT' + RESOLUTION = 'RESOLUTION' + CRS = 'CRS' + TMP_DIR = 'iamap_tmp' + + + def initAlgorithm(self, config=None): + """ + Here we define the inputs and output of the algorithm, along + with some other properties. + """ + cwd = Path(__file__).parent.absolute() + tmp_wd = os.path.join(tempfile.gettempdir(), self.TMP_DIR) + + self.addParameter( + QgsProcessingParameterRasterLayer( + name=self.INPUT, + description=self.tr( + 'Input raster layer or image file path'), + defaultValue=os.path.join(cwd,'assets','test.tif'), + ), + ) + + self.addParameter( + QgsProcessingParameterBand( + name=self.BANDS, + description=self.tr('Selected Bands (defaults to all bands selected)'), + defaultValue = None, + parentLayerParameterName=self.INPUT, + optional=True, + allowMultiple=True, + ) + ) + + crs_param = QgsProcessingParameterCrs( + name=self.CRS, + description=self.tr('Target CRS (default to original CRS)'), + optional=True, + ) + + res_param = QgsProcessingParameterNumber( + name=self.RESOLUTION, + description=self.tr( + 'Target resolution in meters (default to native resolution)'), + type=QgsProcessingParameterNumber.Double, + optional=True, + minValue=0, + maxValue=100000 + ) + + + self.addParameter( + QgsProcessingParameterExtent( + name=self.EXTENT, + description=self.tr( + 'Processing extent (default to the entire image)'), + optional=True + ) + ) + + self.addParameter( + QgsProcessingParameterFolderDestination( + self.OUTPUT, + self.tr( + "Output directory (choose the location that the image features will be saved)"), + defaultValue=tmp_wd, + ) + ) + + self.out_dtype_opt = ['float32', 'int8'] + dtype_param = QgsProcessingParameterEnum( + name=self.OUT_DTYPE, + description=self.tr( + 'Data type of exported features (int8 saves space)'), + options=self.out_dtype_opt, + defaultValue=0, + ) + + + for param in ( + crs_param, + res_param, + dtype_param, + ): + param.setFlags( + param.flags() | QgsProcessingParameterDefinition.FlagAdvanced) + self.addParameter(param) + + + def process_geo_parameters(self,parameters, context, feedback): + """ + Handle geographic parameters that are common to all algorithms (CRS, resolution, extent, selected bands). + """ + + rlayer = self.parameterAsRasterLayer( + parameters, self.INPUT, context) + + if rlayer is None: + raise QgsProcessingException( + self.invalidRasterError(parameters, self.INPUT)) + + self.rlayer_path = rlayer.dataProvider().dataSourceUri() + self.rlayer_dir = os.path.dirname(self.rlayer_path) + self.rlayer_name = os.path.basename(self.rlayer_path) + + self.selected_bands = self.parameterAsInts( + parameters, self.BANDS, context) + + if len(self.selected_bands) == 0: + self.selected_bands = list(range(1, rlayer.bandCount()+1)) + + if max(self.selected_bands) > rlayer.bandCount(): + raise QgsProcessingException( + self.tr("The chosen bands exceed the largest band number!") + ) + res = self.parameterAsDouble( + parameters, self.RESOLUTION, context) + crs = self.parameterAsCrs( + parameters, self.CRS, context) + extent = self.parameterAsExtent( + parameters, self.EXTENT, context) + + # handle crs + if crs is None or not crs.isValid(): + crs = rlayer.crs() + feedback.pushInfo( + f'Layer CRS unit is {crs.mapUnits()}') # 0 for meters, 6 for degrees, 9 for unknown + feedback.pushInfo( + f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}') + if crs.mapUnits() == Qgis.DistanceUnit.Degrees: + crs = self.estimate_utm_crs(rlayer.extent()) + + # target crs should use meters as units + if crs.mapUnits() != Qgis.DistanceUnit.Meters: + feedback.pushInfo( + f'Layer CRS unit is {crs.mapUnits()}') + feedback.pushInfo( + f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}') + raise QgsProcessingException( + self.tr("Only support CRS with the units as meters") + ) + + # 0 for meters, 6 for degrees, 9 for unknown + UNIT_METERS = 0 + UNIT_DEGREES = 6 + if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees: + layer_units = 'degrees' + else: + layer_units = 'meters' + # if res is not provided, get res info from rlayer + if np.isnan(res) or res == 0: + res = rlayer.rasterUnitsPerPixelX() # rasterUnitsPerPixelY() is negative + target_units = layer_units + else: + # when given res in meters by users, convert crs to utm if the original crs unit is degree + if crs.mapUnits() != UNIT_METERS: # Qgis.DistanceUnit.Meters: + if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees: + # estimate utm crs based on layer extent + crs = self.estimate_utm_crs(rlayer.extent()) + else: + raise QgsProcessingException( + f"Resampling of image with the CRS of {crs.authid()} in meters is not supported.") + target_units = 'meters' + # else: + # res = (rlayer_extent.xMaximum() - + # rlayer_extent.xMinimum()) / rlayer.width() + + # handle extent + if extent.isNull(): + extent = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context) + extent_crs = rlayer.crs() + else: + if extent.isEmpty(): + raise QgsProcessingException( + self.tr("The extent for processing can not be empty!")) + extent_crs = self.parameterAsExtentCrs( + parameters, self.EXTENT, context) + # if extent crs != target crs, convert it to target crs + if extent_crs != crs: + transform = QgsCoordinateTransform( + extent_crs, crs, context.transformContext()) + # extent = transform.transformBoundingBox(extent) + # to ensure coverage of the transformed extent + # convert extent to polygon, transform polygon, then get boundingBox of the new polygon + extent_polygon = QgsGeometry.fromRect(extent) + extent_polygon.transform(transform) + extent = extent_polygon.boundingBox() + extent_crs = crs + + # check intersects between extent and rlayer_extent + if rlayer.crs() != crs: + transform = QgsCoordinateTransform( + rlayer.crs(), crs, context.transformContext()) + rlayer_extent = transform.transformBoundingBox( + rlayer.extent()) + else: + rlayer_extent = rlayer.extent() + if not rlayer_extent.intersects(extent): + raise QgsProcessingException( + self.tr("The extent for processing is not intersected with the input image!")) + + + img_width_in_extent = round( + (extent.xMaximum() - extent.xMinimum())/res) + img_height_in_extent = round( + (extent.yMaximum() - extent.yMinimum())/res) + + feedback.pushInfo( + (f'Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},' + f'miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}')) + feedback.pushInfo( + (f'Processing image size: (width {img_width_in_extent}, ' + f'height {img_height_in_extent})')) + + # Send some information to the user + feedback.pushInfo( + f'Layer path: {rlayer.dataProvider().dataSourceUri()}') + # feedback.pushInfo( + # f'Layer band scale: {rlayer_data_provider.bandScale(self.selected_bands[0])}') + feedback.pushInfo(f'Layer name: {rlayer.name()}') + + feedback.pushInfo(f'Bands selected: {self.selected_bands}') + + self.extent = extent + self.rlayer = rlayer + self.crs = crs + self.res = res + + + # used to handle any thread-sensitive cleanup which is required by the algorithm. + def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]: + return {} -- GitLab