Skip to content
Snippets Groups Projects
Commit 7b1c2f7f authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

Merge branch 'github-actions' into dev

parents ed1fdf6d bcbfb0f7
No related branches found
No related tags found
No related merge requests found
Showing with 1232 additions and 1116 deletions
name: CI/CD Pipeline
on:
push:
branches:
- main, dev
pull_request:
branches:
- main, dev
jobs:
build:
name: ${{ matrix.os }}, Python 3.${{ matrix.python-minor-version }}, QGIS 3.${{ matrix.qgis-minor-version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
max-parallel: 6
matrix:
## Windows automatic testing is not functionnig yet
os: [ubuntu-latest , macos-latest]
# os: [ubuntu-latest , macos-latest , windows-latest]
python-minor-version: [11, 12]
qgis-minor-version: [34, 36, 38]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Miniconda
uses: conda-incubator/setup-miniconda@v3
with:
python-version: 3.${{ matrix.python-minor-version }}
channels: conda-forge
auto-update-conda: true
- name: Set up Environment and Install Dependencies
run: |
conda create -n pytest python=3.${{ matrix.python-minor-version }} qgis=3.${{ matrix.qgis-minor-version }} --yes
conda install -n pytest --file requirements.txt --yes
conda install -n pytest pytest --yes
shell: bash -el {0}
- name: Run Tests
run: |
conda run -n pytest pytest tests/
shell: bash -el {0}
...@@ -3,16 +3,24 @@ import inspect ...@@ -3,16 +3,24 @@ import inspect
cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0] cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0]
def classFactory(iface): def classFactory(iface):
from .dialogs.check_gpu import has_gpu from .dialogs.check_gpu import has_gpu
from .dialogs.packages_installer import packages_installer_dialog from .dialogs.packages_installer import packages_installer_dialog
device = has_gpu() device = has_gpu()
packages_installed_allready = packages_installer_dialog.check_required_packages_and_install_if_necessary(iface=iface, device=device) packages_installed_allready = (
packages_installer_dialog.check_required_packages_and_install_if_necessary(
iface=iface, device=device
)
)
# packages_installer_dialog.check_required_packages_and_install_if_necessary(iface=iface) # packages_installer_dialog.check_required_packages_and_install_if_necessary(iface=iface)
if packages_installed_allready: if packages_installed_allready:
from .iamap import IAMap from .iamap import IAMap
return IAMap(iface, cmd_folder) return IAMap(iface, cmd_folder)
else: else:
from .dialogs.packages_installer.packages_installer_dialog import IAMapEmpty from .dialogs.packages_installer.packages_installer_dialog import IAMapEmpty
return IAMapEmpty(iface, cmd_folder) return IAMapEmpty(iface, cmd_folder)
...@@ -3,17 +3,18 @@ from qgis.PyQt.QtCore import QCoreApplication ...@@ -3,17 +3,18 @@ from qgis.PyQt.QtCore import QCoreApplication
from .utils.algo import SKAlgorithm from .utils.algo import SKAlgorithm
from .icons import QIcon_ClusterTool from .icons import QIcon_ClusterTool
class ClusterAlgorithm(SKAlgorithm): class ClusterAlgorithm(SKAlgorithm):
""" """ """
"""
TYPE = 'cluster' TYPE = "cluster"
TMP_DIR = 'iamap_cluster' TMP_DIR = "iamap_cluster"
def tr(self, string): def tr(self, string):
""" """
Returns a translatable string with the self.tr() function. Returns a translatable string with the self.tr() function.
""" """
return QCoreApplication.translate('Processing', string) return QCoreApplication.translate("Processing", string)
def createInstance(self): def createInstance(self):
return ClusterAlgorithm() return ClusterAlgorithm()
...@@ -26,21 +27,21 @@ class ClusterAlgorithm(SKAlgorithm): ...@@ -26,21 +27,21 @@ class ClusterAlgorithm(SKAlgorithm):
lowercase alphanumeric characters only and no spaces or other lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return 'cluster' return "cluster"
def displayName(self): def displayName(self):
""" """
Returns the translated algorithm name, which should be used for any Returns the translated algorithm name, which should be used for any
user-visible display of the algorithm name. user-visible display of the algorithm name.
""" """
return self.tr('Clustering') return self.tr("Clustering")
def group(self): def group(self):
""" """
Returns the name of the group this algorithm belongs to. This string Returns the name of the group this algorithm belongs to. This string
should be localised. should be localised.
""" """
return self.tr('') return self.tr("")
def groupId(self): def groupId(self):
""" """
...@@ -50,7 +51,7 @@ class ClusterAlgorithm(SKAlgorithm): ...@@ -50,7 +51,7 @@ class ClusterAlgorithm(SKAlgorithm):
contain lowercase alphanumeric characters only and no spaces or other contain lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return '' return ""
def shortHelpString(self): def shortHelpString(self):
""" """
...@@ -58,7 +59,9 @@ class ClusterAlgorithm(SKAlgorithm): ...@@ -58,7 +59,9 @@ class ClusterAlgorithm(SKAlgorithm):
should provide a basic description about what the algorithm does and the should provide a basic description about what the algorithm does and the
parameters and outputs associated with it.. parameters and outputs associated with it..
""" """
return self.tr(f"Cluster a raster. Only KMeans is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}") return self.tr(
f"Cluster a raster. Only KMeans is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}"
)
def icon(self): def icon(self):
return QIcon_ClusterTool return QIcon_ClusterTool
import subprocess import subprocess
import platform import platform
def check_nvidia_gpu(): def check_nvidia_gpu():
try: try:
# Run the nvidia-smi command and capture the output # Run the nvidia-smi command and capture the output
output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv,noheader"], stderr=subprocess.STDOUT) output = subprocess.check_output(
["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv,noheader"],
stderr=subprocess.STDOUT,
)
output = output.decode("utf-8").strip() output = output.decode("utf-8").strip()
# Parse the output # Parse the output
gpu_info = output.split(',') gpu_info = output.split(",")
gpu_name = gpu_info[0].strip() gpu_name = gpu_info[0].strip()
output_cuda_version = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) output_cuda_version = subprocess.run(
for line in output_cuda_version.stdout.split('\n'): ["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
if 'CUDA Version' in line: )
cuda_version = line.split('CUDA Version: ')[1].split()[0] for line in output_cuda_version.stdout.split("\n"):
if "CUDA Version" in line:
cuda_version = line.split("CUDA Version: ")[1].split()[0]
return True, gpu_name, cuda_version return True, gpu_name, cuda_version
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return False, None, None return False, None, None
def check_amd_gpu(): def check_amd_gpu():
try: try:
if platform.system() == "Windows": if platform.system() == "Windows":
output = subprocess.check_output(["wmic", "path", "win32_videocontroller", "get", "name"], universal_newlines=True) output = subprocess.check_output(
["wmic", "path", "win32_videocontroller", "get", "name"],
universal_newlines=True,
)
if "AMD" in output or "Radeon" in output: if "AMD" in output or "Radeon" in output:
return True return True
elif platform.system() == "Linux": elif platform.system() == "Linux":
...@@ -31,17 +41,20 @@ def check_amd_gpu(): ...@@ -31,17 +41,20 @@ def check_amd_gpu():
if "AMD" in output or "Radeon" in output: if "AMD" in output or "Radeon" in output:
return True return True
elif platform.system() == "Darwin": elif platform.system() == "Darwin":
output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], universal_newlines=True) output = subprocess.check_output(
["system_profiler", "SPDisplaysDataType"], universal_newlines=True
)
if "AMD" in output or "Radeon" in output: if "AMD" in output or "Radeon" in output:
return True return True
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
return False return False
return False return False
def has_gpu(): def has_gpu():
has_nvidia, gpu_name, cuda_version = check_nvidia_gpu() has_nvidia, gpu_name, cuda_version = check_nvidia_gpu()
if has_nvidia: if has_nvidia:
return cuda_version return cuda_version
if check_amd_gpu(): if check_amd_gpu():
return 'amd' return "amd"
return 'cpu' return "cpu"
This diff is collapsed.
...@@ -5,37 +5,63 @@ ...@@ -5,37 +5,63 @@
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
import re
project = 'iamap'
copyright = '2024, TRESSON Paul, TULET Hadrien, LE COZ Pierre' metadata_file_path = os.path.join('..', '..', 'metadata.txt')
author = 'TRESSON Paul, TULET Hadrien, LE COZ Pierre' metadata_file_path = os.path.abspath(metadata_file_path)
release = '0.5.9' with open(metadata_file_path, 'rt') as file:
file_content = file.read()
try:
versions_from_metadata = re.findall(r'version=(.*)', file_content)[0]
except Exception as e:
raise Exception("Failed to read version from metadata!")
try:
author_from_metadata = re.findall(r'author=(.*)', file_content)[0]
except Exception as e:
raise Exception("Failed to read author from metadata!")
try:
name_from_metadata = re.findall(r'name=(.*)', file_content)[0]
except Exception as e:
raise Exception("Failed to read name from metadata!")
project = "iamap"
copyright = "2024, TRESSON Paul, TULET Hadrien, LE COZ Pierre"
author = author_from_metadata
# The short X.Y version
version = versions_from_metadata
# The full version, including alpha/beta/rc tags
release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
import pydata_sphinx_theme
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
"myst_parser", "myst_parser",
"sphinx_favicon", "sphinx_favicon",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = [] exclude_patterns = []
source_suffix = { source_suffix = {
'.rst': 'restructuredtext', ".rst": "restructuredtext",
'.md': 'markdown', ".md": "markdown",
} }
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'pydata_sphinx_theme' html_theme = "pydata_sphinx_theme"
html_static_path = ['_static'] html_static_path = ["_static"]
html_favicon = "./../../icons/favicon.svg" html_favicon = "./../../icons/favicon.svg"
This diff is collapsed.
import processing import processing
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import QAction, QToolBar
QAction,
QToolBar,
QApplication,
QDialog
)
from PyQt5.QtCore import pyqtSignal, QObject from PyQt5.QtCore import pyqtSignal, QObject
from qgis.core import QgsApplication from qgis.core import QgsApplication
from qgis.gui import QgisInterface from qgis.gui import QgisInterface
from .provider import IAMapProvider from .provider import IAMapProvider
from .icons import (QIcon_EncoderTool, from .icons import (
QIcon_ReductionTool, QIcon_EncoderTool,
QIcon_ClusterTool, QIcon_ReductionTool,
QIcon_SimilarityTool, QIcon_ClusterTool,
QIcon_RandomforestTool, QIcon_SimilarityTool,
) QIcon_RandomforestTool,
)
class IAMap(QObject): class IAMap(QObject):
...@@ -32,34 +28,26 @@ class IAMap(QObject): ...@@ -32,34 +28,26 @@ class IAMap(QObject):
def initGui(self): def initGui(self):
self.initProcessing() self.initProcessing()
self.toolbar: QToolBar = self.iface.addToolBar('IAMap Toolbar') self.toolbar: QToolBar = self.iface.addToolBar("IAMap Toolbar")
self.toolbar.setObjectName('IAMapToolbar') self.toolbar.setObjectName("IAMapToolbar")
self.toolbar.setToolTip('IAMap Toolbar') self.toolbar.setToolTip("IAMap Toolbar")
self.actionEncoder = QAction( self.actionEncoder = QAction(
QIcon_EncoderTool, QIcon_EncoderTool, "Deep Learning Image Encoder", self.iface.mainWindow()
"Deep Learning Image Encoder",
self.iface.mainWindow()
) )
self.actionReducer = QAction( self.actionReducer = QAction(
QIcon_ReductionTool, QIcon_ReductionTool, "Reduce dimensions", self.iface.mainWindow()
"Reduce dimensions",
self.iface.mainWindow()
) )
self.actionCluster = QAction( self.actionCluster = QAction(
QIcon_ClusterTool, QIcon_ClusterTool, "Cluster raster", self.iface.mainWindow()
"Cluster raster",
self.iface.mainWindow()
) )
self.actionSimilarity = QAction( self.actionSimilarity = QAction(
QIcon_SimilarityTool, QIcon_SimilarityTool, "Compute similarity", self.iface.mainWindow()
"Compute similarity",
self.iface.mainWindow()
) )
self.actionRF = QAction( self.actionRF = QAction(
QIcon_RandomforestTool, QIcon_RandomforestTool,
"Fit Machine Learning algorithm", "Fit Machine Learning algorithm",
self.iface.mainWindow() self.iface.mainWindow(),
) )
self.actionEncoder.setObjectName("mActionEncoder") self.actionEncoder.setObjectName("mActionEncoder")
self.actionReducer.setObjectName("mActionReducer") self.actionReducer.setObjectName("mActionReducer")
...@@ -67,16 +55,11 @@ class IAMap(QObject): ...@@ -67,16 +55,11 @@ class IAMap(QObject):
self.actionSimilarity.setObjectName("mactionSimilarity") self.actionSimilarity.setObjectName("mactionSimilarity")
self.actionRF.setObjectName("mactionRF") self.actionRF.setObjectName("mactionRF")
self.actionEncoder.setToolTip( self.actionEncoder.setToolTip("Encode a raster with a deep learning backbone")
"Encode a raster with a deep learning backbone") self.actionReducer.setToolTip("Reduce raster dimensions")
self.actionReducer.setToolTip( self.actionCluster.setToolTip("Cluster raster")
"Reduce raster dimensions") self.actionSimilarity.setToolTip("Compute similarity")
self.actionCluster.setToolTip( self.actionRF.setToolTip("Fit ML model")
"Cluster raster")
self.actionSimilarity.setToolTip(
"Compute similarity")
self.actionRF.setToolTip(
"Fit ML model")
self.actionEncoder.triggered.connect(self.encodeImage) self.actionEncoder.triggered.connect(self.encodeImage)
self.actionReducer.triggered.connect(self.reduceImage) self.actionReducer.triggered.connect(self.reduceImage)
...@@ -107,119 +90,111 @@ class IAMap(QObject): ...@@ -107,119 +90,111 @@ class IAMap(QObject):
QgsApplication.processingRegistry().removeProvider(self.provider) QgsApplication.processingRegistry().removeProvider(self.provider)
def encodeImage(self): def encodeImage(self):
''' """ """
''' result = processing.execAlgorithmDialog("iamap:encoder", {})
result = processing.execAlgorithmDialog('iamap:encoder', {})
print(result) print(result)
# Check if algorithm execution was successful # Check if algorithm execution was successful
if result: if result:
# Retrieve output parameters from the result dictionary # Retrieve output parameters from the result dictionary
if 'OUTPUT_RASTER' in result: if "OUTPUT_RASTER" in result:
output_raster_path = result['OUTPUT_RASTER'] output_raster_path = result["OUTPUT_RASTER"]
output_layer_name = result['OUTPUT_LAYER_NAME'] output_layer_name = result["OUTPUT_LAYER_NAME"]
# Add the output raster layer to the map canvas # Add the output raster layer to the map canvas
self.iface.addRasterLayer(str(output_raster_path),output_layer_name) self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
else: else:
# Handle missing or unexpected output # Handle missing or unexpected output
print('Output raster not found in algorithm result.') print("Output raster not found in algorithm result.")
else: else:
# Handle algorithm execution failure or cancellation # Handle algorithm execution failure or cancellation
print('Algorithm execution was not successful.') print("Algorithm execution was not successful.")
# processing.execAlgorithmDialog('', {}) # processing.execAlgorithmDialog('', {})
# self.close_all_dialogs() # self.close_all_dialogs()
def reduceImage(self): def reduceImage(self):
''' """ """
''' result = processing.execAlgorithmDialog("iamap:reduction", {})
result = processing.execAlgorithmDialog('iamap:reduction', {})
print(result) print(result)
# Check if algorithm execution was successful # Check if algorithm execution was successful
if result: if result:
# Retrieve output parameters from the result dictionary # Retrieve output parameters from the result dictionary
if 'OUTPUT_RASTER' in result: if "OUTPUT_RASTER" in result:
output_raster_path = result['OUTPUT_RASTER'] output_raster_path = result["OUTPUT_RASTER"]
output_layer_name = result['OUTPUT_LAYER_NAME'] output_layer_name = result["OUTPUT_LAYER_NAME"]
# Add the output raster layer to the map canvas # Add the output raster layer to the map canvas
self.iface.addRasterLayer(str(output_raster_path), output_layer_name) self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
else: else:
# Handle missing or unexpected output # Handle missing or unexpected output
print('Output raster not found in algorithm result.') print("Output raster not found in algorithm result.")
else: else:
# Handle algorithm execution failure or cancellation # Handle algorithm execution failure or cancellation
print('Algorithm execution was not successful.') print("Algorithm execution was not successful.")
# processing.execAlgorithmDialog('', {}) # processing.execAlgorithmDialog('', {})
def clusterImage(self): def clusterImage(self):
''' """ """
''' result = processing.execAlgorithmDialog("iamap:cluster", {})
result = processing.execAlgorithmDialog('iamap:cluster', {})
print(result) print(result)
# Check if algorithm execution was successful # Check if algorithm execution was successful
if result: if result:
# Retrieve output parameters from the result dictionary # Retrieve output parameters from the result dictionary
if 'OUTPUT_RASTER' in result: if "OUTPUT_RASTER" in result:
output_raster_path = result['OUTPUT_RASTER'] output_raster_path = result["OUTPUT_RASTER"]
output_layer_name = result['OUTPUT_LAYER_NAME'] output_layer_name = result["OUTPUT_LAYER_NAME"]
# Add the output raster layer to the map canvas # Add the output raster layer to the map canvas
self.iface.addRasterLayer(str(output_raster_path), output_layer_name) self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
else: else:
# Handle missing or unexpected output # Handle missing or unexpected output
print('Output raster not found in algorithm result.') print("Output raster not found in algorithm result.")
else: else:
# Handle algorithm execution failure or cancellation # Handle algorithm execution failure or cancellation
print('Algorithm execution was not successful.') print("Algorithm execution was not successful.")
# processing.execAlgorithmDialog('', {}) # processing.execAlgorithmDialog('', {})
def similarityImage(self): def similarityImage(self):
''' """ """
''' result = processing.execAlgorithmDialog("iamap:similarity", {})
result = processing.execAlgorithmDialog('iamap:similarity', {})
print(result) print(result)
# Check if algorithm execution was successful # Check if algorithm execution was successful
if result: if result:
# Retrieve output parameters from the result dictionary # Retrieve output parameters from the result dictionary
if 'OUTPUT_RASTER' in result: if "OUTPUT_RASTER" in result:
output_raster_path = result['OUTPUT_RASTER'] output_raster_path = result["OUTPUT_RASTER"]
output_layer_name = result['OUTPUT_LAYER_NAME'] output_layer_name = result["OUTPUT_LAYER_NAME"]
used_shp = result['USED_SHP'] used_shp = result["USED_SHP"]
# Add the output raster layer to the map canvas # Add the output raster layer to the map canvas
self.iface.addRasterLayer(str(output_raster_path), output_layer_name) self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
self.iface.addVectorLayer(str(used_shp), 'used points', "ogr") self.iface.addVectorLayer(str(used_shp), "used points", "ogr")
else: else:
# Handle missing or unexpected output # Handle missing or unexpected output
print('Output raster not found in algorithm result.') print("Output raster not found in algorithm result.")
else: else:
# Handle algorithm execution failure or cancellation # Handle algorithm execution failure or cancellation
print('Algorithm execution was not successful.') print("Algorithm execution was not successful.")
# processing.execAlgorithmDialog('', {}) # processing.execAlgorithmDialog('', {})
def rfImage(self): def rfImage(self):
''' """ """
''' result = processing.execAlgorithmDialog("iamap:ml", {})
result = processing.execAlgorithmDialog('iamap:ml', {})
print(result) print(result)
# Check if algorithm execution was successful # Check if algorithm execution was successful
if result: if result:
# Retrieve output parameters from the result dictionary # Retrieve output parameters from the result dictionary
if 'OUTPUT_RASTER' in result: if "OUTPUT_RASTER" in result:
output_raster_path = result['OUTPUT_RASTER'] output_raster_path = result["OUTPUT_RASTER"]
output_layer_name = result['OUTPUT_LAYER_NAME'] output_layer_name = result["OUTPUT_LAYER_NAME"]
used_shp = result['USED_SHP'] used_shp = result["USED_SHP"]
# Add the output raster layer to the map canvas # Add the output raster layer to the map canvas
self.iface.addRasterLayer(str(output_raster_path), output_layer_name) self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
self.iface.addVectorLayer(str(used_shp), 'used points', "ogr") self.iface.addVectorLayer(str(used_shp), "used points", "ogr")
else: else:
# Handle missing or unexpected output # Handle missing or unexpected output
print('Output raster not found in algorithm result.') print("Output raster not found in algorithm result.")
else: else:
# Handle algorithm execution failure or cancellation # Handle algorithm execution failure or cancellation
print('Algorithm execution was not successful.') print("Algorithm execution was not successful.")
# processing.execAlgorithmDialog('', {}) # processing.execAlgorithmDialog('', {})
...@@ -2,11 +2,11 @@ import os ...@@ -2,11 +2,11 @@ import os
from PyQt5.QtGui import QIcon from PyQt5.QtGui import QIcon
cwd = os.path.abspath(os.path.dirname(__file__)) cwd = os.path.abspath(os.path.dirname(__file__))
encoder_tool_path = os.path.join(cwd, 'encoder.svg') encoder_tool_path = os.path.join(cwd, "encoder.svg")
reduction_tool_path = os.path.join(cwd, 'proj.svg') reduction_tool_path = os.path.join(cwd, "proj.svg")
cluster_tool_path = os.path.join(cwd, 'cluster.svg') cluster_tool_path = os.path.join(cwd, "cluster.svg")
similarity_tool_path = os.path.join(cwd, 'sim.svg') similarity_tool_path = os.path.join(cwd, "sim.svg")
random_forest_tool_path = os.path.join(cwd, 'forest.svg') random_forest_tool_path = os.path.join(cwd, "forest.svg")
QIcon_EncoderTool = QIcon(encoder_tool_path) QIcon_EncoderTool = QIcon(encoder_tool_path)
QIcon_ReductionTool = QIcon(reduction_tool_path) QIcon_ReductionTool = QIcon(reduction_tool_path)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
name=iamap name=iamap
description=Extract and manipulate deep learning features from rasters description=Extract and manipulate deep learning features from rasters
about= This plugin is still a work in progress, feel free to fill an issue on github. about= This plugin is still a work in progress, feel free to fill an issue on github.
version=0.5.9 version=0.6.0
icon=icons/favicon.svg icon=icons/favicon.svg
qgisMinimumVersion=3.12 qgisMinimumVersion=3.12
author=Paul Tresson, Pierre Lecoz, Hadrien Tulet author=Paul Tresson, Pierre Lecoz, Hadrien Tulet
......
This diff is collapsed.
...@@ -9,7 +9,6 @@ from .icons import QIcon_EncoderTool ...@@ -9,7 +9,6 @@ from .icons import QIcon_EncoderTool
class IAMapProvider(QgsProcessingProvider): class IAMapProvider(QgsProcessingProvider):
def loadAlgorithms(self, *args, **kwargs): def loadAlgorithms(self, *args, **kwargs):
self.addAlgorithm(EncoderAlgorithm()) self.addAlgorithm(EncoderAlgorithm())
self.addAlgorithm(ReductionAlgorithm()) self.addAlgorithm(ReductionAlgorithm())
...@@ -25,7 +24,7 @@ class IAMapProvider(QgsProcessingProvider): ...@@ -25,7 +24,7 @@ class IAMapProvider(QgsProcessingProvider):
This string should be a unique, short, character only string, This string should be a unique, short, character only string,
eg "qgis" or "gdal". This string should not be localised. eg "qgis" or "gdal". This string should not be localised.
""" """
return 'iamap' return "iamap"
def name(self, *args, **kwargs): def name(self, *args, **kwargs):
"""The human friendly name of your plugin in Processing. """The human friendly name of your plugin in Processing.
...@@ -33,7 +32,7 @@ class IAMapProvider(QgsProcessingProvider): ...@@ -33,7 +32,7 @@ class IAMapProvider(QgsProcessingProvider):
This string should be as short as possible (e.g. "Lastools", not This string should be as short as possible (e.g. "Lastools", not
"Lastools version 1.0.1 64-bit") and localised. "Lastools version 1.0.1 64-bit") and localised.
""" """
return self.tr('IAMap') return self.tr("IAMap")
def icon(self): def icon(self):
"""Should return a QIcon which is used for your provider inside """Should return a QIcon which is used for your provider inside
...@@ -43,4 +42,3 @@ class IAMapProvider(QgsProcessingProvider): ...@@ -43,4 +42,3 @@ class IAMapProvider(QgsProcessingProvider):
def longName(self) -> str: def longName(self) -> str:
return self.name() return self.name()
import os
import tempfile
import numpy as np
from pathlib import Path
from typing import Dict, Any
import joblib
import pandas as pd
import psutil
import json
import rasterio
from rasterio import windows
from qgis.PyQt.QtCore import QCoreApplication from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (Qgis,
QgsGeometry,
QgsProcessingParameterBoolean,
QgsProcessingParameterEnum,
QgsCoordinateTransform,
QgsProcessingException,
QgsProcessingAlgorithm,
QgsProcessingParameterRasterLayer,
QgsProcessingParameterFolderDestination,
QgsProcessingParameterBand,
QgsProcessingParameterNumber,
QgsProcessingParameterExtent,
QgsProcessingParameterCrs,
QgsProcessingParameterDefinition,
)
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.cluster import KMeans
from .utils.misc import get_unique_filename
from .utils.algo import SKAlgorithm from .utils.algo import SKAlgorithm
from .icons import QIcon_ReductionTool from .icons import QIcon_ReductionTool
#from umap.umap_ import UMAP # from umap.umap_ import UMAP
class ReductionAlgorithm(SKAlgorithm): class ReductionAlgorithm(SKAlgorithm):
""" """ """
"""
def tr(self, string): def tr(self, string):
""" """
Returns a translatable string with the self.tr() function. Returns a translatable string with the self.tr() function.
""" """
return QCoreApplication.translate('Processing', string) return QCoreApplication.translate("Processing", string)
def createInstance(self): def createInstance(self):
return ReductionAlgorithm() return ReductionAlgorithm()
...@@ -62,21 +27,21 @@ class ReductionAlgorithm(SKAlgorithm): ...@@ -62,21 +27,21 @@ class ReductionAlgorithm(SKAlgorithm):
lowercase alphanumeric characters only and no spaces or other lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return 'reduction' return "reduction"
def displayName(self): def displayName(self):
""" """
Returns the translated algorithm name, which should be used for any Returns the translated algorithm name, which should be used for any
user-visible display of the algorithm name. user-visible display of the algorithm name.
""" """
return self.tr('Dimension Reduction') return self.tr("Dimension Reduction")
def group(self): def group(self):
""" """
Returns the name of the group this algorithm belongs to. This string Returns the name of the group this algorithm belongs to. This string
should be localised. should be localised.
""" """
return self.tr('') return self.tr("")
def groupId(self): def groupId(self):
""" """
...@@ -86,7 +51,7 @@ class ReductionAlgorithm(SKAlgorithm): ...@@ -86,7 +51,7 @@ class ReductionAlgorithm(SKAlgorithm):
contain lowercase alphanumeric characters only and no spaces or other contain lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return '' return ""
def shortHelpString(self): def shortHelpString(self):
""" """
...@@ -94,7 +59,9 @@ class ReductionAlgorithm(SKAlgorithm): ...@@ -94,7 +59,9 @@ class ReductionAlgorithm(SKAlgorithm):
should provide a basic description about what the algorithm does and the should provide a basic description about what the algorithm does and the
parameters and outputs associated with it.. parameters and outputs associated with it..
""" """
return self.tr(f"Reduce the dimension of deep learning features. Only PCA is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}") return self.tr(
f"Reduce the dimension of deep learning features. Only PCA is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}"
)
def icon(self): def icon(self):
return QIcon_ReductionTool return QIcon_ReductionTool
geopandas >= 0.14.4
scikit-learn >= 1.5.1
psutil >= 5.0.0
# from torchgeo
rasterio >= 1.2
rtree >= 0.9
einops >= 0.3
fiona >= 1.8.19
kornia >= 0.6.9
numpy >= 1.19.3
pyproj >= 3.3
shapely >= 1.7.1
timm >= 0.4.12
pytest
...@@ -6,7 +6,7 @@ psutil >= 5.0.0 ...@@ -6,7 +6,7 @@ psutil >= 5.0.0
# torchgeo == 0.5.2 # torchgeo == 0.5.2
# from torchgeo # from torchgeo
rasterio >= 1.2 rasterio >= 1.2
rtree <= 0.9 rtree >= 1
einops >= 0.3 einops >= 0.3
fiona >= 1.8.19 fiona >= 1.8.19
kornia >= 0.6.9 kornia >= 0.6.9
......
...@@ -3,14 +3,12 @@ from .utils.algo import SHPAlgorithm ...@@ -3,14 +3,12 @@ from .utils.algo import SHPAlgorithm
from .icons import QIcon_SimilarityTool from .icons import QIcon_SimilarityTool
class SimilarityAlgorithm(SHPAlgorithm): class SimilarityAlgorithm(SHPAlgorithm):
def tr(self, string): def tr(self, string):
""" """
Returns a translatable string with the self.tr() function. Returns a translatable string with the self.tr() function.
""" """
return QCoreApplication.translate('Processing', string) return QCoreApplication.translate("Processing", string)
def createInstance(self): def createInstance(self):
return SimilarityAlgorithm() return SimilarityAlgorithm()
...@@ -23,21 +21,21 @@ class SimilarityAlgorithm(SHPAlgorithm): ...@@ -23,21 +21,21 @@ class SimilarityAlgorithm(SHPAlgorithm):
lowercase alphanumeric characters only and no spaces or other lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return 'similarity' return "similarity"
def displayName(self): def displayName(self):
""" """
Returns the translated algorithm name, which should be used for any Returns the translated algorithm name, which should be used for any
user-visible display of the algorithm name. user-visible display of the algorithm name.
""" """
return self.tr('Similarity') return self.tr("Similarity")
def group(self): def group(self):
""" """
Returns the name of the group this algorithm belongs to. This string Returns the name of the group this algorithm belongs to. This string
should be localised. should be localised.
""" """
return self.tr('') return self.tr("")
def groupId(self): def groupId(self):
""" """
...@@ -47,7 +45,7 @@ class SimilarityAlgorithm(SHPAlgorithm): ...@@ -47,7 +45,7 @@ class SimilarityAlgorithm(SHPAlgorithm):
contain lowercase alphanumeric characters only and no spaces or other contain lowercase alphanumeric characters only and no spaces or other
formatting characters. formatting characters.
""" """
return '' return ""
def shortHelpString(self): def shortHelpString(self):
""" """
......
...@@ -3,7 +3,14 @@ import os ...@@ -3,7 +3,14 @@ import os
PYTHON_VERSION = sys.version_info PYTHON_VERSION = sys.version_info
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PLUGIN_ROOT_DIR = os.path.realpath(os.path.abspath(os.path.join(SCRIPT_DIR, '..'))) PLUGIN_ROOT_DIR = os.path.realpath(os.path.abspath(os.path.join(SCRIPT_DIR, "..")))
PACKAGES_INSTALL_DIR = os.path.join(PLUGIN_ROOT_DIR, f'python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}') QGIS_PYTHON_DIR = os.path.realpath(os.path.abspath(os.path.join(PLUGIN_ROOT_DIR, "..")))
PACKAGES_INSTALL_DIR = os.path.join(
PLUGIN_ROOT_DIR, f"python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}"
)
sys.path.append(PACKAGES_INSTALL_DIR) # TODO: check for a less intrusive way to do this sys.path.append(PACKAGES_INSTALL_DIR) # TODO: check for a less intrusive way to do this
qgis_python_path = os.getenv("PYTHONPATH")
if qgis_python_path and qgis_python_path not in sys.path:
sys.path.append(qgis_python_path)
import os import os
import pytest
from pathlib import Path from pathlib import Path
import tempfile import tempfile
import unittest import unittest
from qgis.core import ( from qgis.core import (
QgsProcessingContext, QgsProcessingContext,
QgsProcessingFeedback, QgsProcessingFeedback,
) )
from ..ml import MLAlgorithm from ..ml import MLAlgorithm
from ..similarity import SimilarityAlgorithm from ..similarity import SimilarityAlgorithm
from ..clustering import ClusterAlgorithm from ..clustering import ClusterAlgorithm
from ..reduction import ReductionAlgorithm from ..reduction import ReductionAlgorithm
from ..utils.misc import get_file_md5_hash, remove_files_with_extensions from ..utils.misc import get_file_md5_hash, remove_files_with_extensions
from ..utils.geo import validate_geotiff
INPUT = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'test.tif') INPUT = os.path.join(Path(__file__).parent.parent.absolute(), "assets", "test.tif")
OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test") OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test")
EXTENSIONS_TO_RM = ['.tif', '.pkl', '.json', '.shp', '.shx', '.prj', '.dbf', '.cpg'] EXTENSIONS_TO_RM = [".tif", ".pkl", ".json", ".shp", ".shx", ".prj", ".dbf", ".cpg"]
TEMPLATE = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'template.shp') TEMPLATE = os.path.join(
TEMPLATE_RF = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'ml_poly.shp') Path(__file__).parent.parent.absolute(), "assets", "template.shp"
GT_COL = 'Type' )
TEMPLATE_RF = os.path.join(
Path(__file__).parent.parent.absolute(), "assets", "ml_poly.shp"
)
GT_COL = "Type"
class TestReductionAlgorithm(unittest.TestCase): class TestReductionAlgorithm(unittest.TestCase):
""" """
Base test class, other will inherit from this Base test class, other will inherit from this
""" """
algorithm = ReductionAlgorithm() algorithm = ReductionAlgorithm()
default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT} default_parameters = {"INPUT": INPUT, "OUTPUT": OUTPUT}
possible_hashes = [ possible_hashes = [
'd7a32c6b7a4cee1af9c73607561d7b25', "d7a32c6b7a4cee1af9c73607561d7b25",
'e04f8c86d9aad81dd9c625b9cd8f9824', "e04f8c86d9aad81dd9c625b9cd8f9824",
] ]
out_name = 'proj.tif' output_size = 4405122
output_wh = (968,379)
out_name = "proj.tif"
def setUp(self): def setUp(self):
self.context = QgsProcessingContext() self.context = QgsProcessingContext()
...@@ -38,41 +49,51 @@ class TestReductionAlgorithm(unittest.TestCase): ...@@ -38,41 +49,51 @@ class TestReductionAlgorithm(unittest.TestCase):
def test_valid_parameters(self): def test_valid_parameters(self):
self.algorithm.initAlgorithm() self.algorithm.initAlgorithm()
result = self.algorithm.processAlgorithm(self.default_parameters, self.context, self.feedback) _ = self.algorithm.processAlgorithm(
expected_result_path = os.path.join(self.algorithm.output_dir,self.out_name) self.default_parameters, self.context, self.feedback
result_file_hash = get_file_md5_hash(expected_result_path) )
expected_result_path = os.path.join(self.algorithm.output_dir, self.out_name)
@pytest.mark.parametrize("output_file", expected_result_path, "expected_output_size", self.output_size, "expected_wh", self.output_wh)
def test_geotiff_validity(output_file):
validate_geotiff(output_file)
remove_files_with_extensions(self.algorithm.output_dir, EXTENSIONS_TO_RM) remove_files_with_extensions(self.algorithm.output_dir, EXTENSIONS_TO_RM)
assert result_file_hash in self.possible_hashes
class TestClusteringAlgorithm(TestReductionAlgorithm): class TestClusteringAlgorithm(TestReductionAlgorithm):
algorithm = ClusterAlgorithm() algorithm = ClusterAlgorithm()
possible_hashes = ['0c47b0c4b4c13902db5da3ee6e5d4aef'] # possible_hashes = ["0c47b0c4b4c13902db5da3ee6e5d4aef"]
out_name = 'cluster.tif' out_name = "cluster.tif"
output_size = 4405122
class TestSimAlgorithm(TestReductionAlgorithm): class TestSimAlgorithm(TestReductionAlgorithm):
algorithm = SimilarityAlgorithm() algorithm = SimilarityAlgorithm()
default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT,'TEMPLATE':TEMPLATE} default_parameters = {"INPUT": INPUT, "OUTPUT": OUTPUT, "TEMPLATE": TEMPLATE}
possible_hashes = ['f76eb1f0469725b49fe0252cfe86829a'] # possible_hashes = ["f76eb1f0469725b49fe0252cfe86829a"]
out_name = 'similarity.tif' out_name = "similarity.tif"
output_size = 1468988
class TestMLAlgorithm(TestReductionAlgorithm): class TestMLAlgorithm(TestReductionAlgorithm):
algorithm = MLAlgorithm() algorithm = MLAlgorithm()
default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT,'TEMPLATE':TEMPLATE_RF,'GT_COL': GT_COL} default_parameters = {
possible_hashes = ['bd22d66180347e043fca58d494876184'] "INPUT": INPUT,
out_name = 'ml.tif' "OUTPUT": OUTPUT,
"TEMPLATE": TEMPLATE_RF,
"GT_COL": GT_COL,
}
# possible_hashes = ["bd22d66180347e043fca58d494876184"]
out_name = "ml.tif"
output_size = 367520
if __name__ == "__main__":
if __name__ == "__main__":
for algo in [ for algo in [
TestReductionAlgorithm(), TestReductionAlgorithm(),
TestClusteringAlgorithm(), TestClusteringAlgorithm(),
TestSimAlgorithm(), TestSimAlgorithm(),
TestMLAlgorithm(), TestMLAlgorithm(),
]: ]:
algo.setUp() algo.setUp()
print(algo.algorithm) print(algo.algorithm)
algo.test_valid_parameters() algo.test_valid_parameters()
...@@ -4,146 +4,132 @@ from pathlib import Path ...@@ -4,146 +4,132 @@ from pathlib import Path
import unittest import unittest
import pytest import pytest
from qgis.core import ( from qgis.core import (
QgsProcessingContext, QgsProcessingContext,
QgsProcessingFeedback, QgsProcessingFeedback,
) )
import timm import timm
import torch import torch
# from torchgeo.datasets import RasterDataset # from torchgeo.datasets import RasterDataset
from..tg.datasets import RasterDataset from ..tg.datasets import RasterDataset
from ..encoder import EncoderAlgorithm from ..encoder import EncoderAlgorithm
from ..utils.misc import get_file_md5_hash from ..utils.misc import get_file_md5_hash
from ..utils.geo import validate_geotiff
INPUT = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'test.tif') INPUT = os.path.join(Path(__file__).parent.parent.absolute(), "assets", "test.tif")
OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test") OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test")
class TestEncoderAlgorithm(unittest.TestCase):
class TestEncoderAlgorithm(unittest.TestCase):
def setUp(self): def setUp(self):
self.context = QgsProcessingContext() self.context = QgsProcessingContext()
self.feedback = QgsProcessingFeedback() self.feedback = QgsProcessingFeedback()
self.algorithm = EncoderAlgorithm() self.algorithm = EncoderAlgorithm()
self.default_parameters = { self.default_parameters = {
'BACKBONE_CHOICE': '', "BACKBONE_CHOICE": "",
'BACKBONE_OPT': 0, "BACKBONE_OPT": 0,
'BANDS': None, "BANDS": None,
'BATCH_SIZE': 1, "BATCH_SIZE": 1,
'CKPT': 'NULL', "CKPT": "NULL",
'CRS': None, "CRS": None,
'CUDA': True, "CUDA": True,
'CUDA_ID': 0, "CUDA_ID": 0,
'EXTENT': None, "EXTENT": None,
'FEAT_OPTION': True, "FEAT_OPTION": True,
'INPUT': INPUT, "INPUT": INPUT,
'MERGE_METHOD': 0, "MERGE_METHOD": 0,
'OUTPUT': OUTPUT, "OUTPUT": OUTPUT,
'PAUSES': 0, "PAUSES": 0,
'QUANT': True, "QUANT": True,
'REMOVE_TEMP_FILES': True, "REMOVE_TEMP_FILES": True,
'RESOLUTION': None, "RESOLUTION": None,
'SIZE': 224, "SIZE": 224,
'STRIDE': 224, "STRIDE": 224,
'TEMP_FILES_CLEANUP_FREQ': 1000, "TEMP_FILES_CLEANUP_FREQ": 1000,
'WORKERS': 0, "WORKERS": 0,
'JSON_PARAM': 'NULL', "JSON_PARAM": "NULL",
'OUT_DTYPE': 0, "OUT_DTYPE": 0,
} }
def test_valid_parameters(self): def test_valid_parameters(self):
self.algorithm.initAlgorithm() self.algorithm.initAlgorithm()
_ = self.algorithm.processAlgorithm(self.default_parameters, self.context, self.feedback) _ = self.algorithm.processAlgorithm(
expected_result_path = os.path.join(self.algorithm.output_subdir,'merged.tif') self.default_parameters, self.context, self.feedback
result_file_hash = get_file_md5_hash(expected_result_path) )
expected_result_path = os.path.join(self.algorithm.output_subdir, "merged.tif")
## different rasterio versions lead to different hashes ? @pytest.mark.parametrize("output_file", expected_result_path)
## GPU and quantization as well def test_geotiff_validity(output_file):
possible_hashes = [ validate_geotiff(output_file)
'0fb32cc57a0dd427d9f0165ec6d5418f',
'48c3a78773dbc2c4c7bb7885409284ab',
'431e034b842129679b99a067f2bd3ba4',
'60153535214eaa44458db4e297af72b9',
'f1394d1950f91e4f8277a8667ae77e85',
'a23837caa3aca54aaca2974d546c5123',
]
assert result_file_hash in possible_hashes
os.remove(expected_result_path)
@pytest.mark.slow
def test_data_types(self):
self.algorithm.initAlgorithm()
parameters = self.default_parameters
parameters['OUT_DTYPE'] = 1
_ = self.algorithm.processAlgorithm(parameters, self.context, self.feedback)
expected_result_path = os.path.join(self.algorithm.output_subdir,'merged.tif')
result_file_hash = get_file_md5_hash(expected_result_path)
## different rasterio versions lead to different hashes ?
possible_hashes = [
'ef0c4b0d57f575c1cd10c0578c7114c0',
'ebfad32752de71c5555bda2b40c19b2e',
'd3705c256320b7190dd4f92ad2087247',
'65fa46916d6d0d08ad9656d7d7fabd01',
]
assert result_file_hash in possible_hashes
os.remove(expected_result_path) os.remove(expected_result_path)
def test_timm_create_model(self): def test_timm_create_model(self):
archs = [ archs = [
'vit_base_patch16_224.dino', "vit_base_patch16_224.dino",
'vit_tiny_patch16_224.augreg_in21k', "vit_tiny_patch16_224.augreg_in21k",
'vit_base_patch16_224.mae', "vit_base_patch16_224.mae",
'samvit_base_patch16.sa1b', "samvit_base_patch16.sa1b",
] ]
expected_output_size = [ expected_output_size = [
torch.Size([1,197,768]), torch.Size([1, 197, 768]),
torch.Size([1,197,192]), torch.Size([1, 197, 192]),
torch.Size([1,197,768]), torch.Size([1, 197, 768]),
torch.Size([1, 256, 64, 64]), torch.Size([1, 256, 64, 64]),
] ]
for arch, exp_feat_size in zip(archs, expected_output_size): for arch, exp_feat_size in zip(archs, expected_output_size):
model = timm.create_model( model = timm.create_model(
arch, arch,
pretrained=True, pretrained=True,
in_chans=6, in_chans=6,
num_classes=0, num_classes=0,
) )
model = model.eval() model = model.eval()
data_config = timm.data.resolve_model_data_config(model) data_config = timm.data.resolve_model_data_config(model)
_, h, w, = data_config['input_size'] (
output = model.forward_features(torch.randn(1,6,h,w)) _,
h,
w,
) = data_config["input_size"]
output = model.forward_features(torch.randn(1, 6, h, w))
assert output.shape == exp_feat_size assert output.shape == exp_feat_size
def test_RasterDataset(self): def test_RasterDataset(self):
self.algorithm.initAlgorithm() self.algorithm.initAlgorithm()
parameters = {} parameters = {}
self.algorithm.process_options(parameters, self.context, self.feedback) self.algorithm.process_options(parameters, self.context, self.feedback)
RasterDataset.filename_glob = self.algorithm.rlayer_name RasterDataset.filename_glob = self.algorithm.rlayer_name
RasterDataset.all_bands = [ RasterDataset.all_bands = [
self.algorithm.rlayer.bandName(i_band) for i_band in range(1, self.algorithm.rlayer.bandCount()+1) self.algorithm.rlayer.bandName(i_band)
for i_band in range(1, self.algorithm.rlayer.bandCount() + 1)
] ]
# currently only support rgb bands # currently only support rgb bands
input_bands = [self.algorithm.rlayer.bandName(i_band) input_bands = [
for i_band in self.algorithm.selected_bands] self.algorithm.rlayer.bandName(i_band)
for i_band in self.algorithm.selected_bands
]
if self.algorithm.crs == self.algorithm.rlayer.crs(): if self.algorithm.crs == self.algorithm.rlayer.crs():
dataset = RasterDataset( dataset = RasterDataset(
paths=self.algorithm.rlayer_dir, crs=None, res=self.algorithm.res, bands=input_bands, cache=False) paths=self.algorithm.rlayer_dir,
crs=None,
res=self.algorithm.res,
bands=input_bands,
cache=False,
)
else: else:
dataset = RasterDataset( dataset = RasterDataset(
paths=self.algorithm.rlayer_dir, crs=self.algorithm.crs.toWkt(), res=self.algorithm.res, bands=input_bands, cache=False) paths=self.algorithm.rlayer_dir,
crs=self.algorithm.crs.toWkt(),
res=self.algorithm.res,
bands=input_bands,
cache=False,
)
del dataset del dataset
def test_cuda(self): def test_cuda(self):
...@@ -151,15 +137,10 @@ class TestEncoderAlgorithm(unittest.TestCase): ...@@ -151,15 +137,10 @@ class TestEncoderAlgorithm(unittest.TestCase):
assert True assert True
if __name__ == "__main__": if __name__ == "__main__":
test_encoder = TestEncoderAlgorithm() test_encoder = TestEncoderAlgorithm()
test_encoder.setUp() test_encoder.setUp()
test_encoder.test_timm_create_model() test_encoder.test_timm_create_model()
test_encoder.test_RasterDataset() test_encoder.test_RasterDataset()
test_encoder.test_valid_parameters() test_encoder.test_valid_parameters()
test_encoder.test_data_types()
test_encoder.test_cuda() test_encoder.test_cuda()
# modified from torchgeo code # modified from torchgeo code
"""Base classes for all :mod:`torchgeo` datasets.""" """Base classes for all :mod:`torchgeo` datasets."""
...@@ -390,10 +390,10 @@ class RasterDataset(GeoDataset): ...@@ -390,10 +390,10 @@ class RasterDataset(GeoDataset):
.. versionchanged:: 0.5 .. versionchanged:: 0.5
*root* was renamed to *paths*. *root* was renamed to *paths*.
""" """
print('pre super.__init__') print("pre super.__init__")
super().__init__(transforms) super().__init__(transforms)
print('post super.__init__') print("post super.__init__")
self.paths = paths self.paths = paths
self.bands = bands or self.all_bands self.bands = bands or self.all_bands
self.cache = cache self.cache = cache
...@@ -403,7 +403,7 @@ class RasterDataset(GeoDataset): ...@@ -403,7 +403,7 @@ class RasterDataset(GeoDataset):
filename_regex = re.compile(self.filename_regex, re.VERBOSE) filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for filepath in self.files: for filepath in self.files:
match = re.match(filename_regex, os.path.basename(filepath)) match = re.match(filename_regex, os.path.basename(filepath))
print('regex') print("regex")
if match is not None: if match is not None:
try: try:
with rasterio.open(filepath) as src: with rasterio.open(filepath) as src:
...@@ -580,7 +580,6 @@ class RasterDataset(GeoDataset): ...@@ -580,7 +580,6 @@ class RasterDataset(GeoDataset):
return src return src
class IntersectionDataset(GeoDataset): class IntersectionDataset(GeoDataset):
"""Dataset representing the intersection of two GeoDatasets. """Dataset representing the intersection of two GeoDatasets.
...@@ -882,4 +881,3 @@ class UnionDataset(GeoDataset): ...@@ -882,4 +881,3 @@ class UnionDataset(GeoDataset):
self._res = new_res self._res = new_res
self.datasets[0].res = new_res self.datasets[0].res = new_res
self.datasets[1].res = new_res self.datasets[1].res = new_res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment