diff --git a/.github/workflows/jobs.yml b/.github/workflows/jobs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..875ce6dc5dfbac804d3ed049dfc0be31f0859f99
--- /dev/null
+++ b/.github/workflows/jobs.yml
@@ -0,0 +1,46 @@
+name: CI/CD Pipeline
+
+on:
+  push:
+    branches:
+      - main, dev
+  pull_request:
+    branches:
+      - main, dev
+
+jobs:
+  build:
+    name: ${{ matrix.os }}, Python 3.${{ matrix.python-minor-version }}, QGIS 3.${{ matrix.qgis-minor-version }}
+    runs-on: ${{ matrix.os }}
+    strategy:
+      fail-fast: false
+      max-parallel: 6
+      matrix:
+        ## Windows automatic testing is not functionnig yet
+        os: [ubuntu-latest , macos-latest]
+        # os: [ubuntu-latest , macos-latest , windows-latest]
+        python-minor-version: [11, 12]
+        qgis-minor-version: [34, 36, 38]
+          
+    steps:
+    - name: Checkout
+      uses: actions/checkout@v3
+
+    - name: Set up Miniconda
+      uses: conda-incubator/setup-miniconda@v3
+      with:
+        python-version: 3.${{ matrix.python-minor-version }}
+        channels: conda-forge
+        auto-update-conda: true
+
+    - name: Set up Environment and Install Dependencies
+      run: |
+        conda create -n pytest python=3.${{ matrix.python-minor-version }} qgis=3.${{ matrix.qgis-minor-version }} --yes
+        conda install -n pytest --file requirements.txt --yes
+        conda install -n pytest pytest --yes
+      shell: bash -el {0}
+
+    - name: Run Tests
+      run: |
+        conda run -n pytest pytest tests/
+      shell: bash -el {0}
diff --git a/__init__.py b/__init__.py
index 872bba5ac3e8fc795c50a15239d4e92b9c60fed8..8b843d58580ca644fcc5e8e6878ab01cebf5f303 100644
--- a/__init__.py
+++ b/__init__.py
@@ -3,16 +3,24 @@ import inspect
 
 cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0]
 
+
 def classFactory(iface):
     from .dialogs.check_gpu import has_gpu
     from .dialogs.packages_installer import packages_installer_dialog
+
     device = has_gpu()
-    packages_installed_allready = packages_installer_dialog.check_required_packages_and_install_if_necessary(iface=iface, device=device)
+    packages_installed_allready = (
+        packages_installer_dialog.check_required_packages_and_install_if_necessary(
+            iface=iface, device=device
+        )
+    )
     # packages_installer_dialog.check_required_packages_and_install_if_necessary(iface=iface)
     if packages_installed_allready:
         from .iamap import IAMap
+
         return IAMap(iface, cmd_folder)
 
     else:
         from .dialogs.packages_installer.packages_installer_dialog import IAMapEmpty
+
         return IAMapEmpty(iface, cmd_folder)
diff --git a/clustering.py b/clustering.py
index 7ba9940e63b06a48883f09915accfffc80117221..19cedad29c114f7a49fbd93daefbc598b98f03bf 100644
--- a/clustering.py
+++ b/clustering.py
@@ -3,17 +3,18 @@ from qgis.PyQt.QtCore import QCoreApplication
 from .utils.algo import SKAlgorithm
 from .icons import QIcon_ClusterTool
 
+
 class ClusterAlgorithm(SKAlgorithm):
-    """
-    """
-    TYPE = 'cluster'
-    TMP_DIR = 'iamap_cluster'
+    """ """
+
+    TYPE = "cluster"
+    TMP_DIR = "iamap_cluster"
 
     def tr(self, string):
         """
         Returns a translatable string with the self.tr() function.
         """
-        return QCoreApplication.translate('Processing', string)
+        return QCoreApplication.translate("Processing", string)
 
     def createInstance(self):
         return ClusterAlgorithm()
@@ -26,21 +27,21 @@ class ClusterAlgorithm(SKAlgorithm):
         lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return 'cluster'
+        return "cluster"
 
     def displayName(self):
         """
         Returns the translated algorithm name, which should be used for any
         user-visible display of the algorithm name.
         """
-        return self.tr('Clustering')
+        return self.tr("Clustering")
 
     def group(self):
         """
         Returns the name of the group this algorithm belongs to. This string
         should be localised.
         """
-        return self.tr('')
+        return self.tr("")
 
     def groupId(self):
         """
@@ -50,7 +51,7 @@ class ClusterAlgorithm(SKAlgorithm):
         contain lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return ''
+        return ""
 
     def shortHelpString(self):
         """
@@ -58,7 +59,9 @@ class ClusterAlgorithm(SKAlgorithm):
         should provide a basic description about what the algorithm does and the
         parameters and outputs associated with it..
         """
-        return self.tr(f"Cluster a raster. Only KMeans is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}")
+        return self.tr(
+            f"Cluster a raster. Only KMeans is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}"
+        )
 
     def icon(self):
         return QIcon_ClusterTool
diff --git a/dialogs/check_gpu.py b/dialogs/check_gpu.py
index e39471be52b6e94fc04bfad95c910f4fd1df95e8..6feb57ba33d4abcd826f25b56ce841f6bb83d3bd 100644
--- a/dialogs/check_gpu.py
+++ b/dialogs/check_gpu.py
@@ -1,29 +1,39 @@
 import subprocess
 import platform
 
+
 def check_nvidia_gpu():
     try:
         # Run the nvidia-smi command and capture the output
-        output = subprocess.check_output(["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv,noheader"], stderr=subprocess.STDOUT)
+        output = subprocess.check_output(
+            ["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv,noheader"],
+            stderr=subprocess.STDOUT,
+        )
         output = output.decode("utf-8").strip()
-        
+
         # Parse the output
-        gpu_info = output.split(',')
+        gpu_info = output.split(",")
         gpu_name = gpu_info[0].strip()
-        
-        output_cuda_version = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
-        for line in output_cuda_version.stdout.split('\n'):
-            if 'CUDA Version' in line:
-                cuda_version = line.split('CUDA Version: ')[1].split()[0]
-        
+
+        output_cuda_version = subprocess.run(
+            ["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
+        )
+        for line in output_cuda_version.stdout.split("\n"):
+            if "CUDA Version" in line:
+                cuda_version = line.split("CUDA Version: ")[1].split()[0]
+
         return True, gpu_name, cuda_version
     except (subprocess.CalledProcessError, FileNotFoundError):
         return False, None, None
 
+
 def check_amd_gpu():
     try:
         if platform.system() == "Windows":
-            output = subprocess.check_output(["wmic", "path", "win32_videocontroller", "get", "name"], universal_newlines=True)
+            output = subprocess.check_output(
+                ["wmic", "path", "win32_videocontroller", "get", "name"],
+                universal_newlines=True,
+            )
             if "AMD" in output or "Radeon" in output:
                 return True
         elif platform.system() == "Linux":
@@ -31,17 +41,20 @@ def check_amd_gpu():
             if "AMD" in output or "Radeon" in output:
                 return True
         elif platform.system() == "Darwin":
-            output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], universal_newlines=True)
+            output = subprocess.check_output(
+                ["system_profiler", "SPDisplaysDataType"], universal_newlines=True
+            )
             if "AMD" in output or "Radeon" in output:
                 return True
     except subprocess.CalledProcessError:
         return False
     return False
 
+
 def has_gpu():
     has_nvidia, gpu_name, cuda_version = check_nvidia_gpu()
     if has_nvidia:
         return cuda_version
     if check_amd_gpu():
-        return 'amd'
-    return 'cpu'
+        return "amd"
+    return "cpu"
diff --git a/dialogs/packages_installer/packages_installer_dialog.py b/dialogs/packages_installer/packages_installer_dialog.py
index 7a0cde4ea94fa2e59ef68821ec3534ac3be6195c..c170abcd733464727c2a06d3bc6ea1b4ff234dc3 100644
--- a/dialogs/packages_installer/packages_installer_dialog.py
+++ b/dialogs/packages_installer/packages_installer_dialog.py
@@ -9,7 +9,6 @@ import os
 import subprocess
 import sys
 import traceback
-import urllib
 from dataclasses import dataclass
 from pathlib import Path
 from threading import Thread
@@ -18,46 +17,50 @@ from typing import List
 from PyQt5.QtWidgets import (
     QAction,
     QToolBar,
-    QDialog,
-    QTextBrowser,
-    QApplication,
     QMessageBox,
-    QDialog
+    QDialog,
 )
 from PyQt5.QtCore import pyqtSignal, QObject
-from qgis.core import QgsApplication
 from qgis.gui import QgisInterface
 
 from qgis.PyQt import QtCore, uic
-from qgis.PyQt.QtCore import pyqtSignal
 from qgis.PyQt.QtGui import QCloseEvent
+
 # from qgis.PyQt.QtWidgets import QDialog, QMessageBox, QTextBrowser
-from ...icons import (QIcon_EncoderTool, 
-                    QIcon_ReductionTool, 
-                    QIcon_ClusterTool, 
-                    QIcon_SimilarityTool, 
-                    QIcon_RandomforestTool,
-                    )
+from ...icons import (
+    QIcon_EncoderTool,
+    QIcon_ReductionTool,
+    QIcon_ClusterTool,
+    QIcon_SimilarityTool,
+    QIcon_RandomforestTool,
+)
 
 PLUGIN_NAME = "iamap"
 
 PYTHON_VERSION = sys.version_info
 SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
-PLUGIN_ROOT_DIR = os.path.realpath(os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..')))
-PACKAGES_INSTALL_DIR = os.path.join(PLUGIN_ROOT_DIR, f'python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}')
+PLUGIN_ROOT_DIR = os.path.realpath(
+    os.path.abspath(os.path.join(SCRIPT_DIR, "..", ".."))
+)
+PACKAGES_INSTALL_DIR = os.path.join(
+    PLUGIN_ROOT_DIR, f"python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}"
+)
 
 
-FORM_CLASS, _ = uic.loadUiType(os.path.join(
-    os.path.dirname(__file__), 'packages_installer_dialog.ui'))
+FORM_CLASS, _ = uic.loadUiType(
+    os.path.join(os.path.dirname(__file__), "packages_installer_dialog.ui")
+)
 
-_ERROR_COLOR = '#ff0000'
+_ERROR_COLOR = "#ff0000"
 
 if sys.platform == "linux" or sys.platform == "linux2":
     PYTHON_EXECUTABLE_PATH = sys.executable
 elif sys.platform == "darwin":  # MacOS
-    PYTHON_EXECUTABLE_PATH = str(Path(sys.prefix) / 'bin' / 'python3')  # sys.executable yields QGIS in macOS
+    PYTHON_EXECUTABLE_PATH = str(
+        Path(sys.prefix) / "bin" / "python3"
+    )  # sys.executable yields QGIS in macOS
 elif sys.platform == "win32":
-    PYTHON_EXECUTABLE_PATH = 'python'  # sys.executable yields QGis.exe in Windows
+    PYTHON_EXECUTABLE_PATH = "python"  # sys.executable yields QGis.exe in Windows
 else:
     raise Exception("Unsupported operating system!")
 
@@ -69,7 +72,7 @@ class PackageToInstall:
     import_name: str  # name while importing package
 
     def __str__(self):
-        return f'{self.name}{self.version}'
+        return f"{self.name}{self.version}"
 
 
 class PackagesInstallerDialog(QDialog, FORM_CLASS):
@@ -78,27 +81,33 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
     UI design defined in the `packages_installer_dialog.ui` file.
     """
 
-    signal_log_line = pyqtSignal(str)  # we need to use signal because we cannot edit GUI from another thread
+    signal_log_line = pyqtSignal(
+        str
+    )  # we need to use signal because we cannot edit GUI from another thread
 
-    INSTALLATION_IN_PROGRESS = False  # to make sure we will not start the installation twice
+    INSTALLATION_IN_PROGRESS = (
+        False  # to make sure we will not start the installation twice
+    )
 
     def __init__(self, iface, packages_to_install, device, parent=None):
         super(PackagesInstallerDialog, self).__init__(parent)
         self.setupUi(self)
         self.iface = iface
         self.tb = self.textBrowser_log  # type: QTextBrowser
-        self.packages_to_install=packages_to_install
-        self.device=device
+        self.packages_to_install = packages_to_install
+        self.device = device
         self._create_connections()
         self._setup_message()
         self.aborted = False
         self.thread = None
 
     def move_to_top(self):
-        """ Move the window to the top.
+        """Move the window to the top.
         Although if installed from plugin manager, the plugin manager will move itself to the top anyway.
         """
-        self.setWindowState((self.windowState() & ~QtCore.Qt.WindowMinimized) | QtCore.Qt.WindowActive)
+        self.setWindowState(
+            (self.windowState() & ~QtCore.Qt.WindowMinimized) | QtCore.Qt.WindowActive
+        )
 
         if sys.platform == "linux" or sys.platform == "linux2":
             pass
@@ -111,39 +120,42 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
 
     def _create_connections(self):
         self.pushButton_close.clicked.connect(self.close)
-        self.pushButton_install_packages.clicked.connect(self._run_packages_installation)
+        self.pushButton_install_packages.clicked.connect(
+            self._run_packages_installation
+        )
         self.signal_log_line.connect(self._log_line)
 
     def _log_line(self, txt):
-        txt = txt \
-            .replace('  ', '  ') \
-            .replace('\n', '<br>')
+        txt = txt.replace("  ", "&nbsp;&nbsp;").replace("\n", "<br>")
         self.tb.append(txt)
 
     def log(self, txt):
         self.signal_log_line.emit(txt)
 
     def _setup_message(self) -> None:
-          
-        self.log(f'<h2><span style="color: #000080;"><strong>  '
-                 f'Plugin {PLUGIN_NAME} - Packages installer </strong></span></h2> \n'
-                 f'\n'
-                 f'<b>This plugin requires the following Python packages to be installed:</b>')
-        
-        for package in self.packages_to_install:
-            self.log(f'\t- {package.name}{package.version}')
+        self.log(
+            f'<h2><span style="color: #000080;"><strong>  '
+            f"Plugin {PLUGIN_NAME} - Packages installer </strong></span></h2> \n"
+            f"\n"
+            f"<b>This plugin requires the following Python packages to be installed:</b>"
+        )
 
-        self.log('\n\n'
-                 f'If this packages are not installed in the global environment '
-                 f'(or environment in which QGIS is started) '
-                 f'you can install these packages in the local directory (which is included to the Python path).\n\n'
-                 f'This Dialog does it for you! (Though you can still install these packages manually instead).\n'
-                 f'<b>Please click "Install packages" button below to install them automatically, </b>'
-                 f'or "Test and Close" if you installed them manually...\n')
+        for package in self.packages_to_install:
+            self.log(f"\t- {package.name}{package.version}")
+
+        self.log(
+            "\n\n"
+            "If this packages are not installed in the global environment "
+            "(or environment in which QGIS is started) "
+            "you can install these packages in the local directory (which is included to the Python path).\n\n"
+            "This Dialog does it for you! (Though you can still install these packages manually instead).\n"
+            '<b>Please click "Install packages" button below to install them automatically, </b>'
+            'or "Test and Close" if you installed them manually...\n'
+        )
 
     def _run_packages_installation(self):
         if self.INSTALLATION_IN_PROGRESS:
-            self.log(f'Error! Installation already in progress, cannot start again!')
+            self.log("Error! Installation already in progress, cannot start again!")
             return
         self.aborted = False
         self.INSTALLATION_IN_PROGRESS = True
@@ -151,28 +163,32 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
         self.thread.start()
 
     def _install_packages(self) -> None:
-        self.log('\n\n')
-        self.log('=' * 60)
-        self.log(f'<h3><b>Attempting to install required packages...</b></h3>')
+        self.log("\n\n")
+        self.log("=" * 60)
+        self.log("<h3><b>Attempting to install required packages...</b></h3>")
         os.makedirs(PACKAGES_INSTALL_DIR, exist_ok=True)
 
         self._install_pip_if_necessary()
 
-        self.log(f'<h3><b>Attempting to install required packages...</b></h3>\n')
+        self.log("<h3><b>Attempting to install required packages...</b></h3>\n")
         try:
             self._pip_install_packages(self.packages_to_install)
         except Exception as e:
-            msg = (f'\n <span style="color: {_ERROR_COLOR};"><b> '
-                   f'Packages installation failed with exception: {e}!\n'
-                   f'Please try to install the packages again. </b></span>'
-                   f'\nCheck if there is no error related to system packages, '
-                   f'which may be required to be installed by your system package manager, e.g. "apt". '
-                   f'Copy errors from the stack above and google for possible solutions. '
-                   f'Please report these as an issue on the plugin repository tracker!')
+            msg = (
+                f'\n <span style="color: {_ERROR_COLOR};"><b> '
+                f"Packages installation failed with exception: {e}!\n"
+                f"Please try to install the packages again. </b></span>"
+                f"\nCheck if there is no error related to system packages, "
+                f'which may be required to be installed by your system package manager, e.g. "apt". '
+                f"Copy errors from the stack above and google for possible solutions. "
+                f"Please report these as an issue on the plugin repository tracker!"
+            )
             self.log(msg)
 
         # finally, validate the installation, if there was no error so far...
-        self.log('\n\n <b>Installation of required packages finished. Validating installation...</b>')
+        self.log(
+            "\n\n <b>Installation of required packages finished. Validating installation...</b>"
+        )
         self._check_packages_installation_and_log()
         self.INSTALLATION_IN_PROGRESS = False
 
@@ -182,28 +198,33 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
     def closeEvent(self, event: QCloseEvent):
         self.aborted = True
         if self._check_packages_installation_and_log():
-            res = QMessageBox.information(self.iface.mainWindow(),
-                                       f'{PLUGIN_NAME} - Installation done !',
-                                       'Restart QGIS for the plugin to load properly.',
-                                       QMessageBox.Ok)
+            res = QMessageBox.information(
+                self.iface.mainWindow(),
+                f"{PLUGIN_NAME} - Installation done !",
+                "Restart QGIS for the plugin to load properly.",
+                QMessageBox.Ok,
+            )
             if res == QMessageBox.Ok:
-                log_msg = 'User accepted to restart QGIS'
+                log_msg = "User accepted to restart QGIS"
                 event.accept()
             return
 
-        res = QMessageBox.question(self.iface.mainWindow(),
-                                   f'{PLUGIN_NAME} - skip installation?',
-                                   'Are you sure you want to abort the installation of the required python packages? '
-                                   'The plugin may not function correctly without them!',
-                                   QMessageBox.No, QMessageBox.Yes)
-        log_msg = 'User requested to close the dialog, but the packages are not installed correctly!\n'
+        res = QMessageBox.question(
+            self.iface.mainWindow(),
+            f"{PLUGIN_NAME} - skip installation?",
+            "Are you sure you want to abort the installation of the required python packages? "
+            "The plugin may not function correctly without them!",
+            QMessageBox.No,
+            QMessageBox.Yes,
+        )
+        log_msg = "User requested to close the dialog, but the packages are not installed correctly!\n"
         if res == QMessageBox.Yes:
-            log_msg += 'And the user confirmed to close the dialog, knowing the risk!'
+            log_msg += "And the user confirmed to close the dialog, knowing the risk!"
             event.accept()
         else:
-            log_msg += 'The user reconsidered their decision, and will try to install the packages again!'
+            log_msg += "The user reconsidered their decision, and will try to install the packages again!"
             event.ignore()
-        log_msg += '\n'
+        log_msg += "\n"
         self.log(log_msg)
 
     def _install_pip_if_necessary(self):
@@ -214,18 +235,22 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
         TODO - investigate whether we can also install pip in local directory
         """
 
-        self.log(f'<h4><b>Making sure pip is installed...</b></h4>')
+        self.log("<h4><b>Making sure pip is installed...</b></h4>")
         if check_pip_installed():
-            self.log(f'<em>Pip is installed, skipping installation...</em>\n')
+            self.log("<em>Pip is installed, skipping installation...</em>\n")
             return
 
-        install_pip_command = [PYTHON_EXECUTABLE_PATH, '-m', 'ensurepip']
-        self.log(f'<em>Running command to install pip: \n  $ {" ".join(install_pip_command)} </em>')
-        with subprocess.Popen(install_pip_command,
-                              stdout=subprocess.PIPE,
-                              universal_newlines=True,
-                              stderr=subprocess.STDOUT,
-                              env={'SETUPTOOLS_USE_DISTUTILS': 'stdlib'}) as process:
+        install_pip_command = [PYTHON_EXECUTABLE_PATH, "-m", "ensurepip"]
+        self.log(
+            f'<em>Running command to install pip: \n  $ {" ".join(install_pip_command)} </em>'
+        )
+        with subprocess.Popen(
+            install_pip_command,
+            stdout=subprocess.PIPE,
+            universal_newlines=True,
+            stderr=subprocess.STDOUT,
+            env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"},
+        ) as process:
             try:
                 self._do_process_output_logging(process)
             except InterruptedError as e:
@@ -233,48 +258,67 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
                 return False
 
         if process.returncode != 0:
-            msg = (f'<span style="color: {_ERROR_COLOR};"><b>'
-                   f'pip installation failed! Consider installing it manually.'
-                   f'<b></span>')
+            msg = (
+                f'<span style="color: {_ERROR_COLOR};"><b>'
+                f"pip installation failed! Consider installing it manually."
+                f"<b></span>"
+            )
             self.log(msg)
-        self.log('\n')
+        self.log("\n")
 
     def _pip_install_packages(self, packages: List[PackageToInstall]) -> None:
-        cmd = [PYTHON_EXECUTABLE_PATH, '-m', 'pip', 'install', '-U', f'--target={PACKAGES_INSTALL_DIR}']               
-        cmd_string = ' '.join(cmd)
-        
+        cmd = [
+            PYTHON_EXECUTABLE_PATH,
+            "-m",
+            "pip",
+            "install",
+            "-U",
+            f"--target={PACKAGES_INSTALL_DIR}",
+        ]
+        cmd_string = " ".join(cmd)
+
         for pck in packages:
             if ("index-url") not in pck.version:
                 cmd.append(f" {pck}")
                 cmd_string += f" {pck}"
-            
-            elif pck.name == 'torch':
+
+            elif pck.name == "torch":
                 torch_url = pck.version.split("index-url ")[-1]
-        
-                cmd_torch = [PYTHON_EXECUTABLE_PATH, '-m', 'pip', 'install', '-U', f'--target={PACKAGES_INSTALL_DIR}', 'torch', f"--index-url={torch_url}"] 
-                cmd_torch_string = ' '.join(cmd_torch)
-
-                self.log(f'<em>Running command: \n  $ {cmd_torch_string} </em>')
-                with subprocess.Popen(cmd_torch,
-                                    stdout=subprocess.PIPE,
-                                    universal_newlines=True,
-                                    stderr=subprocess.STDOUT) as process:
-                    self._do_process_output_logging(process)
 
+                cmd_torch = [
+                    PYTHON_EXECUTABLE_PATH,
+                    "-m",
+                    "pip",
+                    "install",
+                    "-U",
+                    f"--target={PACKAGES_INSTALL_DIR}",
+                    "torch",
+                    f"--index-url={torch_url}",
+                ]
+                cmd_torch_string = " ".join(cmd_torch)
+
+                self.log(f"<em>Running command: \n  $ {cmd_torch_string} </em>")
+                with subprocess.Popen(
+                    cmd_torch,
+                    stdout=subprocess.PIPE,
+                    universal_newlines=True,
+                    stderr=subprocess.STDOUT,
+                ) as process:
+                    self._do_process_output_logging(process)
 
-        self.log(f'<em>Running command: \n  $ {cmd_string} </em>')
-        with subprocess.Popen(cmd,
-                              stdout=subprocess.PIPE,
-                              universal_newlines=True,
-                              stderr=subprocess.STDOUT) as process:
+        self.log(f"<em>Running command: \n  $ {cmd_string} </em>")
+        with subprocess.Popen(
+            cmd,
+            stdout=subprocess.PIPE,
+            universal_newlines=True,
+            stderr=subprocess.STDOUT,
+        ) as process:
             self._do_process_output_logging(process)
 
         if process.returncode != 0:
-            raise RuntimeError('Installation with pip failed')
+            raise RuntimeError("Installation with pip failed")
 
-        msg = (f'\n<b>'
-               f'Packages installed correctly!'
-               f'<b>\n\n')
+        msg = "\n<b>" "Packages installed correctly!" "<b>\n\n"
         self.log(msg)
 
     def _do_process_output_logging(self, process: subprocess.Popen) -> None:
@@ -284,36 +328,43 @@ class PackagesInstallerDialog(QDialog, FORM_CLASS):
         for stdout_line in iter(process.stdout.readline, ""):
             if stdout_line.isspace():
                 continue
-            txt = f'<span style="color: #999999;">{stdout_line.rstrip(os.linesep)}</span>'
+            txt = (
+                f'<span style="color: #999999;">{stdout_line.rstrip(os.linesep)}</span>'
+            )
             self.log(txt)
             if self.aborted:
-                raise InterruptedError('Installation aborted by user')
+                raise InterruptedError("Installation aborted by user")
 
     def _check_packages_installation_and_log(self) -> bool:
         packages_ok = are_packages_importable(self.device)
         self.pushButton_install_packages.setEnabled(not packages_ok)
 
         if packages_ok:
-            msg1 = f'All required packages are importable! You can close this window now! Be sure to restart QGIS.'
+            msg1 = "All required packages are importable! You can close this window now! Be sure to restart QGIS."
             self.log(msg1)
             return True
 
         try:
             import_packages(self.device)
-            raise Exception("Unexpected successful import of packages?!? It failed a moment ago, we shouldn't be here!")
+            raise Exception(
+                "Unexpected successful import of packages?!? It failed a moment ago, we shouldn't be here!"
+            )
         except Exception:
-            msg_base = '<b>Python packages required by the plugin could not be loaded due to the following error:</b>'
+            msg_base = "<b>Python packages required by the plugin could not be loaded due to the following error:</b>"
             logging.exception(msg_base)
             tb = traceback.format_exc()
-            msg1 = (f'<span style="color: {_ERROR_COLOR};">'
-                    f'{msg_base} \n '
-                    f'{tb}\n\n'
-                    f'<b>Please try installing the packages again.<b>'
-                    f'</span>')
+            msg1 = (
+                f'<span style="color: {_ERROR_COLOR};">'
+                f"{msg_base} \n "
+                f"{tb}\n\n"
+                f"<b>Please try installing the packages again.<b>"
+                f"</span>"
+            )
             self.log(msg1)
 
         return False
 
+
 def get_pytorch_version(cuda_version):
     # Map CUDA versions to PyTorch versions
     ## cf. https://pytorch.org/get-started/locally/
@@ -328,78 +379,68 @@ def get_pytorch_version(cuda_version):
 
 
 def get_packages_to_install(device):
-
-    requirements_path = os.path.join(PLUGIN_ROOT_DIR, 'requirements.txt')
+    requirements_path = os.path.join(PLUGIN_ROOT_DIR, "requirements.txt")
     packages_to_install = []
 
-    if device == 'cpu':
+    if device == "cpu":
         pass
 
-    else :
-        if device == 'amd':
+    else:
+        if device == "amd":
             packages_to_install.append(
-                    PackageToInstall(
-                        name='torch', 
-                        version=' --index-url https://download.pytorch.org/whl/rocm6.1', 
-                        import_name='torch'
-                        )
-                    )
+                PackageToInstall(
+                    name="torch",
+                    version=" --index-url https://download.pytorch.org/whl/rocm6.1",
+                    import_name="torch",
+                )
+            )
 
         else:
             packages_to_install.append(
-                    PackageToInstall(
-                        name='torch', 
-                        version=get_pytorch_version(device), 
-                        import_name='torch'
-                        )
-                    )
-        
-
+                PackageToInstall(
+                    name="torch",
+                    version=get_pytorch_version(device),
+                    import_name="torch",
+                )
+            )
 
-    with open(requirements_path, 'r') as f:
+    with open(requirements_path, "r") as f:
         raw_txt = f.read()
 
     libraries_versions = {}
 
-    for line in raw_txt.split('\n'):
-        if line.startswith('#') or not line.strip():
+    for line in raw_txt.split("\n"):
+        if line.startswith("#") or not line.strip():
             continue
 
-        line = line.split(';')[0]
-
-        if '==' in line:
-            lib, version = line.split('==')
-            libraries_versions[lib] = '==' + version
-        elif '>=' in line:
-            lib, version = line.split('>=')
-            libraries_versions[lib] = '>=' + version
-        elif '<=' in line:
-            lib, version = line.split('<=')
-            libraries_versions[lib] = '<=' + version
+        line = line.split(";")[0]
+
+        if "==" in line:
+            lib, version = line.split("==")
+            libraries_versions[lib] = "==" + version
+        elif ">=" in line:
+            lib, version = line.split(">=")
+            libraries_versions[lib] = ">=" + version
+        elif "<=" in line:
+            lib, version = line.split("<=")
+            libraries_versions[lib] = "<=" + version
         else:
-            libraries_versions[line] = ''
-
+            libraries_versions[line] = ""
 
     for lib, version in libraries_versions.items():
-
         import_name = lib[:-1]
 
-        if lib == 'scikit-learn ':
-            import_name = 'sklearn'
-        if lib == 'umap-learn ':
-            import_name = 'umap'
+        if lib == "scikit-learn ":
+            import_name = "sklearn"
+        if lib == "umap-learn ":
+            import_name = "umap"
 
         packages_to_install.append(
-                PackageToInstall(
-                    name=lib, 
-                    version=version, 
-                    import_name=import_name
-                    )
-                )
+            PackageToInstall(name=lib, version=version, import_name=import_name)
+        )
     return packages_to_install
 
 
-
 def import_package(package: PackageToInstall):
     importlib.import_module(package.import_name)
 
@@ -414,7 +455,9 @@ def are_packages_importable(device) -> bool:
     try:
         import_packages(device)
     except Exception:
-        logging.exception(f'Python packages required by the plugin could not be loaded due to the following error:')
+        logging.exception(
+            "Python packages required by the plugin could not be loaded due to the following error:"
+        )
         return False
 
     return True
@@ -422,17 +465,21 @@ def are_packages_importable(device) -> bool:
 
 def check_pip_installed() -> bool:
     try:
-        subprocess.check_output([PYTHON_EXECUTABLE_PATH, '-m', 'pip', '--version'])
+        subprocess.check_output([PYTHON_EXECUTABLE_PATH, "-m", "pip", "--version"])
         return True
     except subprocess.CalledProcessError:
         return False
 
 
 dialog = None
-def check_required_packages_and_install_if_necessary(iface, device='cpu'):
+
+
+def check_required_packages_and_install_if_necessary(iface, device="cpu"):
     os.makedirs(PACKAGES_INSTALL_DIR, exist_ok=True)
     if PACKAGES_INSTALL_DIR not in sys.path:
-        sys.path.append(PACKAGES_INSTALL_DIR)  # TODO: check for a less intrusive way to do this
+        sys.path.append(
+            PACKAGES_INSTALL_DIR
+        )  # TODO: check for a less intrusive way to do this
 
     if are_packages_importable(device):
         # if packages are importable we are fine, nothing more to do then
@@ -440,7 +487,9 @@ def check_required_packages_and_install_if_necessary(iface, device='cpu'):
 
     global dialog
     packages_to_install = get_packages_to_install(device)
-    dialog = PackagesInstallerDialog(iface, packages_to_install=packages_to_install, device=device)
+    dialog = PackagesInstallerDialog(
+        iface, packages_to_install=packages_to_install, device=device
+    )
     dialog.setWindowModality(QtCore.Qt.WindowModal)
     dialog.show()
     dialog.move_to_top()
@@ -463,34 +512,34 @@ class IAMapEmpty(QObject):
     def initGui(self):
         self.initProcessing()
 
-        self.toolbar: QToolBar = self.iface.addToolBar('IAMap Toolbar')
-        self.toolbar.setObjectName('IAMapToolbar')
-        self.toolbar.setToolTip('IAMap Toolbar')
+        self.toolbar: QToolBar = self.iface.addToolBar("IAMap Toolbar")
+        self.toolbar.setObjectName("IAMapToolbar")
+        self.toolbar.setToolTip("IAMap Toolbar")
 
         self.actionEncoder = QAction(
             QIcon_EncoderTool,
             "Install dependencies and restart QGIS ! - Deep Learning Image Encoder",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionReducer = QAction(
             QIcon_ReductionTool,
             "Install dependencies and restart QGIS ! - Reduce dimensions",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionCluster = QAction(
             QIcon_ClusterTool,
             "Install dependencies and restart QGIS ! - Cluster raster",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionSimilarity = QAction(
             QIcon_SimilarityTool,
             "Install dependencies and restart QGIS ! - Compute similarity",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionRF = QAction(
             QIcon_RandomforestTool,
             "Install dependencies and restart QGIS ! - Fit Machine Learning algorithm",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionEncoder.setObjectName("mActionEncoder")
         self.actionReducer.setObjectName("mActionReducer")
@@ -499,15 +548,20 @@ class IAMapEmpty(QObject):
         self.actionRF.setObjectName("mactionRF")
 
         self.actionEncoder.setToolTip(
-            "Install dependencies and restart QGIS ! - Encode a raster with a deep learning backbone")
+            "Install dependencies and restart QGIS ! - Encode a raster with a deep learning backbone"
+        )
         self.actionReducer.setToolTip(
-            "Install dependencies and restart QGIS ! - Reduce raster dimensions")
+            "Install dependencies and restart QGIS ! - Reduce raster dimensions"
+        )
         self.actionCluster.setToolTip(
-            "Install dependencies and restart QGIS ! - Cluster raster")
+            "Install dependencies and restart QGIS ! - Cluster raster"
+        )
         self.actionSimilarity.setToolTip(
-            "Install dependencies and restart QGIS ! - Compute similarity")
+            "Install dependencies and restart QGIS ! - Compute similarity"
+        )
         self.actionRF.setToolTip(
-            "Install dependencies and restart QGIS ! - Fit ML model")
+            "Install dependencies and restart QGIS ! - Fit ML model"
+        )
 
         # self.actionEncoder.triggered.connect()
         # self.actionReducer.triggered.connect()
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 6e682b55bb13a164dbba852cbcff5a15d3cb0350..8c8f77dac6606fb63f1c336aaecfd31419309797 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -5,37 +5,63 @@
 
 # -- Project information -----------------------------------------------------
 # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+import os
+import re
 
-project = 'iamap'
-copyright = '2024, TRESSON Paul, TULET Hadrien, LE COZ Pierre'
-author = 'TRESSON Paul, TULET Hadrien, LE COZ Pierre'
-release = '0.5.9'
+
+metadata_file_path = os.path.join('..', '..', 'metadata.txt')
+metadata_file_path = os.path.abspath(metadata_file_path)
+with open(metadata_file_path, 'rt') as file:
+    file_content = file.read()
+try:
+    versions_from_metadata = re.findall(r'version=(.*)', file_content)[0]
+except Exception as e:
+    raise Exception("Failed to read version from metadata!")
+
+try:
+    author_from_metadata = re.findall(r'author=(.*)', file_content)[0]
+except Exception as e:
+    raise Exception("Failed to read author from metadata!")
+
+try:
+    name_from_metadata = re.findall(r'name=(.*)', file_content)[0]
+except Exception as e:
+    raise Exception("Failed to read name from metadata!")
+
+project = "iamap"
+copyright = "2024, TRESSON Paul, TULET Hadrien, LE COZ Pierre"
+author = author_from_metadata
+# The short X.Y version
+version = versions_from_metadata
+# The full version, including alpha/beta/rc tags
+release = version
 
 # -- General configuration ---------------------------------------------------
 # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
 
-import pydata_sphinx_theme
+
+
+
 
 extensions = [
-    'sphinx.ext.autodoc',
-    'sphinx.ext.autosummary',
+    "sphinx.ext.autodoc",
+    "sphinx.ext.autosummary",
     "myst_parser",
     "sphinx_favicon",
 ]
 
-templates_path = ['_templates']
+templates_path = ["_templates"]
 exclude_patterns = []
 
 source_suffix = {
-    '.rst': 'restructuredtext',
-    '.md': 'markdown',
+    ".rst": "restructuredtext",
+    ".md": "markdown",
 }
 
 
 # -- Options for HTML output -------------------------------------------------
 # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
 
-html_theme = 'pydata_sphinx_theme'
-html_static_path = ['_static']
+html_theme = "pydata_sphinx_theme"
+html_static_path = ["_static"]
 html_favicon = "./../../icons/favicon.svg"
-
diff --git a/encoder.py b/encoder.py
index 67f995cb77f1b8c87046e639b8a3507e770ca8be..1bd4834fdaf641600abfb4f0703681e071fc8a07 100644
--- a/encoder.py
+++ b/encoder.py
@@ -1,6 +1,5 @@
 import os
 import logging
-import sys
 import time
 import tempfile
 import re
@@ -11,27 +10,21 @@ import json
 
 import rasterio
 from qgis.PyQt.QtCore import QCoreApplication
-from qgis.core import (Qgis,
-                       QgsGeometry,
-                       QgsCoordinateTransform,
-                       QgsProcessingException,
-                       QgsProcessingAlgorithm,
-                       QgsProcessingParameterRasterLayer,
-                       QgsProcessingParameterFolderDestination,
-                       QgsProcessingParameterBand,
-                       QgsProcessingParameterNumber,
-                       QgsProcessingParameterBoolean,
-                       QgsProcessingParameterFile,
-                       QgsProcessingParameterString,
-                       QgsProcessingParameterEnum,
-                       QgsProcessingParameterExtent,
-                       QgsProcessingParameterCrs,
-                       QgsProcessingParameterDefinition,
-                       )
+from qgis.core import (
+    QgsProcessingParameterRasterLayer,
+    QgsProcessingParameterFolderDestination,
+    QgsProcessingParameterBand,
+    QgsProcessingParameterNumber,
+    QgsProcessingParameterBoolean,
+    QgsProcessingParameterFile,
+    QgsProcessingParameterString,
+    QgsProcessingParameterEnum,
+    QgsProcessingParameterExtent,
+    QgsProcessingParameterCrs,
+    QgsProcessingParameterDefinition,
+)
 
 import torch
-import torch.nn as nn
-from torch import Tensor
 import torch.quantization
 from torch.utils.data import DataLoader
 import torchvision.transforms as T
@@ -46,16 +39,17 @@ import timm
 
 from .utils.geo import get_mean_sd_by_band
 from .utils.geo import merge_tiles
-from .utils.misc import (QGISLogHandler, 
-                         get_dir_size, 
-                         get_model_size, 
-                         remove_files, 
-                         check_disk_space,
-                         get_unique_filename,
-                         save_parameters_to_json,
-                         compute_md5_hash,
-                         log_parameters_to_csv,
-                         )
+from .utils.misc import (
+    QGISLogHandler,
+    get_dir_size,
+    get_model_size,
+    remove_files,
+    check_disk_space,
+    get_unique_filename,
+    save_parameters_to_json,
+    compute_md5_hash,
+    log_parameters_to_csv,
+)
 from .utils.trch import quantize_model
 from .utils.algo import IAMAPAlgorithm
 
@@ -67,36 +61,32 @@ from .tg.transforms import AugmentationSequential
 from .icons import QIcon_EncoderTool
 
 
-
-
 class EncoderAlgorithm(IAMAPAlgorithm):
-    """
-    """
-
-    FEAT_OPTION= 'FEAT_OPTION'
-    INPUT = 'INPUT'
-    CKPT = 'CKPT'
-    BANDS = 'BANDS'
-    STRIDE = 'STRIDE'
-    SIZE = 'SIZE'
-    EXTENT = 'EXTENT'
-    QUANT = 'QUANT'
-    OUTPUT = 'OUTPUT'
-    RESOLUTION = 'RESOLUTION'
-    CRS = 'CRS'
-    CUDA = 'CUDA'
-    BATCH_SIZE = 'BATCH_SIZE'
-    CUDA_ID = 'CUDA_ID'
-    BACKBONE_CHOICE = 'BACKBONE_CHOICE'
-    BACKBONE_OPT = 'BACKBONE_OPT'
-    MERGE_METHOD = 'MERGE_METHOD'
-    WORKERS = 'WORKERS'
-    PAUSES = 'PAUSES'
-    REMOVE_TEMP_FILES = 'REMOVE_TEMP_FILES'
-    TEMP_FILES_CLEANUP_FREQ = 'TEMP_FILES_CLEANUP_FREQ'
-    JSON_PARAM = 'JSON_PARAM'
-    COMPRESS = 'COMPRESS'
-    
+    """ """
+
+    FEAT_OPTION = "FEAT_OPTION"
+    INPUT = "INPUT"
+    CKPT = "CKPT"
+    BANDS = "BANDS"
+    STRIDE = "STRIDE"
+    SIZE = "SIZE"
+    EXTENT = "EXTENT"
+    QUANT = "QUANT"
+    OUTPUT = "OUTPUT"
+    RESOLUTION = "RESOLUTION"
+    CRS = "CRS"
+    CUDA = "CUDA"
+    BATCH_SIZE = "BATCH_SIZE"
+    CUDA_ID = "CUDA_ID"
+    BACKBONE_CHOICE = "BACKBONE_CHOICE"
+    BACKBONE_OPT = "BACKBONE_OPT"
+    MERGE_METHOD = "MERGE_METHOD"
+    WORKERS = "WORKERS"
+    PAUSES = "PAUSES"
+    REMOVE_TEMP_FILES = "REMOVE_TEMP_FILES"
+    TEMP_FILES_CLEANUP_FREQ = "TEMP_FILES_CLEANUP_FREQ"
+    JSON_PARAM = "JSON_PARAM"
+    COMPRESS = "COMPRESS"
 
     def initAlgorithm(self, config=None):
         """
@@ -109,17 +99,16 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         self.addParameter(
             QgsProcessingParameterRasterLayer(
                 name=self.INPUT,
-                description=self.tr(
-                    'Input raster layer or image file path'),
-            defaultValue=os.path.join(cwd,'assets','test.tif'),
+                description=self.tr("Input raster layer or image file path"),
+                defaultValue=os.path.join(cwd, "assets", "test.tif"),
             ),
         )
 
         self.addParameter(
             QgsProcessingParameterBand(
                 name=self.BANDS,
-                description=self.tr('Selected Bands (defaults to all bands selected)'),
-                defaultValue = None, 
+                description=self.tr("Selected Bands (defaults to all bands selected)"),
+                defaultValue=None,
                 parentLayerParameterName=self.INPUT,
                 optional=True,
                 allowMultiple=True,
@@ -128,78 +117,82 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         compress_param = QgsProcessingParameterBoolean(
             name=self.COMPRESS,
             description=self.tr(
-                'Compress final result to uint16 and JP2 to save space'),
+                "Compress final result to uint16 and JP2 to save space"
+            ),
             defaultValue=False,
             optional=True,
         )
 
         crs_param = QgsProcessingParameterCrs(
             name=self.CRS,
-            description=self.tr('Target CRS (default to original CRS)'),
+            description=self.tr("Target CRS (default to original CRS)"),
             optional=True,
         )
 
         res_param = QgsProcessingParameterNumber(
             name=self.RESOLUTION,
             description=self.tr(
-                'Target resolution in meters (default to native resolution)'),
+                "Target resolution in meters (default to native resolution)"
+            ),
             type=QgsProcessingParameterNumber.Double,
             optional=True,
             minValue=0,
-            maxValue=100000
+            maxValue=100000,
         )
 
         cuda_id_param = QgsProcessingParameterNumber(
             name=self.CUDA_ID,
             description=self.tr(
-                'CUDA Device ID (choose which GPU to use, default to device 0)'),
+                "CUDA Device ID (choose which GPU to use, default to device 0)"
+            ),
             type=QgsProcessingParameterNumber.Integer,
             defaultValue=0,
             minValue=0,
-            maxValue=9
+            maxValue=9,
         )
         nworkers_param = QgsProcessingParameterNumber(
             name=self.WORKERS,
-            description=self.tr(
-                'Number of CPU workers for dataloader (0 selects all)'),
+            description=self.tr("Number of CPU workers for dataloader (0 selects all)"),
             type=QgsProcessingParameterNumber.Integer,
             defaultValue=0,
             minValue=0,
-            maxValue=10
+            maxValue=10,
         )
         pauses_param = QgsProcessingParameterNumber(
             name=self.PAUSES,
             description=self.tr(
-                'Schedule pauses between batches to ease CPU usage (in seconds).'),
+                "Schedule pauses between batches to ease CPU usage (in seconds)."
+            ),
             type=QgsProcessingParameterNumber.Integer,
             defaultValue=0,
             minValue=0,
-            maxValue=10000
+            maxValue=10000,
         )
 
         tmp_files_cleanup_frq = QgsProcessingParameterNumber(
             name=self.TEMP_FILES_CLEANUP_FREQ,
             description=self.tr(
-                'Frequencie at which temporary files should be cleaned up (zero means no cleanup).'),
+                "Frequencie at which temporary files should be cleaned up (zero means no cleanup)."
+            ),
             type=QgsProcessingParameterNumber.Integer,
             defaultValue=1000,
             minValue=1,
-            maxValue=10000
+            maxValue=10000,
         )
 
         remove_tmp_files = QgsProcessingParameterBoolean(
             name=self.REMOVE_TEMP_FILES,
             description=self.tr(
-                'Remove temporary files after encoding. If you want to test different merging options, it may be better to keep the tiles.'),
+                "Remove temporary files after encoding. If you want to test different merging options, it may be better to keep the tiles."
+            ),
             defaultValue=True,
         )
 
         self.addParameter(
             QgsProcessingParameterExtent(
                 name=self.EXTENT,
-                description=self.tr(
-                    'Processing extent (default to the entire image)'),
-                optional=True
+                description=self.tr("Processing extent (default to the entire image)"),
+                optional=True,
             )
         )
 
@@ -207,99 +200,97 @@ class EncoderAlgorithm(IAMAPAlgorithm):
             QgsProcessingParameterNumber(
                 name=self.SIZE,
                 description=self.tr(
-                    'Sampling size (the raster will be sampled in a square with a side of that many pixel)'),
+                    "Sampling size (the raster will be sampled in a square with a side of that many pixel)"
+                ),
                 type=QgsProcessingParameterNumber.Integer,
-                defaultValue = 224,
+                defaultValue=224,
                 minValue=1,
-                maxValue=1024
+                maxValue=1024,
             )
         )
 
-
         self.addParameter(
             QgsProcessingParameterNumber(
                 name=self.STRIDE,
                 description=self.tr(
-                    'Stride (If smaller than the sampling size, tiles will overlap. If larger, it may cause errors.)'),
+                    "Stride (If smaller than the sampling size, tiles will overlap. If larger, it may cause errors.)"
+                ),
                 type=QgsProcessingParameterNumber.Integer,
-                defaultValue = 224,
+                defaultValue=224,
                 minValue=1,
-                maxValue=1024
+                maxValue=1024,
             )
         )
 
         chkpt_param = QgsProcessingParameterFile(
-                name=self.CKPT,
-                description=self.tr(
-                    'Pretrained checkpoint'),
-                # extension='pth',
-                fileFilter='Checkpoint Files (*.pth *.pkl);; All Files (*.*)',
-                optional=True,
-                defaultValue=None
-            )
-        
+            name=self.CKPT,
+            description=self.tr("Pretrained checkpoint"),
+            # extension='pth',
+            fileFilter="Checkpoint Files (*.pth *.pkl);; All Files (*.*)",
+            optional=True,
+            defaultValue=None,
+        )
 
         self.addParameter(
             QgsProcessingParameterFolderDestination(
                 self.OUTPUT,
                 self.tr(
-                    "Output directory (choose the location that the image features will be saved)"),
-            defaultValue=tmp_wd,
+                    "Output directory (choose the location that the image features will be saved)"
+                ),
+                defaultValue=tmp_wd,
             )
         )
 
         self.addParameter(
             QgsProcessingParameterBoolean(
-                self.CUDA,
-                self.tr("Use GPU if CUDA is available."),
-                defaultValue=True
+                self.CUDA, self.tr("Use GPU if CUDA is available."), defaultValue=True
             )
         )
         self.backbone_opt = [
-                            'ViT base DINO',
-                            'ViT tiny Imagenet (smallest)', 
-                            'ViT base MAE', 
-                            'SAM', 
-                            '--Empty--'
-                            ]
+            "ViT base DINO",
+            "ViT tiny Imagenet (smallest)",
+            "ViT base MAE",
+            "SAM",
+            "--Empty--",
+        ]
         self.timm_backbone_opt = [
-                            'vit_base_patch16_224.dino',
-                            'vit_tiny_patch16_224.augreg_in21k',
-                            'vit_base_patch16_224.mae',
-                            'samvit_base_patch16.sa1b',
-                            ]
-        self.addParameter (
+            "vit_base_patch16_224.dino",
+            "vit_tiny_patch16_224.augreg_in21k",
+            "vit_base_patch16_224.mae",
+            "samvit_base_patch16.sa1b",
+        ]
+        self.addParameter(
             QgsProcessingParameterEnum(
-                name = self.BACKBONE_OPT,
-                description = self.tr(
-                    "Pre-selected backbones if you don't know what to pick"),
-                defaultValue = 0,
-                options = self.backbone_opt,
-                
+                name=self.BACKBONE_OPT,
+                description=self.tr(
+                    "Pre-selected backbones if you don't know what to pick"
+                ),
+                defaultValue=0,
+                options=self.backbone_opt,
             )
         )
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterString(
-                name = self.BACKBONE_CHOICE,
-                description = self.tr(
-                    'Enter a architecture name if you want to test another backbone (see huggingface.co/timm/)'),
-                defaultValue = None,
+                name=self.BACKBONE_CHOICE,
+                description=self.tr(
+                    "Enter a architecture name if you want to test another backbone (see huggingface.co/timm/)"
+                ),
+                defaultValue=None,
                 optional=True,
             )
         )
-        
 
-        
         self.addParameter(
             QgsProcessingParameterNumber(
                 name=self.BATCH_SIZE,
                 # large images will be sampled into patches in a grid-like fashion
                 description=self.tr(
-                    'Batch size (take effect if choose to use GPU and CUDA is available)'),
+                    "Batch size (take effect if choose to use GPU and CUDA is available)"
+                ),
                 type=QgsProcessingParameterNumber.Integer,
                 defaultValue=1,
                 minValue=1,
-                maxValue=1024
+                maxValue=1024,
             )
         )
 
@@ -307,48 +298,45 @@ class EncoderAlgorithm(IAMAPAlgorithm):
             QgsProcessingParameterBoolean(
                 self.QUANT,
                 self.tr("Quantization of the model to reduce space"),
-                defaultValue=True
+                defaultValue=True,
             )
         )
 
-        self.merge_options = ['first', 'min', 'max','average','sum', 'count', 'last']
+        self.merge_options = ["first", "min", "max", "average", "sum", "count", "last"]
         merge_param = QgsProcessingParameterEnum(
-                name=self.MERGE_METHOD,
-                description=self.tr(
-                    'Merge method at the end of inference.'),
-                options=self.merge_options,
-                defaultValue=0,
-                )
+            name=self.MERGE_METHOD,
+            description=self.tr("Merge method at the end of inference."),
+            options=self.merge_options,
+            defaultValue=0,
+        )
 
         json_param = QgsProcessingParameterFile(
-                name=self.JSON_PARAM,
-                description=self.tr(
-                    'Pass parameters as json file'),
-                # extension='pth',
-                fileFilter='JSON Files (*.json)',
-                optional=True,
-                defaultValue=None
-            )
+            name=self.JSON_PARAM,
+            description=self.tr("Pass parameters as json file"),
+            # extension='pth',
+            fileFilter="JSON Files (*.json)",
+            optional=True,
+            defaultValue=None,
+        )
 
         for param in (
-                crs_param, 
-                res_param, 
-                chkpt_param, 
-                cuda_id_param, 
-                merge_param, 
-                nworkers_param,
-                pauses_param,
-                remove_tmp_files,
-                compress_param,
-                tmp_files_cleanup_frq,
-                json_param,
-                ):
+            crs_param,
+            res_param,
+            chkpt_param,
+            cuda_id_param,
+            merge_param,
+            nworkers_param,
+            pauses_param,
+            remove_tmp_files,
+            compress_param,
+            tmp_files_cleanup_frq,
+            json_param,
+        ):
             param.setFlags(
-                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+            )
             self.addParameter(param)
 
-
-
     @torch.no_grad()
     def processAlgorithm(self, parameters, context, feedback):
         """
@@ -360,43 +348,56 @@ class EncoderAlgorithm(IAMAPAlgorithm):
 
         ## compute parameters hash to have a unique identifier for the run
         ## some parameters do not change the encoding part of the algorithm
-        keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES']
+        keys_to_remove = ["MERGE_METHOD", "WORKERS", "PAUSES"]
         subdir_hash = compute_md5_hash(parameters, keys_to_remove=keys_to_remove)
-        output_subdir = os.path.join(self.output_dir,subdir_hash)
+        output_subdir = os.path.join(self.output_dir, subdir_hash)
         output_subdir = Path(output_subdir)
         output_subdir.mkdir(parents=True, exist_ok=True)
         self.output_subdir = output_subdir
-        feedback.pushInfo(f'output_subdir: {output_subdir}')
-        feedback.pushInfo(f'saving parameters to json file')
+        feedback.pushInfo(f"output_subdir: {output_subdir}")
+        feedback.pushInfo("saving parameters to json file")
         save_parameters_to_json(parameters, self.output_subdir)
-        feedback.pushInfo(f'logging parameters to csv')
-        log_parameters_to_csv(parameters,self.output_dir)
+        feedback.pushInfo("logging parameters to csv")
+        log_parameters_to_csv(parameters, self.output_dir)
 
         RasterDataset.filename_glob = self.rlayer_name
         RasterDataset.all_bands = [
-            self.rlayer.bandName(i_band) for i_band in range(1, self.rlayer.bandCount()+1)
+            self.rlayer.bandName(i_band)
+            for i_band in range(1, self.rlayer.bandCount() + 1)
         ]
         # currently only support rgb bands
-        input_bands = [self.rlayer.bandName(i_band)
-                       for i_band in self.selected_bands]
+        input_bands = [self.rlayer.bandName(i_band) for i_band in self.selected_bands]
 
-        feedback.pushInfo(f'create dataset')
+        feedback.pushInfo("create dataset")
         if self.crs == self.rlayer.crs():
             dataset = RasterDataset(
-                paths=self.rlayer_dir, crs=None, res=self.res, bands=input_bands, cache=False)
+                paths=self.rlayer_dir,
+                crs=None,
+                res=self.res,
+                bands=input_bands,
+                cache=False,
+            )
         else:
             dataset = RasterDataset(
-                paths=self.rlayer_dir, crs=self.crs.toWkt(), res=self.res, bands=input_bands, cache=False)
-        extent_bbox = BoundingBox(minx=self.extent.xMinimum(), maxx=self.extent.xMaximum(), miny=self.extent.yMinimum(), maxy=self.extent.yMaximum(),
-                                  mint=dataset.index.bounds[4], maxt=dataset.index.bounds[5])
-
+                paths=self.rlayer_dir,
+                crs=self.crs.toWkt(),
+                res=self.res,
+                bands=input_bands,
+                cache=False,
+            )
+        extent_bbox = BoundingBox(
+            minx=self.extent.xMinimum(),
+            maxx=self.extent.xMaximum(),
+            miny=self.extent.yMinimum(),
+            maxy=self.extent.yMaximum(),
+            mint=dataset.index.bounds[4],
+            maxt=dataset.index.bounds[5],
+        )
 
         if feedback.isCanceled():
-            feedback.pushWarning(
-                self.tr("\n !!!Processing is canceled by user!!! \n"))
+            feedback.pushWarning(self.tr("\n !!!Processing is canceled by user!!! \n"))
             return
 
-
         ### Custom logging to have more feedback during model loading
         logging.basicConfig(level=logging.DEBUG)
         logger = logging.getLogger()
@@ -408,101 +409,105 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         logger.info("Starting model loading...")
 
         # Load the model
-        feedback.pushInfo(f'creating model')
+        feedback.pushInfo("creating model")
         model = timm.create_model(
             self.backbone_name,
             pretrained=True,
             in_chans=len(input_bands),
             num_classes=0,
-            )
+        )
         logger.info("Model loaded succesfully !")
         logger.handlers.clear()
 
-
         if feedback.isCanceled():
-            feedback.pushWarning(
-                self.tr("\n !!!Processing is canceled by user!!! \n"))
+            feedback.pushWarning(self.tr("\n !!!Processing is canceled by user!!! \n"))
             return
 
-        feedback.pushInfo(f'model done')
+        feedback.pushInfo("model done")
         data_config = timm.data.resolve_model_data_config(model)
-        _, h, w, = data_config['input_size']
+        (
+            _,
+            h,
+            w,
+        ) = data_config["input_size"]
 
         if torch.cuda.is_available() and self.use_gpu:
             if self.cuda_id + 1 > torch.cuda.device_count():
                 self.cuda_id = torch.cuda.device_count() - 1
-            cuda_device = f'cuda:{self.cuda_id}'
-            device = f'cuda:{self.cuda_id}'
+            cuda_device = f"cuda:{self.cuda_id}"  # noqa: F841
+            device = f"cuda:{self.cuda_id}"
         else:
             self.batch_size = 1
-            device = 'cpu'
+            device = "cpu"
 
-        feedback.pushInfo(f'Device id: {device}')
+        feedback.pushInfo(f"Device id: {device}")
 
         if self.quantization:
-
-            try :
-                feedback.pushInfo(f'before quantization : {get_model_size(model)}')
+            try:
+                feedback.pushInfo(f"before quantization : {get_model_size(model)}")
 
                 model = quantize_model(model, device)
-                feedback.pushInfo(f'after quantization : {get_model_size(model)}')
-
-            except :
+                feedback.pushInfo(f"after quantization : {get_model_size(model)}")
 
-                feedback.pushInfo(f'quantization impossible, using original model.')
+            except Exception:
+                feedback.pushInfo("quantization impossible, using original model.")
 
         transform = AugmentationSequential(
-                T.ConvertImageDtype(torch.float32), # change dtype for normalize to be possible
-                K.Normalize(self.means,self.sds), # normalize occurs only on raster, not mask
-                K.Resize((h, w)),  # resize to 224*224 pixels, regardless of sampling size
-                data_keys=["image"],
-                )
+            T.ConvertImageDtype(
+                torch.float32
+            ),  # change dtype for normalize to be possible
+            K.Normalize(
+                self.means, self.sds
+            ),  # normalize occurs only on raster, not mask
+            K.Resize((h, w)),  # resize to 224*224 pixels, regardless of sampling size
+            data_keys=["image"],
+        )
         dataset.transforms = transform
 
-
         # sampler = GridGeoSampler(
-        #         dataset, 
-        #         size=self.size, 
-        #         stride=self.stride, 
-        #         roi=extent_bbox, 
+        #         dataset,
+        #         size=self.size,
+        #         stride=self.stride,
+        #         roi=extent_bbox,
         #         units=Units.PIXELS
         #         )  # Units.CRS or Units.PIXELS
         sampler = NoBordersGridGeoSampler(
-                dataset, 
-                size=self.size, 
-                stride=self.stride, 
-                roi=extent_bbox, 
-                units=Units.PIXELS
-                )  # Units.CRS or Units.PIXELS
+            dataset,
+            size=self.size,
+            stride=self.stride,
+            roi=extent_bbox,
+            units=Units.PIXELS,
+        )  # Units.CRS or Units.PIXELS
 
         if len(sampler) == 0:
             self.load_feature = False
-            feedback.pushWarning(f'\n !!!No available patch sample inside the chosen extent!!! \n')
-
+            feedback.pushWarning(
+                "\n !!!No available patch sample inside the chosen extent!!! \n"
+            )
 
-        feedback.pushInfo(f'model to dedvice')
+        feedback.pushInfo("model to dedvice")
         model.to(device=device)
 
-        feedback.pushInfo(f'Batch size: {self.batch_size}')
+        feedback.pushInfo(f"Batch size: {self.batch_size}")
         dataloader = DataLoader(
-                dataset, 
-                batch_size=self.batch_size, 
-                sampler=sampler, 
-                collate_fn=stack_samples,
-                num_workers=self.nworkers,
-                )
+            dataset,
+            batch_size=self.batch_size,
+            sampler=sampler,
+            collate_fn=stack_samples,
+            num_workers=self.nworkers,
+        )
 
-        feedback.pushInfo(f'Patch sample num: {len(sampler)}')
-        feedback.pushInfo(f'Total batch num: {len(dataloader)}')
+        feedback.pushInfo(f"Patch sample num: {len(sampler)}")
+        feedback.pushInfo(f"Total batch num: {len(dataloader)}")
         feedback.pushInfo(f'\n\n{"-"*16}\nBegining inference \n{"-"*16}\n\n')
 
-
-
         last_batch_done = self.get_last_batch_done()
         if last_batch_done >= 0:
-            feedback.pushInfo(f"\n\n {'-'*8} \n Resuming at batch number {last_batch_done}\n {'-'*8} \n\n")
+            feedback.pushInfo(
+                f"\n\n {'-'*8} \n Resuming at batch number {last_batch_done}\n {'-'*8} \n\n"
+            )
 
-        bboxes = [] # keep track of bboxes to have coordinates at the end
+        bboxes = []  # keep track of bboxes to have coordinates at the end
         elapsed_time_list = []
         total = 100 / len(dataloader) if len(dataloader) else 0
 
@@ -510,7 +515,6 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         self.all_encoding_done = True
 
         for current, sample in enumerate(dataloader):
-
             if current <= last_batch_done:
                 continue
 
@@ -520,50 +524,58 @@ class EncoderAlgorithm(IAMAPAlgorithm):
             if feedback.isCanceled():
                 self.load_feature = False
                 feedback.pushWarning(
-                    self.tr("\n !!!Processing is canceled by user!!! \n"))
+                    self.tr("\n !!!Processing is canceled by user!!! \n")
+                )
                 self.all_encoding_done = False
                 break
-            
+
             feedback.pushInfo(f'\n{"-"*8}\nBatch no. {current} loaded')
 
-            images = sample['image'].to(device)
+            images = sample["image"].to(device)
             if len(images.shape) > 4:
                 images = images.squeeze(1)
-            
-            feedback.pushInfo(f'Batch shape {images.shape}')
+
+            feedback.pushInfo(f"Batch shape {images.shape}")
 
             features = model.forward_features(images)
-            features = features[:,1:,:] # take only patch tokens
-            
+            features = features[:, 1:, :]  # take only patch tokens
+
             if current <= last_batch_done + 1:
-                n_patches = int(np.sqrt(features.shape[1]))   
+                n_patches = int(np.sqrt(features.shape[1]))
 
-            features = features.view(features.shape[0],n_patches,n_patches,features.shape[-1])
+            features = features.view(
+                features.shape[0], n_patches, n_patches, features.shape[-1]
+            )
             features = features.detach().cpu().numpy()
-            feedback.pushInfo(f'Features shape {features.shape}')
+            feedback.pushInfo(f"Features shape {features.shape}")
 
-            self.save_features(features,sample['bbox'], current)
-            feedback.pushInfo(f'Features saved')
+            self.save_features(features, sample["bbox"], current)
+            feedback.pushInfo("Features saved")
 
             if current <= last_batch_done + 1:
-                total_space, total_used_space, free_space = check_disk_space(self.output_subdir)
+                total_space, total_used_space, free_space = check_disk_space(
+                    self.output_subdir
+                )
 
                 used_outputsubdir = get_dir_size(str(self.output_subdir))
-                
-                to_use = ((len(dataloader) / (current+1)) - 1) * used_outputsubdir
+
+                to_use = ((len(dataloader) / (current + 1)) - 1) * used_outputsubdir
                 if to_use >= free_space:
                     feedback.pushWarning(
-                        self.tr(f"\n !!! only {free_space} GB disk space remaining, canceling !!! \n"))
+                        self.tr(
+                            f"\n !!! only {free_space} GB disk space remaining, canceling !!! \n"
+                        )
+                    )
                     break
 
-            bboxes.extend(sample['bbox'])
+            bboxes.extend(sample["bbox"])
 
             if self.pauses != 0:
                 time.sleep(self.pauses)
 
             end_time = time.time()
             # get the execution time of encoder, ms
-            elapsed_time = (end_time - start_time)
+            elapsed_time = end_time - start_time
             elapsed_time_list.append(elapsed_time)
             time_spent = sum(elapsed_time_list)
             time_remain = (time_spent / (current + 1)) * (len(dataloader) - current - 1)
@@ -579,83 +591,98 @@ class EncoderAlgorithm(IAMAPAlgorithm):
 
             feedback.pushInfo(f"Encoder executed with {elapsed_time:.3f}s")
             feedback.pushInfo(f"Time spent: {time_spent:.3f}s")
-                  
+
             if time_remain <= 60:
-                feedback.pushInfo(f"Estimated time remaining: {time_remain:.3f}s \n {'-'*8}")
+                feedback.pushInfo(
+                    f"Estimated time remaining: {time_remain:.3f}s \n {'-'*8}"
+                )
             else:
                 time_remain_m, time_remain_s = divmod(int(time_remain), 60)
                 time_remain_h, time_remain_m = divmod(time_remain_m, 60)
-                feedback.pushInfo(f"Estimated time remaining: {time_remain_h:d}h:{time_remain_m:02d}m:{time_remain_s:02d}s \n" )
+                feedback.pushInfo(
+                    f"Estimated time remaining: {time_remain_h:d}h:{time_remain_m:02d}m:{time_remain_s:02d}s \n"
+                )
 
             if ((current + 1) % self.cleanup_frq == 0) and self.remove_tmp_files:
-
                 ## not the cleanest way to do for now
                 ## but avoids to refactor all
                 self.all_encoding_done = False
-                feedback.pushInfo('Cleaning temporary files...')
-                all_tiles = [os.path.join(self.output_subdir,f) for f in os.listdir(self.output_subdir) if f.endswith('_tmp.tif')]
-                all_tiles = [f for f in all_tiles if not f.startswith('merged')]
+                feedback.pushInfo("Cleaning temporary files...")
+                all_tiles = [
+                    os.path.join(self.output_subdir, f)
+                    for f in os.listdir(self.output_subdir)
+                    if f.endswith("_tmp.tif")
+                ]
+                all_tiles = [f for f in all_tiles if not f.startswith("merged")]
 
-                dst_path = Path(os.path.join(self.output_subdir,'merged_tmp.tif'))
+                dst_path = Path(os.path.join(self.output_subdir, "merged_tmp.tif"))
 
                 merge_tiles(
-                        tiles = all_tiles, 
-                        dst_path = dst_path,
-                        method = self.merge_method,
-                        )
+                    tiles=all_tiles,
+                    dst_path=dst_path,
+                    method=self.merge_method,
+                )
                 self.remove_temp_files()
                 self.all_encoding_done = True
 
             # Update the progress bar
-            feedback.setProgress(int((current+1) * total))
-
+            feedback.setProgress(int((current + 1) * total))
 
         ## merging all temp tiles
-        feedback.pushInfo(f"\n\n{'-'*8}\n Merging tiles \n{'-'*8}\n" )
-        all_tiles = [os.path.join(self.output_subdir,f) for f in os.listdir(self.output_subdir) if f.endswith('_tmp.tif')]
+        feedback.pushInfo(f"\n\n{'-'*8}\n Merging tiles \n{'-'*8}\n")
+        all_tiles = [
+            os.path.join(self.output_subdir, f)
+            for f in os.listdir(self.output_subdir)
+            if f.endswith("_tmp.tif")
+        ]
         rlayer_name, ext = os.path.splitext(self.rlayer_name)
 
-        if not self.all_encoding_done :
-            dst_path = Path(os.path.join(self.output_subdir,'merged_tmp.tif'))
-            layer_name = f'{rlayer_name} features tmp'
+        if not self.all_encoding_done:
+            dst_path = Path(os.path.join(self.output_subdir, "merged_tmp.tif"))
+            layer_name = f"{rlayer_name} features tmp"
         else:
             # dst_path = Path(os.path.join(self.output_subdir,'merged.tif'))
             ## update filename if a merged.tif file allready exists
-            
-            dst_path, layer_name = get_unique_filename(self.output_subdir, f'merged.tif', f'{rlayer_name} features')
+
+            dst_path, layer_name = get_unique_filename(
+                self.output_subdir, "merged.tif", f"{rlayer_name} features"
+            )
             dst_path = Path(dst_path)
 
         merge_tiles(
-                tiles = all_tiles, 
-                dst_path = dst_path,
-                method = self.merge_method,
-                )
+            tiles=all_tiles,
+            dst_path=dst_path,
+            method=self.merge_method,
+        )
 
         if self.remove_tmp_files:
-
             self.remove_temp_files()
 
-        parameters['OUTPUT_RASTER']=dst_path
+        parameters["OUTPUT_RASTER"] = dst_path
 
         if self.compress:
             dst_path = self.tiff_to_jp2(parameters, feedback)
 
-        return {"Output feature path": self.output_subdir, 'Patch samples saved': self.iPatch, 'OUTPUT_RASTER':dst_path, 'OUTPUT_LAYER_NAME':layer_name}
+        return {
+            "Output feature path": self.output_subdir,
+            "Patch samples saved": self.iPatch,
+            "OUTPUT_RASTER": dst_path,
+            "OUTPUT_LAYER_NAME": layer_name,
+        }
 
     def load_parameters_as_json(self, feedback, parameters):
-        parameters['JSON_PARAM'] = str(parameters['JSON_PARAM'])
-        json_param = parameters['JSON_PARAM']
+        parameters["JSON_PARAM"] = str(parameters["JSON_PARAM"])
+        json_param = parameters["JSON_PARAM"]
         print(json_param)
-        if json_param != 'NULL':
+        if json_param != "NULL":
             with open(json_param) as json_file:
                 parameters = json.load(json_file)
-            feedback.pushInfo(f'Loading previous parameters from {json_param}')
-            parameters.pop('JSON_PARAM',None)
+            feedback.pushInfo(f"Loading previous parameters from {json_param}")
+            parameters.pop("JSON_PARAM", None)
         else:
-            parameters.pop('JSON_PARAM',None)
-        
-        return parameters
+            parameters.pop("JSON_PARAM", None)
 
+        return parameters
 
     def remove_temp_files(self):
         """
@@ -666,32 +693,32 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         last_batch_done = self.get_last_batch_done()
         if not self.all_encoding_done:
             tiles_to_remove = [
-                    os.path.join(self.output_subdir, f)
-                    for f in os.listdir(self.output_subdir)
-                    if f.endswith('_tmp.tif') and not f.startswith(str(last_batch_done))
-                    ]
+                os.path.join(self.output_subdir, f)
+                for f in os.listdir(self.output_subdir)
+                if f.endswith("_tmp.tif") and not f.startswith(str(last_batch_done))
+            ]
             tiles_to_remove = [
-                    f for f in tiles_to_remove
-                    if not f.endswith('merged_tmp.tif')
-                    ]
+                f for f in tiles_to_remove if not f.endswith("merged_tmp.tif")
+            ]
 
         ## else cleanup all temp files
-        else : 
-            tiles_to_remove = [os.path.join(self.output_subdir, f)
-                 for f in os.listdir(self.output_subdir)
-                 if f.endswith('_tmp.tif')]
+        else:
+            tiles_to_remove = [
+                os.path.join(self.output_subdir, f)
+                for f in os.listdir(self.output_subdir)
+                if f.endswith("_tmp.tif")
+            ]
 
         remove_files(tiles_to_remove)
 
         return
 
     def get_last_batch_done(self):
-
         ## get largest batch_number achieved
         ## files are saved with the pattern '{batch_number}_{image_id_within_batch}_tmp.tif'
         # Regular expression pattern to extract numbers
         # pattern = re.compile(r'^(\d+)_\d+\.tif$')
-        pattern = re.compile(r'^(\d+)_\d+_tmp\.tif$')
+        pattern = re.compile(r"^(\d+)_\d+_tmp\.tif$")
 
         # Initialize a set to store unique first numbers
         batch_numbers = set()
@@ -712,33 +739,33 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         else:
             return -1
 
-
     def save_features(
-            self,
-            feature: np.ndarray,
-            bboxes: BoundingBox,
-            nbatch: int,
-            dtype: str = 'float32'
-            ):
-
-        if dtype == 'int8':
+        self,
+        feature: np.ndarray,
+        bboxes: BoundingBox,
+        nbatch: int,
+        dtype: str = "float32",
+    ):
+        if dtype == "int8":
             feature = (feature * 127).astype(np.int8)
         # iterate over batch_size dimension
         for idx in range(feature.shape[0]):
             _, height, width, channels = feature.shape
             bbox = bboxes[idx]
-            rio_transform = rasterio.transform.from_bounds(bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, width, height)  # west, south, east, north, width, height
+            rio_transform = rasterio.transform.from_bounds(
+                bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, width, height
+            )  # west, south, east, north, width, height
             feature_path = os.path.join(self.output_subdir, f"{nbatch}_{idx}_tmp.tif")
             with rasterio.open(
-                    feature_path,
-                    mode="w",
-                    driver="GTiff",
-                    height=height, 
-                    width=width,
-                    count=channels,
-                    dtype=dtype,
-                    crs=self.crs.toWkt(),
-                    transform=rio_transform
+                feature_path,
+                mode="w",
+                driver="GTiff",
+                height=height,
+                width=width,
+                count=channels,
+                dtype=dtype,
+                crs=self.crs.toWkt(),
+                transform=rio_transform,
             ) as ds:
                 ds.write(np.transpose(feature[idx, ...], (2, 0, 1)))
                 tags = {
@@ -750,88 +777,69 @@ class EncoderAlgorithm(IAMAPAlgorithm):
 
         return
 
-    def process_options(self,parameters, context, feedback):
+    def process_options(self, parameters, context, feedback):
         self.iPatch = 0
-        
+
         self.feature_dir = ""
 
-        feedback.pushInfo(
-                f'PARAMETERS :\n{parameters}')
-        
-        feedback.pushInfo(
-                f'CONTEXT :\n{context}')
-        
-        feedback.pushInfo(
-                f'FEEDBACK :\n{feedback}')
+        feedback.pushInfo(f"PARAMETERS :\n{parameters}")
 
-        self.process_geo_parameters(parameters, context, feedback)
+        feedback.pushInfo(f"CONTEXT :\n{context}")
 
-        ckpt_path = self.parameterAsFile(
-            parameters, self.CKPT, context)
+        feedback.pushInfo(f"FEEDBACK :\n{feedback}")
 
+        self.process_geo_parameters(parameters, context, feedback)
+
+        ckpt_path = self.parameterAsFile(parameters, self.CKPT, context)  # noqa: F841
 
         ## Use the given backbone name is any, use preselected models otherwise.
-        input_name = self.parameterAsString(
-            parameters, self.BACKBONE_CHOICE, context)
-        
+        input_name = self.parameterAsString(parameters, self.BACKBONE_CHOICE, context)
+
         if input_name:
             self.backbone_name = input_name
         else:
-            backbone_idx = self.parameterAsEnum(
-                parameters, self.BACKBONE_OPT, context)
+            backbone_idx = self.parameterAsEnum(parameters, self.BACKBONE_OPT, context)
             self.backbone_name = self.timm_backbone_opt[backbone_idx]
-            feedback.pushInfo(f'self.backbone_name:{self.backbone_name}')
-
-        self.compress = self.parameterAsBoolean(
-            parameters, self.COMPRESS, context)
-        self.stride = self.parameterAsInt(
-            parameters, self.STRIDE, context)
-        self.size = self.parameterAsInt(
-            parameters, self.SIZE, context)
-        self.quantization = self.parameterAsBoolean(
-            parameters, self.QUANT, context)
-        self.use_gpu = self.parameterAsBoolean(
-            parameters, self.CUDA, context)
-        self.batch_size = self.parameterAsInt(
-            parameters, self.BATCH_SIZE, context)
-        self.output_dir = self.parameterAsString(
-            parameters, self.OUTPUT, context)
-        self.cuda_id = self.parameterAsInt(
-            parameters, self.CUDA_ID, context)
-        self.pauses = self.parameterAsInt(
-            parameters, self.PAUSES, context)
+            feedback.pushInfo(f"self.backbone_name:{self.backbone_name}")
+
+        self.compress = self.parameterAsBoolean(parameters, self.COMPRESS, context)
+        self.stride = self.parameterAsInt(parameters, self.STRIDE, context)
+        self.size = self.parameterAsInt(parameters, self.SIZE, context)
+        self.quantization = self.parameterAsBoolean(parameters, self.QUANT, context)
+        self.use_gpu = self.parameterAsBoolean(parameters, self.CUDA, context)
+        self.batch_size = self.parameterAsInt(parameters, self.BATCH_SIZE, context)
+        self.output_dir = self.parameterAsString(parameters, self.OUTPUT, context)
+        self.cuda_id = self.parameterAsInt(parameters, self.CUDA_ID, context)
+        self.pauses = self.parameterAsInt(parameters, self.PAUSES, context)
         self.cleanup_frq = self.parameterAsInt(
-            parameters, self.TEMP_FILES_CLEANUP_FREQ, context)
-        self.nworkers = self.parameterAsInt(
-            parameters, self.WORKERS, context)
-        merge_method_idx = self.parameterAsEnum(
-            parameters, self.MERGE_METHOD, context)
+            parameters, self.TEMP_FILES_CLEANUP_FREQ, context
+        )
+        self.nworkers = self.parameterAsInt(parameters, self.WORKERS, context)
+        merge_method_idx = self.parameterAsEnum(parameters, self.MERGE_METHOD, context)
         self.merge_method = self.merge_options[merge_method_idx]
         self.remove_tmp_files = self.parameterAsBoolean(
-            parameters, self.REMOVE_TEMP_FILES, context)
-
+            parameters, self.REMOVE_TEMP_FILES, context
+        )
 
         # get mean and sd of dataset from raster metadata
-        feedback.pushInfo(f'Computing means and sds for normalization')
+        feedback.pushInfo("Computing means and sds for normalization")
         means, sds = get_mean_sd_by_band(self.rlayer_path)
         # subset with selected_bands
-        feedback.pushInfo(f'Selected bands: {self.selected_bands}')
-        self.means = [means[i-1] for i in self.selected_bands]
-        self.sds = [sds[i-1] for i in self.selected_bands]
-        feedback.pushInfo(f'Means for normalization: {self.means}')
-        feedback.pushInfo(f'Std. dev. for normalization: {self.sds}')
-
+        feedback.pushInfo(f"Selected bands: {self.selected_bands}")
+        self.means = [means[i - 1] for i in self.selected_bands]
+        self.sds = [sds[i - 1] for i in self.selected_bands]
+        feedback.pushInfo(f"Means for normalization: {self.means}")
+        feedback.pushInfo(f"Std. dev. for normalization: {self.sds}")
 
     # used to handle any thread-sensitive cleanup which is required by the algorithm.
     def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
         return {}
 
-
     def tr(self, string):
         """
         Returns a translatable string with the self.tr() function.
         """
-        return QCoreApplication.translate('Processing', string)
+        return QCoreApplication.translate("Processing", string)
 
     def createInstance(self):
         return EncoderAlgorithm()
@@ -844,21 +852,21 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return 'encoder'
+        return "encoder"
 
     def displayName(self):
         """
         Returns the translated algorithm name, which should be used for any
         user-visible display of the algorithm name.
         """
-        return self.tr('Image Encoder')
+        return self.tr("Image Encoder")
 
     def group(self):
         """
         Returns the name of the group this algorithm belongs to. This string
         should be localised.
         """
-        return self.tr('')
+        return self.tr("")
 
     def groupId(self):
         """
@@ -868,7 +876,7 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         contain lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return ''
+        return ""
 
     def shortHelpString(self):
         """
@@ -880,4 +888,3 @@ class EncoderAlgorithm(IAMAPAlgorithm):
 
     def icon(self):
         return QIcon_EncoderTool
-
diff --git a/iamap.py b/iamap.py
index ddcbdda890457883dba57013ab609f0c9c966d71..f7f4a90fb37da16640404e289fa85abf9b135bcd 100644
--- a/iamap.py
+++ b/iamap.py
@@ -1,20 +1,16 @@
 import processing
-from PyQt5.QtWidgets import (
-    QAction,
-    QToolBar,
-    QApplication,
-    QDialog
-)
+from PyQt5.QtWidgets import QAction, QToolBar
 from PyQt5.QtCore import pyqtSignal, QObject
 from qgis.core import QgsApplication
 from qgis.gui import QgisInterface
 from .provider import IAMapProvider
-from .icons import (QIcon_EncoderTool, 
-                    QIcon_ReductionTool, 
-                    QIcon_ClusterTool, 
-                    QIcon_SimilarityTool, 
-                    QIcon_RandomforestTool,
-                    )
+from .icons import (
+    QIcon_EncoderTool,
+    QIcon_ReductionTool,
+    QIcon_ClusterTool,
+    QIcon_SimilarityTool,
+    QIcon_RandomforestTool,
+)
 
 
 class IAMap(QObject):
@@ -32,34 +28,26 @@ class IAMap(QObject):
     def initGui(self):
         self.initProcessing()
 
-        self.toolbar: QToolBar = self.iface.addToolBar('IAMap Toolbar')
-        self.toolbar.setObjectName('IAMapToolbar')
-        self.toolbar.setToolTip('IAMap Toolbar')
+        self.toolbar: QToolBar = self.iface.addToolBar("IAMap Toolbar")
+        self.toolbar.setObjectName("IAMapToolbar")
+        self.toolbar.setToolTip("IAMap Toolbar")
 
         self.actionEncoder = QAction(
-            QIcon_EncoderTool,
-            "Deep Learning Image Encoder",
-            self.iface.mainWindow()
+            QIcon_EncoderTool, "Deep Learning Image Encoder", self.iface.mainWindow()
         )
         self.actionReducer = QAction(
-            QIcon_ReductionTool,
-            "Reduce dimensions",
-            self.iface.mainWindow()
+            QIcon_ReductionTool, "Reduce dimensions", self.iface.mainWindow()
         )
         self.actionCluster = QAction(
-            QIcon_ClusterTool,
-            "Cluster raster",
-            self.iface.mainWindow()
+            QIcon_ClusterTool, "Cluster raster", self.iface.mainWindow()
         )
         self.actionSimilarity = QAction(
-            QIcon_SimilarityTool,
-            "Compute similarity",
-            self.iface.mainWindow()
+            QIcon_SimilarityTool, "Compute similarity", self.iface.mainWindow()
         )
         self.actionRF = QAction(
             QIcon_RandomforestTool,
             "Fit Machine Learning algorithm",
-            self.iface.mainWindow()
+            self.iface.mainWindow(),
         )
         self.actionEncoder.setObjectName("mActionEncoder")
         self.actionReducer.setObjectName("mActionReducer")
@@ -67,16 +55,11 @@ class IAMap(QObject):
         self.actionSimilarity.setObjectName("mactionSimilarity")
         self.actionRF.setObjectName("mactionRF")
 
-        self.actionEncoder.setToolTip(
-            "Encode a raster with a deep learning backbone")
-        self.actionReducer.setToolTip(
-            "Reduce raster dimensions")
-        self.actionCluster.setToolTip(
-            "Cluster raster")
-        self.actionSimilarity.setToolTip(
-            "Compute similarity")
-        self.actionRF.setToolTip(
-            "Fit ML model")
+        self.actionEncoder.setToolTip("Encode a raster with a deep learning backbone")
+        self.actionReducer.setToolTip("Reduce raster dimensions")
+        self.actionCluster.setToolTip("Cluster raster")
+        self.actionSimilarity.setToolTip("Compute similarity")
+        self.actionRF.setToolTip("Fit ML model")
 
         self.actionEncoder.triggered.connect(self.encodeImage)
         self.actionReducer.triggered.connect(self.reduceImage)
@@ -107,119 +90,111 @@ class IAMap(QObject):
         QgsApplication.processingRegistry().removeProvider(self.provider)
 
     def encodeImage(self):
-        '''
-        '''
-        result = processing.execAlgorithmDialog('iamap:encoder', {})
+        """ """
+        result = processing.execAlgorithmDialog("iamap:encoder", {})
         print(result)
-                # Check if algorithm execution was successful
+        # Check if algorithm execution was successful
         if result:
             # Retrieve output parameters from the result dictionary
-            if 'OUTPUT_RASTER' in result:
-                output_raster_path = result['OUTPUT_RASTER']
-                output_layer_name = result['OUTPUT_LAYER_NAME']
+            if "OUTPUT_RASTER" in result:
+                output_raster_path = result["OUTPUT_RASTER"]
+                output_layer_name = result["OUTPUT_LAYER_NAME"]
 
                 # Add the output raster layer to the map canvas
-                self.iface.addRasterLayer(str(output_raster_path),output_layer_name)
+                self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
             else:
                 # Handle missing or unexpected output
-                print('Output raster not found in algorithm result.')
+                print("Output raster not found in algorithm result.")
         else:
             # Handle algorithm execution failure or cancellation
-            print('Algorithm execution was not successful.')
+            print("Algorithm execution was not successful.")
         # processing.execAlgorithmDialog('', {})
         # self.close_all_dialogs()
 
-
     def reduceImage(self):
-        '''
-        '''
-        result = processing.execAlgorithmDialog('iamap:reduction', {})
+        """ """
+        result = processing.execAlgorithmDialog("iamap:reduction", {})
         print(result)
-                # Check if algorithm execution was successful
+        # Check if algorithm execution was successful
         if result:
             # Retrieve output parameters from the result dictionary
-            if 'OUTPUT_RASTER' in result:
-                output_raster_path = result['OUTPUT_RASTER']
-                output_layer_name = result['OUTPUT_LAYER_NAME']
+            if "OUTPUT_RASTER" in result:
+                output_raster_path = result["OUTPUT_RASTER"]
+                output_layer_name = result["OUTPUT_LAYER_NAME"]
 
                 # Add the output raster layer to the map canvas
                 self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
             else:
                 # Handle missing or unexpected output
-                print('Output raster not found in algorithm result.')
+                print("Output raster not found in algorithm result.")
         else:
             # Handle algorithm execution failure or cancellation
-            print('Algorithm execution was not successful.')
+            print("Algorithm execution was not successful.")
         # processing.execAlgorithmDialog('', {})
 
-
     def clusterImage(self):
-        '''
-        '''
-        result = processing.execAlgorithmDialog('iamap:cluster', {})
+        """ """
+        result = processing.execAlgorithmDialog("iamap:cluster", {})
         print(result)
-                # Check if algorithm execution was successful
+        # Check if algorithm execution was successful
         if result:
             # Retrieve output parameters from the result dictionary
-            if 'OUTPUT_RASTER' in result:
-                output_raster_path = result['OUTPUT_RASTER']
-                output_layer_name = result['OUTPUT_LAYER_NAME']
+            if "OUTPUT_RASTER" in result:
+                output_raster_path = result["OUTPUT_RASTER"]
+                output_layer_name = result["OUTPUT_LAYER_NAME"]
 
                 # Add the output raster layer to the map canvas
                 self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
             else:
                 # Handle missing or unexpected output
-                print('Output raster not found in algorithm result.')
+                print("Output raster not found in algorithm result.")
         else:
             # Handle algorithm execution failure or cancellation
-            print('Algorithm execution was not successful.')
+            print("Algorithm execution was not successful.")
         # processing.execAlgorithmDialog('', {})
 
-
     def similarityImage(self):
-        '''
-        '''
-        result = processing.execAlgorithmDialog('iamap:similarity', {})
+        """ """
+        result = processing.execAlgorithmDialog("iamap:similarity", {})
         print(result)
-                # Check if algorithm execution was successful
+        # Check if algorithm execution was successful
         if result:
             # Retrieve output parameters from the result dictionary
-            if 'OUTPUT_RASTER' in result:
-                output_raster_path = result['OUTPUT_RASTER']
-                output_layer_name = result['OUTPUT_LAYER_NAME']
-                used_shp = result['USED_SHP']
+            if "OUTPUT_RASTER" in result:
+                output_raster_path = result["OUTPUT_RASTER"]
+                output_layer_name = result["OUTPUT_LAYER_NAME"]
+                used_shp = result["USED_SHP"]
 
                 # Add the output raster layer to the map canvas
                 self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
-                self.iface.addVectorLayer(str(used_shp), 'used points', "ogr")
+                self.iface.addVectorLayer(str(used_shp), "used points", "ogr")
             else:
                 # Handle missing or unexpected output
-                print('Output raster not found in algorithm result.')
+                print("Output raster not found in algorithm result.")
         else:
             # Handle algorithm execution failure or cancellation
-            print('Algorithm execution was not successful.')
+            print("Algorithm execution was not successful.")
         # processing.execAlgorithmDialog('', {})
-        
+
     def rfImage(self):
-        '''
-        '''
-        result = processing.execAlgorithmDialog('iamap:ml', {})
+        """ """
+        result = processing.execAlgorithmDialog("iamap:ml", {})
         print(result)
-                # Check if algorithm execution was successful
+        # Check if algorithm execution was successful
         if result:
             # Retrieve output parameters from the result dictionary
-            if 'OUTPUT_RASTER' in result:
-                output_raster_path = result['OUTPUT_RASTER']
-                output_layer_name = result['OUTPUT_LAYER_NAME']
-                used_shp = result['USED_SHP']
+            if "OUTPUT_RASTER" in result:
+                output_raster_path = result["OUTPUT_RASTER"]
+                output_layer_name = result["OUTPUT_LAYER_NAME"]
+                used_shp = result["USED_SHP"]
 
                 # Add the output raster layer to the map canvas
                 self.iface.addRasterLayer(str(output_raster_path), output_layer_name)
-                self.iface.addVectorLayer(str(used_shp), 'used points', "ogr")
+                self.iface.addVectorLayer(str(used_shp), "used points", "ogr")
             else:
                 # Handle missing or unexpected output
-                print('Output raster not found in algorithm result.')
+                print("Output raster not found in algorithm result.")
         else:
             # Handle algorithm execution failure or cancellation
-            print('Algorithm execution was not successful.')
+            print("Algorithm execution was not successful.")
         # processing.execAlgorithmDialog('', {})
diff --git a/icons/__init__.py b/icons/__init__.py
index d5e77abd5bd3433f0d24c69aa5f94e2ba8eb54ec..a540df97650c253e11c085960e72942c3b1e5450 100644
--- a/icons/__init__.py
+++ b/icons/__init__.py
@@ -2,11 +2,11 @@ import os
 from PyQt5.QtGui import QIcon
 
 cwd = os.path.abspath(os.path.dirname(__file__))
-encoder_tool_path = os.path.join(cwd, 'encoder.svg')
-reduction_tool_path = os.path.join(cwd, 'proj.svg')
-cluster_tool_path = os.path.join(cwd, 'cluster.svg')
-similarity_tool_path = os.path.join(cwd, 'sim.svg')
-random_forest_tool_path = os.path.join(cwd, 'forest.svg')
+encoder_tool_path = os.path.join(cwd, "encoder.svg")
+reduction_tool_path = os.path.join(cwd, "proj.svg")
+cluster_tool_path = os.path.join(cwd, "cluster.svg")
+similarity_tool_path = os.path.join(cwd, "sim.svg")
+random_forest_tool_path = os.path.join(cwd, "forest.svg")
 
 QIcon_EncoderTool = QIcon(encoder_tool_path)
 QIcon_ReductionTool = QIcon(reduction_tool_path)
diff --git a/metadata.txt b/metadata.txt
index ee350e18b24c1595589d4a0827d5ad3c8c712e37..66f4ad379c1ad4dd432f75b4db4305435d46e326 100644
--- a/metadata.txt
+++ b/metadata.txt
@@ -2,7 +2,7 @@
 name=iamap
 description=Extract and manipulate deep learning features from rasters
 about= This plugin is still a work in progress, feel free to fill an issue on github.
-version=0.5.9
+version=0.6.0
 icon=icons/favicon.svg
 qgisMinimumVersion=3.12
 author=Paul Tresson, Pierre Lecoz, Hadrien Tulet
diff --git a/ml.py b/ml.py
index 444bd7fd6bd7b3ed334ad564a442468ff445529a..bef9e86dc7a66d8af544c9422f53f71cc22c7a5b 100644
--- a/ml.py
+++ b/ml.py
@@ -12,39 +12,39 @@ import pandas as pd
 from shapely.geometry import box
 from qgis.PyQt.QtCore import QCoreApplication
 from qgis.core import (
-                       QgsProcessingParameterBoolean,
-                       QgsProcessingParameterEnum,
-                       QgsProcessingParameterVectorLayer,
-                       QgsProcessingParameterString,
-                       QgsProcessingParameterNumber,
-                       QgsProcessingParameterDefinition
-                       )
+    QgsProcessingParameterBoolean,
+    QgsProcessingParameterEnum,
+    QgsProcessingParameterVectorLayer,
+    QgsProcessingParameterString,
+    QgsProcessingParameterNumber,
+    QgsProcessingParameterDefinition,
+)
 
 from .icons import QIcon_RandomforestTool
 from .utils.geo import get_random_samples_in_gdf, get_unique_col_name
 from .utils.algo import (
-                        SHPAlgorithm,
-                        get_sklearn_algorithms_with_methods,
-                        instantiate_sklearn_algorithm,
-                        get_arguments,
-                        )
+    SHPAlgorithm,
+    get_sklearn_algorithms_with_methods,
+    instantiate_sklearn_algorithm,
+    get_arguments,
+)
 
 import sklearn.ensemble as ensemble
 import sklearn.neighbors as neighbors
 from sklearn.base import ClassifierMixin, RegressorMixin
 from sklearn.metrics import (
-                            accuracy_score,
-                            precision_score, 
-                            recall_score, 
-                            f1_score, 
-                            confusion_matrix, 
-                            classification_report
-                            )
+    accuracy_score,
+    precision_score,
+    recall_score,
+    f1_score,
+    confusion_matrix,
+    classification_report,
+)
 from sklearn.metrics import (
-                            mean_absolute_error, 
-                            mean_squared_error, 
-                            r2_score,
-                            )
+    mean_absolute_error,
+    mean_squared_error,
+    r2_score,
+)
 
 
 def check_model_type(model):
@@ -57,20 +57,19 @@ def check_model_type(model):
 
 
 class MLAlgorithm(SHPAlgorithm):
-    """
-    """
-
-    GT_COL = 'GT_COL'
-    DO_KFOLDS = 'DO_KFOLDS'
-    FOLD_COL = 'FOLD_COL'
-    NFOLDS = 'NFOLDS'
-    SAVE_MODEL = 'SAVE_MODEL'
-    SK_PARAM = 'SK_PARAM'
-    TEMPLATE_TEST = 'TEMPLATE_TEST'
-    METHOD = 'METHOD'
-    TMP_DIR = 'iamap_ml'
-    DEFAULT_TEMPLATE = 'ml_poly.shp'
-    TYPE = 'ml'
+    """ """
+
+    GT_COL = "GT_COL"
+    DO_KFOLDS = "DO_KFOLDS"
+    FOLD_COL = "FOLD_COL"
+    NFOLDS = "NFOLDS"
+    SAVE_MODEL = "SAVE_MODEL"
+    SK_PARAM = "SK_PARAM"
+    TEMPLATE_TEST = "TEMPLATE_TEST"
+    METHOD = "METHOD"
+    TMP_DIR = "iamap_ml"
+    DEFAULT_TEMPLATE = "ml_poly.shp"
+    TYPE = "ml"
 
     def initAlgorithm(self, config=None):
         """
@@ -82,100 +81,96 @@ class MLAlgorithm(SHPAlgorithm):
         self.init_input_shp()
 
         self.method_opt = self.get_algorithms()
-        default_index = self.method_opt.index('RandomForestClassifier')
-        self.addParameter (
+        default_index = self.method_opt.index("RandomForestClassifier")
+        self.addParameter(
             QgsProcessingParameterEnum(
-                name = self.METHOD,
-                description = self.tr(
-                    'Sklearn algorithm used'),
-                defaultValue = default_index,
-                options = self.method_opt,
+                name=self.METHOD,
+                description=self.tr("Sklearn algorithm used"),
+                defaultValue=default_index,
+                options=self.method_opt,
             )
         )
 
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterString(
-                name = self.SK_PARAM,
-                description = self.tr(
-                    'Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments.'),
-                defaultValue = '',
+                name=self.SK_PARAM,
+                description=self.tr(
+                    "Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments."
+                ),
+                defaultValue="",
                 optional=True,
             )
         )
 
-
         self.addParameter(
             QgsProcessingParameterVectorLayer(
                 name=self.TEMPLATE,
                 description=self.tr(
-                    'Input shapefile path for training data set for random forest (if no test data_set, will be devised in train and test)'),
-            # defaultValue=os.path.join(self.cwd,'assets',self.DEFAULT_TEMPLATE),
+                    "Input shapefile path for training data set for random forest (if no test data_set, will be devised in train and test)"
+                ),
+                # defaultValue=os.path.join(self.cwd,'assets',self.DEFAULT_TEMPLATE),
             ),
         )
-        
+
         self.addParameter(
             QgsProcessingParameterVectorLayer(
                 name=self.TEMPLATE_TEST,
-                description=self.tr(
-                    'Input shapefile path for test dataset.'),
-                optional = True
+                description=self.tr("Input shapefile path for test dataset."),
+                optional=True,
             ),
         )
 
-
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterString(
-                name = self.GT_COL,
-                description = self.tr(
-                    'Name of the column containing ground truth values.'),
-                defaultValue = '',
+                name=self.GT_COL,
+                description=self.tr(
+                    "Name of the column containing ground truth values."
+                ),
+                defaultValue="",
             )
         )
 
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterBoolean(
-                name = self.DO_KFOLDS,
-                description = self.tr(
-                    'Perform cross-validation'),
-                defaultValue = True,
+                name=self.DO_KFOLDS,
+                description=self.tr("Perform cross-validation"),
+                defaultValue=True,
             )
         )
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterString(
-                name = self.FOLD_COL,
-                description = self.tr(
-                    'Name of the column defining folds in case of cross-validation. If none is selected, random sampling is used.'),
-                defaultValue = '',
+                name=self.FOLD_COL,
+                description=self.tr(
+                    "Name of the column defining folds in case of cross-validation. If none is selected, random sampling is used."
+                ),
+                defaultValue="",
                 optional=True,
             )
         )
 
         nfold_param = QgsProcessingParameterNumber(
             name=self.NFOLDS,
-            description=self.tr(
-                'Number of folds performed'),
+            description=self.tr("Number of folds performed"),
             type=QgsProcessingParameterNumber.Integer,
             optional=True,
             minValue=2,
             defaultValue=5,
-            maxValue=10
+            maxValue=10,
         )
 
         save_param = QgsProcessingParameterBoolean(
-                self.SAVE_MODEL,
-                self.tr("Save model after fit."),
-                defaultValue=True
-                )
+            self.SAVE_MODEL, self.tr("Save model after fit."), defaultValue=True
+        )
 
         for param in (
-                nfold_param,
-                save_param,
-                ):
+            nfold_param,
+            save_param,
+        ):
             param.setFlags(
-                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+            )
             self.addParameter(param)
 
-
     def processAlgorithm(self, parameters, context, feedback):
         """
         Here is where the processing itself takes place.
@@ -192,156 +187,153 @@ class MLAlgorithm(SHPAlgorithm):
         if self.do_kfold:
             best_metric = 0
             best_metrics_dict = {}
-            for fold in sorted(self.gdf[self.fold_col].unique()): # pyright: ignore[reportAttributeAccessIssue]
-                feedback.pushInfo(f'==== Fold {fold} ====')
+            for fold in sorted(self.gdf[self.fold_col].unique()):  # pyright: ignore[reportAttributeAccessIssue]
+                feedback.pushInfo(f"==== Fold {fold} ====")
                 self.test_gdf = self.gdf.loc[self.gdf[self.fold_col] == fold]
                 self.train_gdf = self.gdf.loc[self.gdf[self.fold_col] != fold]
                 metrics_dict = self.train_test_loop(feedback)
 
-                if 'accuracy' in metrics_dict.keys():
-                    used_metric = metrics_dict['accuracy']
-                if 'r2' in metrics_dict.keys():
-                    used_metric = metrics_dict['accuracy']
+                if "accuracy" in metrics_dict.keys():
+                    used_metric = metrics_dict["accuracy"]
+                if "r2" in metrics_dict.keys():
+                    used_metric = metrics_dict["accuracy"]
                 if used_metric >= best_metric:
                     best_metric = used_metric
                     best_metrics_dict = metrics_dict
                     self.best_model = self.model
 
         if (self.test_gdf is None) and not self.do_kfold:
-            train_set, train_gts = self.get_raster(mode='train')
+            train_set, train_gts = self.get_raster(mode="train")
             self.model.fit(train_set, train_gts)
-            feedback.pushWarning(f'No test set was provided and no cross-validation is done, unable to assess model quality !')
+            feedback.pushWarning(
+                "No test set was provided and no cross-validation is done, unable to assess model quality !"
+            )
             self.best_model = self.model
 
-        
-        feedback.pushInfo(f'Fitting done, saving model\n')
-        save_file = f'{self.method_name}.pkl'.lower()
-        metrics_save_file = f'{self.method_name}-metrics.json'.lower()
+        feedback.pushInfo("Fitting done, saving model\n")
+        save_file = f"{self.method_name}.pkl".lower()
+        metrics_save_file = f"{self.method_name}-metrics.json".lower()
         if self.save_model:
             out_path = os.path.join(self.output_dir, save_file)
             joblib.dump(self.best_model, out_path)
-            with open(os.path.join(self.output_dir, metrics_save_file), "w") as json_file:
+            with open(
+                os.path.join(self.output_dir, metrics_save_file), "w"
+            ) as json_file:
                 ## confusion matrix is a np array that does not fit in a json
-                best_metrics_dict.pop('conf_matrix', None)
-                best_metrics_dict.pop('class_report', None)
+                best_metrics_dict.pop("conf_matrix", None)
+                best_metrics_dict.pop("class_report", None)
                 json.dump(best_metrics_dict, json_file, indent=4)
 
         self.infer_model(feedback)
 
-        return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name, 'USED_SHP':self.used_shp_path}
-
+        return {
+            "OUTPUT_RASTER": self.dst_path,
+            "OUTPUT_LAYER_NAME": self.layer_name,
+            "USED_SHP": self.used_shp_path,
+        }
 
     def train_test_loop(self, feedback):
-        train_set, train_gts = self.get_raster(mode='train')
-        test_set, test_gts = self.get_raster(mode='test')
+        train_set, train_gts = self.get_raster(mode="train")
+        test_set, test_gts = self.get_raster(mode="test")
 
         self.model.fit(train_set, train_gts)
         predictions = self.model.predict(test_set)
-        return self.get_metrics(test_gts,predictions, feedback)
-
+        return self.get_metrics(test_gts, predictions, feedback)
 
     def infer_model(self, feedback):
-
         with rasterio.open(self.rlayer_path) as ds:
-
             transform = ds.transform
             crs = ds.crs
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = np.transpose(raster, (1,2,0))
-            raster = raster[:,:,self.input_bands]
-
+            raster = np.transpose(raster, (1, 2, 0))
+            raster = raster[:, :, self.input_bands]
 
             inf_raster = raster.reshape(-1, raster.shape[-1])
-            np.nan_to_num(inf_raster) # NaN to zero after normalisation
+            np.nan_to_num(inf_raster)  # NaN to zero after normalisation
 
             proj_img = self.best_model.predict(inf_raster)
 
-            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1],-1))
+            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1], -1))
             height, width, channels = proj_img.shape
 
-            feedback.pushInfo(f'Export to geotif\n')
-            with rasterio.open(self.dst_path, 'w', driver='GTiff',
-                               height=height, 
-                               width=width, 
-                               count=channels, 
-                               dtype=self.out_dtype,
-                               crs=crs, 
-                               transform=transform) as dst_ds:
+            feedback.pushInfo("Export to geotif\n")
+            with rasterio.open(
+                self.dst_path,
+                "w",
+                driver="GTiff",
+                height=height,
+                width=width,
+                count=channels,
+                dtype=self.out_dtype,
+                crs=crs,
+                transform=transform,
+            ) as dst_ds:
                 dst_ds.write(np.transpose(proj_img, (2, 0, 1)))
-            feedback.pushInfo(f'Export to geotif done\n')
-
+            feedback.pushInfo("Export to geotif done\n")
 
     def process_ml_shp(self, parameters, context, feedback):
-
         template_test = self.parameterAsVectorLayer(
-            parameters, self.TEMPLATE_TEST, context)
-        feedback.pushInfo(f'template_test: {template_test}')
+            parameters, self.TEMPLATE_TEST, context
+        )
+        feedback.pushInfo(f"template_test: {template_test}")
 
-        self.test_gdf=None
+        self.test_gdf = None
 
-        if template_test is not None :
+        if template_test is not None:
             random_samples = self.parameterAsInt(
-                parameters, self.RANDOM_SAMPLES, context)
+                parameters, self.RANDOM_SAMPLES, context
+            )
 
             gdf = gpd.read_file(template_test.dataProvider().dataSourceUri())
             gdf = gdf.to_crs(self.crs.toWkt())
 
-            feedback.pushInfo(f'before samples: {len(gdf)}')
+            feedback.pushInfo(f"before samples: {len(gdf)}")
             ## get random samples if geometry is not point based
             gdf = get_random_samples_in_gdf(gdf, random_samples, seed=self.seed)
 
-            feedback.pushInfo(f'before extent: {len(gdf)}')
+            feedback.pushInfo(f"before extent: {len(gdf)}")
             bounds = box(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+            )
             self.test_gdf = gdf[gdf.within(bounds)]
-            feedback.pushInfo(f'after extent: {len(self.test_gdf)}')
+            feedback.pushInfo(f"after extent: {len(self.test_gdf)}")
 
             if len(self.test_gdf) == 0:
                 feedback.pushWarning("No template points within extent !")
                 return False
 
-
     def process_ml_options(self, parameters, context, feedback):
-
-        self.save_model = self.parameterAsBoolean(
-                parameters, self.SAVE_MODEL, context)
-        self.do_kfold = self.parameterAsBoolean(
-            parameters, self.DO_KFOLDS, context)
-        gt_col = self.parameterAsString(
-            parameters, self.GT_COL, context)
-        fold_col = self.parameterAsString(
-            parameters, self.FOLD_COL, context)
-        nfolds = self.parameterAsInt(
-            parameters, self.NFOLDS, context)
-        str_kwargs = self.parameterAsString(
-                parameters, self.SK_PARAM, context)
+        self.save_model = self.parameterAsBoolean(parameters, self.SAVE_MODEL, context)
+        self.do_kfold = self.parameterAsBoolean(parameters, self.DO_KFOLDS, context)
+        gt_col = self.parameterAsString(parameters, self.GT_COL, context)
+        fold_col = self.parameterAsString(parameters, self.FOLD_COL, context)
+        nfolds = self.parameterAsInt(parameters, self.NFOLDS, context)
+        str_kwargs = self.parameterAsString(parameters, self.SK_PARAM, context)
 
         ## If a fold column is provided, this defines the folds. Otherwise, random split
         ## check that no column with name 'fold' exists, otherwise we use 'fold1' etc..
         ## we also make a new column containing gt values
-        self.fold_col = get_unique_col_name(self.gdf, 'fold')
-        self.gt_col = get_unique_col_name(self.gdf, 'gt')
+        self.fold_col = get_unique_col_name(self.gdf, "fold")
+        self.gt_col = get_unique_col_name(self.gdf, "gt")
 
         ## Instantiate model
-        if str_kwargs != '':
+        if str_kwargs != "":
             self.passed_kwargs = ast.literal_eval(str_kwargs)
         else:
             self.passed_kwargs = {}
 
-        method_idx = self.parameterAsEnum(
-            parameters, self.METHOD, context)
+        method_idx = self.parameterAsEnum(parameters, self.METHOD, context)
         self.method_name = self.method_opt[method_idx]
 
         try:
@@ -352,144 +344,151 @@ class MLAlgorithm(SHPAlgorithm):
         kwargs = self.update_kwargs(default_args)
 
         try:
-            self.model = instantiate_sklearn_algorithm(ensemble, self.method_name, **kwargs)
+            self.model = instantiate_sklearn_algorithm(
+                ensemble, self.method_name, **kwargs
+            )
         except AttributeError:
-            self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
-
+            self.model = instantiate_sklearn_algorithm(
+                neighbors, self.method_name, **kwargs
+            )
 
         ## different behaviours if we are doing classification or regression
         ## If classification, we create a new col with unique integers for each classes
         ## to ease inference
         self.task_type = check_model_type(self.model)
-        
-        if self.task_type == 'classification':
-            self.out_dtype = 'int8'
-            self.gdf[self.gt_col] = pd.factorize(self.gdf[gt_col])[0] # unique int for each class
+
+        if self.task_type == "classification":
+            self.out_dtype = "int8"
+            self.gdf[self.gt_col] = pd.factorize(self.gdf[gt_col])[
+                0
+            ]  # unique int for each class
         else:
             self.gt_col = gt_col
 
-
         ## If no test set is provided and the option to perform kfolds is true, we perform kfolds
-        if self.test_gdf == None and self.do_kfold:
-            if fold_col.strip() != '' :
+        if self.test_gdf is None and self.do_kfold:
+            if fold_col.strip() != "":
                 self.gdf[self.fold_col] = self.gdf[fold_col]
             else:
                 np.random.seed(self.seed)
-                self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
+                self.gdf[self.fold_col] = np.random.randint(
+                    1, nfolds + 1, size=len(self.gdf)
+                )
         ## Else, self.gdf is the train set
         else:
             self.train_gdf = self.gdf
 
-        feedback.pushInfo(f'saving modified dataframe to: {self.used_shp_path}')
+        feedback.pushInfo(f"saving modified dataframe to: {self.used_shp_path}")
         self.gdf.to_file(self.used_shp_path)
 
-
-    def get_raster(self, mode='train'):
-
-        if mode == 'train':
+    def get_raster(self, mode="train"):
+        if mode == "train":
             gdf = self.train_gdf
         else:
             gdf = self.test_gdf
 
         with rasterio.open(self.rlayer_path) as ds:
-
             gdf = gdf.to_crs(ds.crs)
             pixel_values = []
             gts = []
 
             transform = ds.transform
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = raster[self.input_bands,:,:]
+            raster = raster[self.input_bands, :, :]
 
             for index, data in gdf.iterrows():
                 # Get the coordinates of the point in the raster's pixel space
                 x, y = data.geometry.x, data.geometry.y
 
                 # Convert point coordinates to pixel coordinates within the window
-                col, row = ~transform * (x, y)  # Convert from map coordinates to pixel coordinates
+                col, row = ~transform * (
+                    x,
+                    y,
+                )  # Convert from map coordinates to pixel coordinates
                 col, row = int(col), int(row)
-                pixel_values.append(list(raster[:,row, col]))
+                pixel_values.append(list(raster[:, row, col]))
                 gts.append(data[self.gt_col])
 
-
         return np.asarray(pixel_values), np.asarray(gts)
 
-
     def update_kwargs(self, kwargs_dict):
-
         for key, value in self.passed_kwargs.items():
             if key in kwargs_dict.keys():
                 kwargs_dict[key] = value
-        
-        kwargs_dict['random_state'] = self.seed
 
-        return kwargs_dict
+        kwargs_dict["random_state"] = self.seed
 
+        return kwargs_dict
 
     def get_metrics(self, test_gts, predictions, feedback):
-
         metrics_dict = {}
-        if self.task_type == 'classification':
+        if self.task_type == "classification":
             # Evaluate the model
-            metrics_dict['accuracy'] = accuracy_score(test_gts, predictions)
-            metrics_dict['precision'] = precision_score(test_gts, predictions, average='weighted')  # Modify `average` for multiclass if necessary
-            metrics_dict['recall'] = recall_score(test_gts, predictions, average='weighted')
-            metrics_dict['f1'] = f1_score(test_gts, predictions, average='weighted')
-            metrics_dict['conf_matrix'] = confusion_matrix(test_gts, predictions)
-            metrics_dict['class_report'] = classification_report(test_gts, predictions)
-
-
-        elif self.task_type == 'regression':
+            metrics_dict["accuracy"] = accuracy_score(test_gts, predictions)
+            metrics_dict["precision"] = precision_score(
+                test_gts, predictions, average="weighted"
+            )  # Modify `average` for multiclass if necessary
+            metrics_dict["recall"] = recall_score(
+                test_gts, predictions, average="weighted"
+            )
+            metrics_dict["f1"] = f1_score(test_gts, predictions, average="weighted")
+            metrics_dict["conf_matrix"] = confusion_matrix(test_gts, predictions)
+            metrics_dict["class_report"] = classification_report(test_gts, predictions)
 
-            metrics_dict['mae'] = mean_absolute_error(test_gts, predictions)
-            metrics_dict['mse'] = mean_squared_error(test_gts, predictions)
-            metrics_dict['rmse'] = np.sqrt(metrics_dict['mse'])
-            metrics_dict['r2'] = r2_score(test_gts, predictions)
+        elif self.task_type == "regression":
+            metrics_dict["mae"] = mean_absolute_error(test_gts, predictions)
+            metrics_dict["mse"] = mean_squared_error(test_gts, predictions)
+            metrics_dict["rmse"] = np.sqrt(metrics_dict["mse"])
+            metrics_dict["r2"] = r2_score(test_gts, predictions)
 
         else:
-            feedback.pushWarning('Unable to evaluate the model !!')
+            feedback.pushWarning("Unable to evaluate the model !!")
 
         for key, value in metrics_dict.items():
-            feedback.pushInfo(f'{key}:\t {value}')
+            feedback.pushInfo(f"{key}:\t {value}")
 
         return metrics_dict
-        
 
     def get_algorithms(self):
-        required_methods = ['fit', 'predict']
+        required_methods = ["fit", "predict"]
         ensemble_algos = get_sklearn_algorithms_with_methods(ensemble, required_methods)
-        neighbors_algos = get_sklearn_algorithms_with_methods(neighbors, required_methods)
-        return sorted(ensemble_algos+neighbors_algos)
-
+        neighbors_algos = get_sklearn_algorithms_with_methods(
+            neighbors, required_methods
+        )
+        return sorted(ensemble_algos + neighbors_algos)
 
     def get_help_sk_methods(self):
         """
         Generate help string with default arguments of supported sklearn algorithms.
         """
-            
-        help_str = '\n\n Here are the default arguments of the supported algorithms:\n\n'
 
-        required_methods = ['fit', 'predict']
+        help_str = (
+            "\n\n Here are the default arguments of the supported algorithms:\n\n"
+        )
+
+        required_methods = ["fit", "predict"]
 
         ensemble_algos = get_sklearn_algorithms_with_methods(ensemble, required_methods)
-        for algo in ensemble_algos :
-            args = get_arguments(ensemble,algo)
-            help_str += f'- {algo}:\n'
-            help_str += f'{args}\n'
+        for algo in ensemble_algos:
+            args = get_arguments(ensemble, algo)
+            help_str += f"- {algo}:\n"
+            help_str += f"{args}\n"
 
-        neighbors_algos = get_sklearn_algorithms_with_methods(neighbors, required_methods)
-        for algo in neighbors_algos :
-            args = get_arguments(neighbors,algo)
-            help_str += f'- {algo}:\n'
-            help_str += f'{args}\n'
+        neighbors_algos = get_sklearn_algorithms_with_methods(
+            neighbors, required_methods
+        )
+        for algo in neighbors_algos:
+            args = get_arguments(neighbors, algo)
+            help_str += f"- {algo}:\n"
+            help_str += f"{args}\n"
 
         return help_str
 
@@ -497,12 +496,11 @@ class MLAlgorithm(SHPAlgorithm):
     def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
         return {}
 
-
     def tr(self, string):
         """
         Returns a translatable string with the self.tr() function.
         """
-        return QCoreApplication.translate('Processing', string)
+        return QCoreApplication.translate("Processing", string)
 
     def createInstance(self):
         return MLAlgorithm()
@@ -515,21 +513,21 @@ class MLAlgorithm(SHPAlgorithm):
         lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return 'ml'
+        return "ml"
 
     def displayName(self):
         """
         Returns the translated algorithm name, which should be used for any
         user-visible display of the algorithm name.
         """
-        return self.tr('Machine Learning')
+        return self.tr("Machine Learning")
 
     def group(self):
         """
         Returns the name of the group this algorithm belongs to. This string
         should be localised.
         """
-        return self.tr('')
+        return self.tr("")
 
     def groupId(self):
         """
@@ -539,7 +537,7 @@ class MLAlgorithm(SHPAlgorithm):
         contain lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return ''
+        return ""
 
     def shortHelpString(self):
         """
@@ -547,7 +545,9 @@ class MLAlgorithm(SHPAlgorithm):
         should provide a basic description about what the algorithm does and the
         parameters and outputs associated with it..
         """
-        return self.tr(f"Fit a Machine Learning model using input template. Only RandomForestClassifier is throughfully tested. \n{self.get_help_sk_methods()}")
+        return self.tr(
+            f"Fit a Machine Learning model using input template. Only RandomForestClassifier is throughfully tested. \n{self.get_help_sk_methods()}"
+        )
 
     def icon(self):
         return QIcon_RandomforestTool
diff --git a/provider.py b/provider.py
index aca942999311d531100d58023b515e913ca3c376..b460103d085d2bd0c7d6fe6f4cad249a5f9864c1 100644
--- a/provider.py
+++ b/provider.py
@@ -9,7 +9,6 @@ from .icons import QIcon_EncoderTool
 
 
 class IAMapProvider(QgsProcessingProvider):
-
     def loadAlgorithms(self, *args, **kwargs):
         self.addAlgorithm(EncoderAlgorithm())
         self.addAlgorithm(ReductionAlgorithm())
@@ -25,7 +24,7 @@ class IAMapProvider(QgsProcessingProvider):
         This string should be a unique, short, character only string,
         eg "qgis" or "gdal". This string should not be localised.
         """
-        return 'iamap'
+        return "iamap"
 
     def name(self, *args, **kwargs):
         """The human friendly name of your plugin in Processing.
@@ -33,7 +32,7 @@ class IAMapProvider(QgsProcessingProvider):
         This string should be as short as possible (e.g. "Lastools", not
         "Lastools version 1.0.1 64-bit") and localised.
         """
-        return self.tr('IAMap')
+        return self.tr("IAMap")
 
     def icon(self):
         """Should return a QIcon which is used for your provider inside
@@ -43,4 +42,3 @@ class IAMapProvider(QgsProcessingProvider):
 
     def longName(self) -> str:
         return self.name()
-
diff --git a/reduction.py b/reduction.py
index f6028641284d57b8dcb0e8a38c261ddb396369bb..767ddb1d9637dba957ad91fb78ffacccfe44c860 100644
--- a/reduction.py
+++ b/reduction.py
@@ -1,55 +1,20 @@
-import os
-import tempfile
-import numpy as np
-from pathlib import Path
-from typing import Dict, Any
-import joblib
-import pandas as pd
-import psutil
-import json
-
-
-import rasterio
-from rasterio import windows
 from qgis.PyQt.QtCore import QCoreApplication
-from qgis.core import (Qgis,
-                       QgsGeometry,
-                       QgsProcessingParameterBoolean,
-                       QgsProcessingParameterEnum,
-                       QgsCoordinateTransform,
-                       QgsProcessingException,
-                       QgsProcessingAlgorithm,
-                       QgsProcessingParameterRasterLayer,
-                       QgsProcessingParameterFolderDestination,
-                       QgsProcessingParameterBand,
-                       QgsProcessingParameterNumber,
-                       QgsProcessingParameterExtent,
-                       QgsProcessingParameterCrs,
-                       QgsProcessingParameterDefinition,
-                       )
 
 
-from sklearn.preprocessing import StandardScaler
-from sklearn.decomposition import PCA, IncrementalPCA
-from sklearn.cluster import KMeans
-
-from .utils.misc import get_unique_filename
 from .utils.algo import SKAlgorithm
 
 from .icons import QIcon_ReductionTool
-#from umap.umap_ import UMAP
-
+# from umap.umap_ import UMAP
 
 
 class ReductionAlgorithm(SKAlgorithm):
-    """
-    """
+    """ """
 
     def tr(self, string):
         """
         Returns a translatable string with the self.tr() function.
         """
-        return QCoreApplication.translate('Processing', string)
+        return QCoreApplication.translate("Processing", string)
 
     def createInstance(self):
         return ReductionAlgorithm()
@@ -62,21 +27,21 @@ class ReductionAlgorithm(SKAlgorithm):
         lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return 'reduction'
+        return "reduction"
 
     def displayName(self):
         """
         Returns the translated algorithm name, which should be used for any
         user-visible display of the algorithm name.
         """
-        return self.tr('Dimension Reduction')
+        return self.tr("Dimension Reduction")
 
     def group(self):
         """
         Returns the name of the group this algorithm belongs to. This string
         should be localised.
         """
-        return self.tr('')
+        return self.tr("")
 
     def groupId(self):
         """
@@ -86,7 +51,7 @@ class ReductionAlgorithm(SKAlgorithm):
         contain lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return ''
+        return ""
 
     def shortHelpString(self):
         """
@@ -94,7 +59,9 @@ class ReductionAlgorithm(SKAlgorithm):
         should provide a basic description about what the algorithm does and the
         parameters and outputs associated with it..
         """
-        return self.tr(f"Reduce the dimension of deep learning features. Only PCA is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}")
+        return self.tr(
+            f"Reduce the dimension of deep learning features. Only PCA is thoughfully tested. Other algorithms are implemented as is by sklearn. {self.get_help_sk_methods()}"
+        )
 
     def icon(self):
         return QIcon_ReductionTool
diff --git a/requirements-ga.txt b/requirements-ga.txt
new file mode 100644
index 0000000000000000000000000000000000000000..066935452485f829bd69401b77c620ac240f9214
--- /dev/null
+++ b/requirements-ga.txt
@@ -0,0 +1,14 @@
+geopandas >= 0.14.4
+scikit-learn >= 1.5.1
+psutil >= 5.0.0
+# from torchgeo
+rasterio >= 1.2
+rtree >= 0.9
+einops >= 0.3
+fiona >= 1.8.19
+kornia >= 0.6.9
+numpy >= 1.19.3
+pyproj >= 3.3
+shapely >= 1.7.1
+timm >= 0.4.12
+pytest
diff --git a/requirements.txt b/requirements.txt
index 9e22d502d0f5b04223f31aa1de69e293566197be..a2d31d0b5a68a42242e7d6deda18a7a2bc829f1e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ psutil >= 5.0.0
 # torchgeo == 0.5.2
 # from torchgeo
 rasterio >= 1.2
-rtree <= 0.9
+rtree >= 1
 einops >= 0.3
 fiona >= 1.8.19
 kornia >= 0.6.9
diff --git a/similarity.py b/similarity.py
index ed0b38a3c4b918d8a14db0ffd4b3ecbf5e4430d4..d911d0fb35341f49eda1783c14609a31c6a01852 100644
--- a/similarity.py
+++ b/similarity.py
@@ -3,14 +3,12 @@ from .utils.algo import SHPAlgorithm
 from .icons import QIcon_SimilarityTool
 
 
-
 class SimilarityAlgorithm(SHPAlgorithm):
-
     def tr(self, string):
         """
         Returns a translatable string with the self.tr() function.
         """
-        return QCoreApplication.translate('Processing', string)
+        return QCoreApplication.translate("Processing", string)
 
     def createInstance(self):
         return SimilarityAlgorithm()
@@ -23,21 +21,21 @@ class SimilarityAlgorithm(SHPAlgorithm):
         lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return 'similarity'
+        return "similarity"
 
     def displayName(self):
         """
         Returns the translated algorithm name, which should be used for any
         user-visible display of the algorithm name.
         """
-        return self.tr('Similarity')
+        return self.tr("Similarity")
 
     def group(self):
         """
         Returns the name of the group this algorithm belongs to. This string
         should be localised.
         """
-        return self.tr('')
+        return self.tr("")
 
     def groupId(self):
         """
@@ -47,7 +45,7 @@ class SimilarityAlgorithm(SHPAlgorithm):
         contain lowercase alphanumeric characters only and no spaces or other
         formatting characters.
         """
-        return ''
+        return ""
 
     def shortHelpString(self):
         """
diff --git a/tests/__init__.py b/tests/__init__.py
index c3de68ef53f20832da722033737c1995b311443e..c0a9a565b53893ed7f0ac80e4786c3b9db9c7ecc 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -3,7 +3,14 @@ import os
 
 PYTHON_VERSION = sys.version_info
 SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
-PLUGIN_ROOT_DIR = os.path.realpath(os.path.abspath(os.path.join(SCRIPT_DIR, '..')))
-PACKAGES_INSTALL_DIR = os.path.join(PLUGIN_ROOT_DIR, f'python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}')
+PLUGIN_ROOT_DIR = os.path.realpath(os.path.abspath(os.path.join(SCRIPT_DIR, "..")))
+QGIS_PYTHON_DIR = os.path.realpath(os.path.abspath(os.path.join(PLUGIN_ROOT_DIR, "..")))
+PACKAGES_INSTALL_DIR = os.path.join(
+    PLUGIN_ROOT_DIR, f"python{PYTHON_VERSION.major}.{PYTHON_VERSION.minor}"
+)
 
 sys.path.append(PACKAGES_INSTALL_DIR)  # TODO: check for a less intrusive way to do this
+
+qgis_python_path = os.getenv("PYTHONPATH")
+if qgis_python_path and qgis_python_path not in sys.path:
+    sys.path.append(qgis_python_path)
diff --git a/tests/test_common.py b/tests/test_common.py
index 7e222a25d97a9b58b2f90a7cba68430494eda87f..792cc558d2afd87d5674f93df194a4d11b88c37d 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,36 +1,47 @@
 import os
+import pytest
 from pathlib import Path
 import tempfile
 import unittest
 from qgis.core import (
-        QgsProcessingContext, 
-        QgsProcessingFeedback,
-        )
+    QgsProcessingContext,
+    QgsProcessingFeedback,
+)
 
 from ..ml import MLAlgorithm
 from ..similarity import SimilarityAlgorithm
 from ..clustering import ClusterAlgorithm
 from ..reduction import ReductionAlgorithm
 from ..utils.misc import get_file_md5_hash, remove_files_with_extensions
+from ..utils.geo import validate_geotiff
+
 
-INPUT = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'test.tif')
+INPUT = os.path.join(Path(__file__).parent.parent.absolute(), "assets", "test.tif")
 OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test")
-EXTENSIONS_TO_RM = ['.tif', '.pkl', '.json', '.shp', '.shx', '.prj', '.dbf', '.cpg']
-TEMPLATE = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'template.shp')
-TEMPLATE_RF = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'ml_poly.shp')
-GT_COL = 'Type'
+EXTENSIONS_TO_RM = [".tif", ".pkl", ".json", ".shp", ".shx", ".prj", ".dbf", ".cpg"]
+TEMPLATE = os.path.join(
+    Path(__file__).parent.parent.absolute(), "assets", "template.shp"
+)
+TEMPLATE_RF = os.path.join(
+    Path(__file__).parent.parent.absolute(), "assets", "ml_poly.shp"
+)
+GT_COL = "Type"
+
 
 class TestReductionAlgorithm(unittest.TestCase):
     """
     Base test class, other will inherit from this
     """
+
     algorithm = ReductionAlgorithm()
-    default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT}
+    default_parameters = {"INPUT": INPUT, "OUTPUT": OUTPUT}
     possible_hashes = [
-            'd7a32c6b7a4cee1af9c73607561d7b25',
-            'e04f8c86d9aad81dd9c625b9cd8f9824',
-                       ]
-    out_name = 'proj.tif'
+        "d7a32c6b7a4cee1af9c73607561d7b25",
+        "e04f8c86d9aad81dd9c625b9cd8f9824",
+    ]
+    output_size = 4405122
+    output_wh = (968,379)
+    out_name = "proj.tif"
 
     def setUp(self):
         self.context = QgsProcessingContext()
@@ -38,41 +49,51 @@ class TestReductionAlgorithm(unittest.TestCase):
 
     def test_valid_parameters(self):
         self.algorithm.initAlgorithm()
-        result = self.algorithm.processAlgorithm(self.default_parameters, self.context, self.feedback)
-        expected_result_path = os.path.join(self.algorithm.output_dir,self.out_name)
-        result_file_hash = get_file_md5_hash(expected_result_path)
+        _ = self.algorithm.processAlgorithm(
+            self.default_parameters, self.context, self.feedback
+        )
+        expected_result_path = os.path.join(self.algorithm.output_dir, self.out_name)
+        @pytest.mark.parametrize("output_file", expected_result_path, "expected_output_size", self.output_size, "expected_wh", self.output_wh)
+        def test_geotiff_validity(output_file):
+            validate_geotiff(output_file)
         remove_files_with_extensions(self.algorithm.output_dir, EXTENSIONS_TO_RM)
-        assert result_file_hash in self.possible_hashes
-        
 
 
 class TestClusteringAlgorithm(TestReductionAlgorithm):
     algorithm = ClusterAlgorithm()
-    possible_hashes = ['0c47b0c4b4c13902db5da3ee6e5d4aef']
-    out_name = 'cluster.tif'
+    # possible_hashes = ["0c47b0c4b4c13902db5da3ee6e5d4aef"]
+    out_name = "cluster.tif"
+    output_size = 4405122
 
 
 class TestSimAlgorithm(TestReductionAlgorithm):
     algorithm = SimilarityAlgorithm()
-    default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT,'TEMPLATE':TEMPLATE}
-    possible_hashes = ['f76eb1f0469725b49fe0252cfe86829a']
-    out_name = 'similarity.tif'
+    default_parameters = {"INPUT": INPUT, "OUTPUT": OUTPUT, "TEMPLATE": TEMPLATE}
+    # possible_hashes = ["f76eb1f0469725b49fe0252cfe86829a"]
+    out_name = "similarity.tif"
+    output_size = 1468988
 
 
 class TestMLAlgorithm(TestReductionAlgorithm):
     algorithm = MLAlgorithm()
-    default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT,'TEMPLATE':TEMPLATE_RF,'GT_COL': GT_COL}
-    possible_hashes = ['bd22d66180347e043fca58d494876184']
-    out_name = 'ml.tif'
+    default_parameters = {
+        "INPUT": INPUT,
+        "OUTPUT": OUTPUT,
+        "TEMPLATE": TEMPLATE_RF,
+        "GT_COL": GT_COL,
+    }
+    # possible_hashes = ["bd22d66180347e043fca58d494876184"]
+    out_name = "ml.tif"
+    output_size = 367520
 
-if __name__ == "__main__":
 
+if __name__ == "__main__":
     for algo in [
-            TestReductionAlgorithm(),
-            TestClusteringAlgorithm(),
-            TestSimAlgorithm(),
-            TestMLAlgorithm(),
-            ]:
+        TestReductionAlgorithm(),
+        TestClusteringAlgorithm(),
+        TestSimAlgorithm(),
+        TestMLAlgorithm(),
+    ]:
         algo.setUp()
         print(algo.algorithm)
         algo.test_valid_parameters()
diff --git a/tests/test_encoder.py b/tests/test_encoder.py
index 57c6456fb8219a7dba900308c2cccded6b6caf06..ead5a66275b2c2538e47a1841a718baf262304c1 100644
--- a/tests/test_encoder.py
+++ b/tests/test_encoder.py
@@ -4,146 +4,132 @@ from pathlib import Path
 import unittest
 import pytest
 from qgis.core import (
-        QgsProcessingContext, 
-        QgsProcessingFeedback,
-        )
+    QgsProcessingContext,
+    QgsProcessingFeedback,
+)
 
 import timm
 import torch
+
 # from torchgeo.datasets import RasterDataset
-from..tg.datasets import RasterDataset
+from ..tg.datasets import RasterDataset
 
 from ..encoder import EncoderAlgorithm
 from ..utils.misc import get_file_md5_hash
+from ..utils.geo import validate_geotiff
 
 
-INPUT = os.path.join(Path(__file__).parent.parent.absolute(), 'assets', 'test.tif')
+INPUT = os.path.join(Path(__file__).parent.parent.absolute(), "assets", "test.tif")
 OUTPUT = os.path.join(tempfile.gettempdir(), "iamap_test")
 
-class TestEncoderAlgorithm(unittest.TestCase):
 
+class TestEncoderAlgorithm(unittest.TestCase):
     def setUp(self):
         self.context = QgsProcessingContext()
         self.feedback = QgsProcessingFeedback()
         self.algorithm = EncoderAlgorithm()
         self.default_parameters = {
-                'BACKBONE_CHOICE': '', 
-                'BACKBONE_OPT': 0, 
-                'BANDS': None, 
-                'BATCH_SIZE': 1, 
-                'CKPT': 'NULL', 
-                'CRS': None, 
-                'CUDA': True, 
-                'CUDA_ID': 0, 
-                'EXTENT': None, 
-                'FEAT_OPTION': True, 
-                'INPUT': INPUT,
-                'MERGE_METHOD': 0, 
-                'OUTPUT': OUTPUT,
-                'PAUSES': 0, 
-                'QUANT': True, 
-                'REMOVE_TEMP_FILES': True, 
-                'RESOLUTION': None, 
-                'SIZE': 224, 
-                'STRIDE': 224, 
-                'TEMP_FILES_CLEANUP_FREQ': 1000, 
-                'WORKERS': 0,
-                'JSON_PARAM': 'NULL', 
-                'OUT_DTYPE': 0, 
-                      }
+            "BACKBONE_CHOICE": "",
+            "BACKBONE_OPT": 0,
+            "BANDS": None,
+            "BATCH_SIZE": 1,
+            "CKPT": "NULL",
+            "CRS": None,
+            "CUDA": True,
+            "CUDA_ID": 0,
+            "EXTENT": None,
+            "FEAT_OPTION": True,
+            "INPUT": INPUT,
+            "MERGE_METHOD": 0,
+            "OUTPUT": OUTPUT,
+            "PAUSES": 0,
+            "QUANT": True,
+            "REMOVE_TEMP_FILES": True,
+            "RESOLUTION": None,
+            "SIZE": 224,
+            "STRIDE": 224,
+            "TEMP_FILES_CLEANUP_FREQ": 1000,
+            "WORKERS": 0,
+            "JSON_PARAM": "NULL",
+            "OUT_DTYPE": 0,
+        }
 
     def test_valid_parameters(self):
-
         self.algorithm.initAlgorithm()
-        _ = self.algorithm.processAlgorithm(self.default_parameters, self.context, self.feedback)
-        expected_result_path = os.path.join(self.algorithm.output_subdir,'merged.tif')
-        result_file_hash = get_file_md5_hash(expected_result_path)
-
-        ## different rasterio versions lead to different hashes ? 
-        ## GPU and quantization as well
-        possible_hashes = [
-                '0fb32cc57a0dd427d9f0165ec6d5418f',
-                '48c3a78773dbc2c4c7bb7885409284ab',
-                '431e034b842129679b99a067f2bd3ba4',
-                '60153535214eaa44458db4e297af72b9',
-                'f1394d1950f91e4f8277a8667ae77e85',
-                'a23837caa3aca54aaca2974d546c5123',
-                           ]
-        assert result_file_hash in possible_hashes
-        os.remove(expected_result_path)
-
-    @pytest.mark.slow
-    def test_data_types(self):
-
-        self.algorithm.initAlgorithm()
-        parameters = self.default_parameters
-        parameters['OUT_DTYPE'] = 1
-        _ = self.algorithm.processAlgorithm(parameters, self.context, self.feedback)
-        expected_result_path = os.path.join(self.algorithm.output_subdir,'merged.tif')
-        result_file_hash = get_file_md5_hash(expected_result_path)
-
-        ## different rasterio versions lead to different hashes ? 
-        possible_hashes = [
-                'ef0c4b0d57f575c1cd10c0578c7114c0',
-                'ebfad32752de71c5555bda2b40c19b2e',
-                'd3705c256320b7190dd4f92ad2087247',
-                '65fa46916d6d0d08ad9656d7d7fabd01',
-                           ]
-        assert result_file_hash in possible_hashes
+        _ = self.algorithm.processAlgorithm(
+            self.default_parameters, self.context, self.feedback
+        )
+        expected_result_path = os.path.join(self.algorithm.output_subdir, "merged.tif")
+        @pytest.mark.parametrize("output_file", expected_result_path)
+        def test_geotiff_validity(output_file):
+            validate_geotiff(output_file)
         os.remove(expected_result_path)
 
 
     def test_timm_create_model(self):
-
         archs = [
-                'vit_base_patch16_224.dino',
-                'vit_tiny_patch16_224.augreg_in21k',
-                'vit_base_patch16_224.mae',
-                'samvit_base_patch16.sa1b',
-                ]
+            "vit_base_patch16_224.dino",
+            "vit_tiny_patch16_224.augreg_in21k",
+            "vit_base_patch16_224.mae",
+            "samvit_base_patch16.sa1b",
+        ]
         expected_output_size = [
-                torch.Size([1,197,768]),
-                torch.Size([1,197,192]),
-                torch.Size([1,197,768]),
-                torch.Size([1, 256, 64, 64]),
-                ]
+            torch.Size([1, 197, 768]),
+            torch.Size([1, 197, 192]),
+            torch.Size([1, 197, 768]),
+            torch.Size([1, 256, 64, 64]),
+        ]
 
         for arch, exp_feat_size in zip(archs, expected_output_size):
-
             model = timm.create_model(
                 arch,
                 pretrained=True,
                 in_chans=6,
                 num_classes=0,
-                )
+            )
             model = model.eval()
 
             data_config = timm.data.resolve_model_data_config(model)
-            _, h, w, = data_config['input_size']
-            output = model.forward_features(torch.randn(1,6,h,w))
+            (
+                _,
+                h,
+                w,
+            ) = data_config["input_size"]
+            output = model.forward_features(torch.randn(1, 6, h, w))
 
             assert output.shape == exp_feat_size
 
-
     def test_RasterDataset(self):
-
         self.algorithm.initAlgorithm()
         parameters = {}
         self.algorithm.process_options(parameters, self.context, self.feedback)
         RasterDataset.filename_glob = self.algorithm.rlayer_name
         RasterDataset.all_bands = [
-            self.algorithm.rlayer.bandName(i_band) for i_band in range(1, self.algorithm.rlayer.bandCount()+1)
+            self.algorithm.rlayer.bandName(i_band)
+            for i_band in range(1, self.algorithm.rlayer.bandCount() + 1)
         ]
         # currently only support rgb bands
-        input_bands = [self.algorithm.rlayer.bandName(i_band)
-                       for i_band in self.algorithm.selected_bands]
+        input_bands = [
+            self.algorithm.rlayer.bandName(i_band)
+            for i_band in self.algorithm.selected_bands
+        ]
 
         if self.algorithm.crs == self.algorithm.rlayer.crs():
             dataset = RasterDataset(
-                paths=self.algorithm.rlayer_dir, crs=None, res=self.algorithm.res, bands=input_bands, cache=False)
+                paths=self.algorithm.rlayer_dir,
+                crs=None,
+                res=self.algorithm.res,
+                bands=input_bands,
+                cache=False,
+            )
         else:
             dataset = RasterDataset(
-                paths=self.algorithm.rlayer_dir, crs=self.algorithm.crs.toWkt(), res=self.algorithm.res, bands=input_bands, cache=False)
+                paths=self.algorithm.rlayer_dir,
+                crs=self.algorithm.crs.toWkt(),
+                res=self.algorithm.res,
+                bands=input_bands,
+                cache=False,
+            )
         del dataset
 
     def test_cuda(self):
@@ -151,15 +137,10 @@ class TestEncoderAlgorithm(unittest.TestCase):
             assert True
 
 
-
-
-
 if __name__ == "__main__":
-
     test_encoder = TestEncoderAlgorithm()
     test_encoder.setUp()
     test_encoder.test_timm_create_model()
     test_encoder.test_RasterDataset()
     test_encoder.test_valid_parameters()
-    test_encoder.test_data_types()
     test_encoder.test_cuda()
diff --git a/tg/datasets.py b/tg/datasets.py
index 792d265527c8afd55bd47f26ee8449b130f22f22..1398bd62c231a0495ea1c9ae9c33523b6c0d576e 100644
--- a/tg/datasets.py
+++ b/tg/datasets.py
@@ -1,4 +1,4 @@
-# modified from torchgeo code 
+# modified from torchgeo code
 
 """Base classes for all :mod:`torchgeo` datasets."""
 
@@ -390,10 +390,10 @@ class RasterDataset(GeoDataset):
         .. versionchanged:: 0.5
            *root* was renamed to *paths*.
         """
-        print('pre super.__init__')
+        print("pre super.__init__")
         super().__init__(transforms)
 
-        print('post super.__init__')
+        print("post super.__init__")
         self.paths = paths
         self.bands = bands or self.all_bands
         self.cache = cache
@@ -403,7 +403,7 @@ class RasterDataset(GeoDataset):
         filename_regex = re.compile(self.filename_regex, re.VERBOSE)
         for filepath in self.files:
             match = re.match(filename_regex, os.path.basename(filepath))
-            print('regex')
+            print("regex")
             if match is not None:
                 try:
                     with rasterio.open(filepath) as src:
@@ -580,7 +580,6 @@ class RasterDataset(GeoDataset):
             return src
 
 
-
 class IntersectionDataset(GeoDataset):
     """Dataset representing the intersection of two GeoDatasets.
 
@@ -882,4 +881,3 @@ class UnionDataset(GeoDataset):
         self._res = new_res
         self.datasets[0].res = new_res
         self.datasets[1].res = new_res
-
diff --git a/tg/samplers.py b/tg/samplers.py
index cf3da11f7d5b9c5da4552f14cde0d9300ab11569..24ea788a814983ec4429e51fcada1b8496da3e86 100644
--- a/tg/samplers.py
+++ b/tg/samplers.py
@@ -267,7 +267,6 @@ class GridGeoSampler(GeoSampler):
 
 
 class NoBordersGridGeoSampler(GridGeoSampler):
-
     def __iter__(self) -> Iterator[BoundingBox]:
         """
         Modification of original Torchgeo sampler to avoid overlapping borders of a dataset.
diff --git a/tg/transforms.py b/tg/transforms.py
index 28386342fcdac6969df50d844f5d3a33bf7a23b3..a08e684587817d55495d559ed7d63d4520530196 100644
--- a/tg/transforms.py
+++ b/tg/transforms.py
@@ -178,4 +178,3 @@ class _NCropGenerator(K.random_generator.CropGenerator):
             "input_size": out[0]["input_size"],
             "output_size": out[0]["output_size"],
         }
-
diff --git a/tg/utils.py b/tg/utils.py
index fe99302d52fa649fd14b9adb9bc73b5956f01ff4..e5dd031ad366e8370d66a14b0300565052a788ef 100644
--- a/tg/utils.py
+++ b/tg/utils.py
@@ -1,4 +1,4 @@
-# modified from torchgeo code 
+# modified from torchgeo code
 
 """Common dataset utilities."""
 
@@ -33,7 +33,6 @@ __all__ = (
 )
 
 
-
 @dataclass(frozen=True)
 class BoundingBox:
     """Data class for indexing spatiotemporal data."""
@@ -308,7 +307,6 @@ def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
     return mint.timestamp(), maxt.timestamp()
 
 
-
 def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list[Any]]:
     """Convert a list of dictionaries to a dictionary of lists.
 
@@ -415,7 +413,6 @@ def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]:
     return collated
 
 
-
 def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]:
     """Load an image file using rasterio.
 
@@ -432,8 +429,6 @@ def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]:
     return array
 
 
-
-
 def path_is_vsi(path: str) -> bool:
     """Checks if the given path is pointing to a Virtual File System.
 
@@ -458,19 +453,15 @@ def path_is_vsi(path: str) -> bool:
     return "://" in path or path.startswith("/vsi")
 
 
-
 """Common sampler utilities."""
 
 
-
 @overload
-def _to_tuple(value: Union[tuple[int, int], int]) -> tuple[int, int]:
-    ...
+def _to_tuple(value: Union[tuple[int, int], int]) -> tuple[int, int]: ...
 
 
 @overload
-def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]:
-    ...
+def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: ...
 
 
 def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]:
diff --git a/utils/algo.py b/utils/algo.py
index 12ab8a82f4f2e68798159af2eb9859a7d32a1383..ce21e15997d42d1c4d87efe49db8d030f6a75af9 100644
--- a/utils/algo.py
+++ b/utils/algo.py
@@ -7,22 +7,24 @@ import joblib
 from collections import Counter
 from pathlib import Path
 from typing import Dict, Any
-from qgis.core import (Qgis,
-                       QgsGeometry,
-                       QgsCoordinateTransform,
-                       QgsProcessingException,
-                       QgsProcessingAlgorithm, QgsProcessingParameterBoolean,
-                       QgsProcessingParameterRasterLayer,
-                       QgsProcessingParameterFolderDestination,
-                       QgsProcessingParameterBand,
-                       QgsProcessingParameterNumber,
-                       QgsProcessingParameterEnum,
-                       QgsProcessingParameterVectorLayer,
-                       QgsProcessingParameterExtent,
-                       QgsProcessingParameterString,
-                       QgsProcessingParameterCrs,
-                       QgsProcessingParameterDefinition, QgsProcessingParameterVectorLayer,
-                       )
+from qgis.core import (
+    Qgis,
+    QgsGeometry,
+    QgsCoordinateTransform,
+    QgsProcessingException,
+    QgsProcessingAlgorithm,
+    QgsProcessingParameterBoolean,
+    QgsProcessingParameterRasterLayer,
+    QgsProcessingParameterFolderDestination,
+    QgsProcessingParameterBand,
+    QgsProcessingParameterNumber,
+    QgsProcessingParameterEnum,
+    QgsProcessingParameterExtent,
+    QgsProcessingParameterString,
+    QgsProcessingParameterCrs,
+    QgsProcessingParameterDefinition,
+    QgsProcessingParameterVectorLayer,
+)
 import rasterio
 from rasterio import windows
 from rasterio.enums import Resampling
@@ -49,12 +51,13 @@ def get_sklearn_algorithms_with_methods(module, required_methods):
     # Get all classes in the module that are subclasses of BaseEstimator
     algorithms = []
     for name, obj in inspect.getmembers(module, inspect.isclass):
-        if issubclass(obj, BaseEstimator) and not name.startswith('_'):
+        if issubclass(obj, BaseEstimator) and not name.startswith("_"):
             # Check if the class has all the required methods
             if all(hasattr(obj, method) for method in required_methods):
                 algorithms.append(name)
     return algorithms
 
+
 def instantiate_sklearn_algorithm(module, algorithm_name, **kwargs):
     # Retrieve the class from the module by name
     AlgorithmClass = getattr(module, algorithm_name)
@@ -66,50 +69,47 @@ def get_arguments(module, algorithm_name):
     AlgorithmClass = getattr(module, algorithm_name)
     # Get the signature of the __init__ method
     init_signature = inspect.signature(AlgorithmClass.__init__)
-    
+
     # Retrieve the parameters of the __init__ method
     parameters = init_signature.parameters
     default_kwargs = {}
-    
+
     for param_name, param in parameters.items():
         # Skip 'self'
-        if param_name != 'self':
+        if param_name != "self":
             # if param.default == None:
             #     required_kwargs[param_name] = None  # Placeholder for the required value
             # else:
             default_kwargs[param_name] = param.default
-    
+
     # return required_kwargs, default_kwargs
     return default_kwargs
 
 
 def get_iter(model, fit_raster):
-
     iter = None
-    if hasattr(model, 'partial_fit') and hasattr(model, 'max_iter'):
+    if hasattr(model, "partial_fit") and hasattr(model, "max_iter"):
         iter = range(model.max_iter)
 
-    if hasattr(model, 'partial_fit') and not hasattr(model, 'max_iter'):
+    if hasattr(model, "partial_fit") and not hasattr(model, "max_iter"):
         chunk_size = calculate_chunk_size(fit_raster)
-        iter = range(0, len(fit_raster), chunk_size) 
+        iter = range(0, len(fit_raster), chunk_size)
 
     return iter
 
 
 class IAMAPAlgorithm(QgsProcessingAlgorithm):
-    """
-    """
-
-    INPUT = 'INPUT'
-    BANDS = 'BANDS'
-    EXTENT = 'EXTENT'
-    OUTPUT = 'OUTPUT'
-    RESOLUTION = 'RESOLUTION'
-    RANDOM_SEED = 'RANDOM_SEED'
-    CRS = 'CRS'
-    COMPRESS = 'COMPRESS'
-    TMP_DIR = 'iamap_tmp'
-    
+    """ """
+
+    INPUT = "INPUT"
+    BANDS = "BANDS"
+    EXTENT = "EXTENT"
+    OUTPUT = "OUTPUT"
+    RESOLUTION = "RESOLUTION"
+    RANDOM_SEED = "RANDOM_SEED"
+    CRS = "CRS"
+    COMPRESS = "COMPRESS"
+    TMP_DIR = "iamap_tmp"
 
     def initAlgorithm(self, config=None):
         """
@@ -121,20 +121,17 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
 
         compress_param = QgsProcessingParameterBoolean(
             name=self.COMPRESS,
-            description=self.tr(
-                'Compress final result to JP2'),
+            description=self.tr("Compress final result to JP2"),
             defaultValue=True,
             optional=True,
         )
 
-        for param in (
-                compress_param,
-                ):
+        for param in (compress_param,):
             param.setFlags(
-                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+            )
             self.addParameter(param)
 
-
     def init_input_output_raster(self):
         self.cwd = Path(__file__).parent.absolute()
         tmp_wd = os.path.join(tempfile.gettempdir(), self.TMP_DIR)
@@ -142,17 +139,16 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
         self.addParameter(
             QgsProcessingParameterRasterLayer(
                 name=self.INPUT,
-                description=self.tr(
-                    'Input raster layer or image file path'),
-            # defaultValue=os.path.join(self.cwd,'assets','test.tif'),
+                description=self.tr("Input raster layer or image file path"),
+                # defaultValue=os.path.join(self.cwd,'assets','test.tif'),
             ),
         )
 
         self.addParameter(
             QgsProcessingParameterBand(
                 name=self.BANDS,
-                description=self.tr('Selected Bands (defaults to all bands selected)'),
-                defaultValue = None, 
+                description=self.tr("Selected Bands (defaults to all bands selected)"),
+                defaultValue=None,
                 parentLayerParameterName=self.INPUT,
                 optional=True,
                 allowMultiple=True,
@@ -161,26 +157,26 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
 
         crs_param = QgsProcessingParameterCrs(
             name=self.CRS,
-            description=self.tr('Target CRS (default to original CRS)'),
+            description=self.tr("Target CRS (default to original CRS)"),
             optional=True,
         )
 
         res_param = QgsProcessingParameterNumber(
             name=self.RESOLUTION,
             description=self.tr(
-                'Target resolution in meters (default to native resolution)'),
+                "Target resolution in meters (default to native resolution)"
+            ),
             type=QgsProcessingParameterNumber.Double,
             optional=True,
             minValue=0,
-            maxValue=100000
+            maxValue=100000,
         )
 
         self.addParameter(
             QgsProcessingParameterExtent(
                 name=self.EXTENT,
-                description=self.tr(
-                    'Processing extent (default to the entire image)'),
-                optional=True
+                description=self.tr("Processing extent (default to the entire image)"),
+                optional=True,
             )
         )
 
@@ -188,84 +184,82 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
             QgsProcessingParameterFolderDestination(
                 self.OUTPUT,
                 self.tr(
-                    "Output directory (choose the location that the image features will be saved)"),
-            defaultValue=tmp_wd,
+                    "Output directory (choose the location that the image features will be saved)"
+                ),
+                defaultValue=tmp_wd,
             )
         )
 
         for param in (
-                crs_param, 
-                res_param, 
-                ):
+            crs_param,
+            res_param,
+        ):
             param.setFlags(
-                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+            )
             self.addParameter(param)
 
     def init_seed(self):
         seed_param = QgsProcessingParameterNumber(
             name=self.RANDOM_SEED,
-            description=self.tr(
-                'Random seed'),
+            description=self.tr("Random seed"),
             type=QgsProcessingParameterNumber.Integer,
             defaultValue=42,
             minValue=0,
-            maxValue=100000
+            maxValue=100000,
         )
         seed_param.setFlags(
-            seed_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+            seed_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+        )
         self.addParameter(seed_param)
 
-
-
-    def process_geo_parameters(self,parameters, context, feedback):
+    def process_geo_parameters(self, parameters, context, feedback):
         """
         Handle geographic parameters that are common to all algorithms (CRS, resolution, extent, selected bands).
         """
 
-        rlayer = self.parameterAsRasterLayer(
-            parameters, self.INPUT, context)
-        
+        rlayer = self.parameterAsRasterLayer(parameters, self.INPUT, context)
+
         if rlayer is None:
             raise QgsProcessingException(
-                self.invalidRasterError(parameters, self.INPUT))
+                self.invalidRasterError(parameters, self.INPUT)
+            )
 
         self.rlayer_path = rlayer.dataProvider().dataSourceUri()
         self.rlayer_dir = os.path.dirname(self.rlayer_path)
         self.rlayer_name = os.path.basename(self.rlayer_path)
 
-        self.selected_bands = self.parameterAsInts(
-            parameters, self.BANDS, context)
+        self.selected_bands = self.parameterAsInts(parameters, self.BANDS, context)
 
         if len(self.selected_bands) == 0:
-            self.selected_bands = list(range(1, rlayer.bandCount()+1))
+            self.selected_bands = list(range(1, rlayer.bandCount() + 1))
 
         if max(self.selected_bands) > rlayer.bandCount():
             raise QgsProcessingException(
                 self.tr("The chosen bands exceed the largest band number!")
             )
-        res = self.parameterAsDouble(
-            parameters, self.RESOLUTION, context)
-        crs = self.parameterAsCrs(
-            parameters, self.CRS, context)
-        extent = self.parameterAsExtent(
-            parameters, self.EXTENT, context)
+        res = self.parameterAsDouble(parameters, self.RESOLUTION, context)
+        crs = self.parameterAsCrs(parameters, self.CRS, context)
+        extent = self.parameterAsExtent(parameters, self.EXTENT, context)
 
         # handle crs
         if crs is None or not crs.isValid():
             crs = rlayer.crs()
             feedback.pushInfo(
-                f'Layer CRS unit is {crs.mapUnits()}')  # 0 for meters, 6 for degrees, 9 for unknown
+                f"Layer CRS unit is {crs.mapUnits()}"
+            )  # 0 for meters, 6 for degrees, 9 for unknown
             feedback.pushInfo(
-                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
+                f"whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}"
+            )
             if crs.mapUnits() == Qgis.DistanceUnit.Degrees:
                 crs = self.estimate_utm_crs(rlayer.extent())
 
         # target crs should use meters as units
         if crs.mapUnits() != Qgis.DistanceUnit.Meters:
+            feedback.pushInfo(f"Layer CRS unit is {crs.mapUnits()}")
             feedback.pushInfo(
-                f'Layer CRS unit is {crs.mapUnits()}')
-            feedback.pushInfo(
-                f'whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}')
+                f"whether the CRS is a geographic CRS (using lat/lon coordinates) {crs.isGeographic()}"
+            )
             raise QgsProcessingException(
                 self.tr("Only support CRS with the units as meters")
             )
@@ -278,31 +272,37 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
             res = rlayer.rasterUnitsPerPixelX()  # rasterUnitsPerPixelY() is negative
         else:
             # when given res in meters by users, convert crs to utm if the original crs unit is degree
-            if crs.mapUnits() != UNIT_METERS: # Qgis.DistanceUnit.Meters:
-                if rlayer.crs().mapUnits() == UNIT_DEGREES: # Qgis.DistanceUnit.Degrees:
+            if crs.mapUnits() != UNIT_METERS:  # Qgis.DistanceUnit.Meters:
+                if (
+                    rlayer.crs().mapUnits() == UNIT_DEGREES
+                ):  # Qgis.DistanceUnit.Degrees:
                     # estimate utm crs based on layer extent
                     crs = self.estimate_utm_crs(rlayer.extent())
                 else:
                     raise QgsProcessingException(
-                        f"Resampling of image with the CRS of {crs.authid()} in meters is not supported.")
+                        f"Resampling of image with the CRS of {crs.authid()} in meters is not supported."
+                    )
             # else:
             #     res = (rlayer_extent.xMaximum() -
             #            rlayer_extent.xMinimum()) / rlayer.width()
 
         # handle extent
         if extent.isNull():
-            extent = rlayer.extent()  # QgsProcessingUtils.combineLayerExtents(layers, crs, context)
+            extent = (
+                rlayer.extent()
+            )  # QgsProcessingUtils.combineLayerExtents(layers, crs, context)
             extent_crs = rlayer.crs()
         else:
             if extent.isEmpty():
                 raise QgsProcessingException(
-                    self.tr("The extent for processing can not be empty!"))
-            extent_crs = self.parameterAsExtentCrs(
-                parameters, self.EXTENT, context)
+                    self.tr("The extent for processing can not be empty!")
+                )
+            extent_crs = self.parameterAsExtentCrs(parameters, self.EXTENT, context)
         # if extent crs != target crs, convert it to target crs
         if extent_crs != crs:
             transform = QgsCoordinateTransform(
-                extent_crs, crs, context.transformContext())
+                extent_crs, crs, context.transformContext()
+            )
             # extent = transform.transformBoundingBox(extent)
             # to ensure coverage of the transformed extent
             # convert extent to polygon, transform polygon, then get boundingBox of the new polygon
@@ -314,36 +314,41 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
         # check intersects between extent and rlayer_extent
         if rlayer.crs() != crs:
             transform = QgsCoordinateTransform(
-                rlayer.crs(), crs, context.transformContext())
-            rlayer_extent = transform.transformBoundingBox(
-                rlayer.extent())
+                rlayer.crs(), crs, context.transformContext()
+            )
+            rlayer_extent = transform.transformBoundingBox(rlayer.extent())
         else:
             rlayer_extent = rlayer.extent()
         if not rlayer_extent.intersects(extent):
             raise QgsProcessingException(
-                self.tr("The extent for processing is not intersected with the input image!"))
-
+                self.tr(
+                    "The extent for processing is not intersected with the input image!"
+                )
+            )
 
-        img_width_in_extent = round(
-            (extent.xMaximum() - extent.xMinimum())/res)
-        img_height_in_extent = round(
-            (extent.yMaximum() - extent.yMinimum())/res)
+        img_width_in_extent = round((extent.xMaximum() - extent.xMinimum()) / res)
+        img_height_in_extent = round((extent.yMaximum() - extent.yMinimum()) / res)
 
         feedback.pushInfo(
-            (f'Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},'
-             f'miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}'))
+            (
+                f"Processing extent: minx:{extent.xMinimum():.6f}, maxx:{extent.xMaximum():.6f},"
+                f"miny:{extent.yMinimum():.6f}, maxy:{extent.yMaximum():.6f}"
+            )
+        )
         feedback.pushInfo(
-            (f'Processing image size: (width {img_width_in_extent}, '
-             f'height {img_height_in_extent})'))
+            (
+                f"Processing image size: (width {img_width_in_extent}, "
+                f"height {img_height_in_extent})"
+            )
+        )
 
         # Send some information to the user
-        feedback.pushInfo(
-            f'Layer path: {rlayer.dataProvider().dataSourceUri()}')
+        feedback.pushInfo(f"Layer path: {rlayer.dataProvider().dataSourceUri()}")
         # feedback.pushInfo(
         #     f'Layer band scale: {rlayer_data_provider.bandScale(self.selected_bands[0])}')
-        feedback.pushInfo(f'Layer name: {rlayer.name()}')
+        feedback.pushInfo(f"Layer name: {rlayer.name()}")
 
-        feedback.pushInfo(f'Bands selected: {self.selected_bands}')
+        feedback.pushInfo(f"Bands selected: {self.selected_bands}")
 
         self.extent = extent
         self.rlayer = rlayer
@@ -354,50 +359,50 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
         """
         Compress final file to JP2.
         """
-        
-        feedback.pushInfo(f'Compressing to JP2')
 
-        file = parameters['OUTPUT_RASTER']
-        dst_path = Path(file).with_suffix('.jp2')
+        feedback.pushInfo("Compressing to JP2")
+
+        file = parameters["OUTPUT_RASTER"]
+        dst_path = Path(file).with_suffix(".jp2")
 
         ## update in the parameters
-        parameters['OUTPUT_RASTER'] = dst_path
+        parameters["OUTPUT_RASTER"] = dst_path
 
         with rasterio.open(file) as src:
             # Read the data
             float_data = src.read(resampling=Resampling.nearest)
-    
+
             # Initialize an array for the normalized uint16 data
             uint16_data = np.empty_like(float_data, dtype=np.uint16)
-            
+
             # Loop through each band to normalize individually
             for i in range(float_data.shape[0]):
                 band = float_data[i]
-                
+
                 # Find min and max of the current band
                 band_min = np.min(band)
                 band_max = np.max(band)
-                
+
                 # Normalize to the range [0, 1]
                 normalized_band = (band - band_min) / (band_max - band_min)
-                
+
                 # Scale to the uint16 range [0, 65535]
                 uint16_data[i] = (normalized_band * 65535).astype(np.uint16)
-            
+
             # Define metadata for the output JP2
             profile = src.profile
             profile.update(
-                driver='JP2OpenJPEG',   # Specify JP2 driver
-                dtype='uint16',        # Keep data as float32
-                compress='jp2',         # Compression type (note: might be driver-specific)
-                crs=src.crs,            # Coordinate system
-                transform=src.transform # Affine transform
+                driver="JP2OpenJPEG",  # Specify JP2 driver
+                dtype="uint16",  # Keep data as float32
+                compress="jp2",  # Compression type (note: might be driver-specific)
+                crs=src.crs,  # Coordinate system
+                transform=src.transform,  # Affine transform
             )
             # profile.update(tiled=False)
             profile.update(tiled=True, blockxsize=256, blockysize=256)
 
             # Write to JP2 file
-            with rasterio.open(dst_path, 'w', **profile) as dst:
+            with rasterio.open(dst_path, "w", **profile) as dst:
                 dst.write(uint16_data)
 
         return dst_path
@@ -407,24 +412,22 @@ class IAMAPAlgorithm(QgsProcessingAlgorithm):
         return {}
 
 
-
 class SKAlgorithm(IAMAPAlgorithm):
     """
     Common class that handles helper functions for sklearn algorithms.
     Behaviour defaults to projection algorithms (PCA etc...)
     """
 
-    LOAD = 'LOAD'
-    OUTPUT = 'OUTPUT'
-    MAIN_PARAM = 'MAIN_PARAM'
-    SUBSET = 'SUBSET'
-    METHOD = 'METHOD'
-    SAVE_MODEL = 'SAVE_MODEL'
-    COMPRESS = 'COMPRESS'
-    SK_PARAM = 'SK_PARAM'
-    TMP_DIR = 'iamap_reduction'
-    TYPE = 'proj'
-    
+    LOAD = "LOAD"
+    OUTPUT = "OUTPUT"
+    MAIN_PARAM = "MAIN_PARAM"
+    SUBSET = "SUBSET"
+    METHOD = "METHOD"
+    SAVE_MODEL = "SAVE_MODEL"
+    COMPRESS = "COMPRESS"
+    SK_PARAM = "SK_PARAM"
+    TMP_DIR = "iamap_reduction"
+    TYPE = "proj"
 
     def initAlgorithm(self, config=None):
         """
@@ -434,86 +437,86 @@ class SKAlgorithm(IAMAPAlgorithm):
         self.init_input_output_raster()
         self.init_seed()
 
-        proj_methods = ['fit', 'transform']
-        clust_methods = ['fit', 'fit_predict']
-        if self.TYPE == 'proj':
-            method_opt1 = get_sklearn_algorithms_with_methods(decomposition, proj_methods)
+        proj_methods = ["fit", "transform"]
+        clust_methods = ["fit", "fit_predict"]
+        if self.TYPE == "proj":
+            method_opt1 = get_sklearn_algorithms_with_methods(
+                decomposition, proj_methods
+            )
             method_opt2 = get_sklearn_algorithms_with_methods(cluster, proj_methods)
             self.method_opt = method_opt1 + method_opt2
 
             self.addParameter(
                 QgsProcessingParameterNumber(
                     name=self.MAIN_PARAM,
-                    description=self.tr(
-                        'Number of target components'),
+                    description=self.tr("Number of target components"),
                     type=QgsProcessingParameterNumber.Integer,
-                    defaultValue = 3,
+                    defaultValue=3,
                     minValue=1,
-                    maxValue=1024
+                    maxValue=1024,
                 )
             )
-            default_index = self.method_opt.index('PCA')
-        else :
-            self.method_opt = get_sklearn_algorithms_with_methods(cluster, clust_methods)
+            default_index = self.method_opt.index("PCA")
+        else:
+            self.method_opt = get_sklearn_algorithms_with_methods(
+                cluster, clust_methods
+            )
             self.addParameter(
                 QgsProcessingParameterNumber(
                     name=self.MAIN_PARAM,
-                    description=self.tr(
-                        'Number of target clusters'),
+                    description=self.tr("Number of target clusters"),
                     type=QgsProcessingParameterNumber.Integer,
-                    defaultValue = 3,
+                    defaultValue=3,
                     minValue=1,
-                    maxValue=1024
+                    maxValue=1024,
                 )
             )
-            default_index = self.method_opt.index('KMeans')
+            default_index = self.method_opt.index("KMeans")
 
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterEnum(
-                name = self.METHOD,
-                description = self.tr(
-                    'Sklearn algorithm used'),
-                defaultValue = default_index,
-                options = self.method_opt,
+                name=self.METHOD,
+                description=self.tr("Sklearn algorithm used"),
+                defaultValue=default_index,
+                options=self.method_opt,
             )
         )
 
-        self.addParameter (
+        self.addParameter(
             QgsProcessingParameterString(
-                name = self.SK_PARAM,
-                description = self.tr(
-                    'Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments.'),
-                defaultValue = '',
+                name=self.SK_PARAM,
+                description=self.tr(
+                    "Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments."
+                ),
+                defaultValue="",
                 optional=True,
             )
         )
 
         subset_param = QgsProcessingParameterNumber(
-                name=self.SUBSET,
-                description=self.tr(
-                    'Select a subset of random pixels of the image to fit transform'),
-                type=QgsProcessingParameterNumber.Integer,
-                defaultValue=None,
-                minValue=1,
-                maxValue=10_000,
-                optional=True,
-                )
+            name=self.SUBSET,
+            description=self.tr(
+                "Select a subset of random pixels of the image to fit transform"
+            ),
+            type=QgsProcessingParameterNumber.Integer,
+            defaultValue=None,
+            minValue=1,
+            maxValue=10_000,
+            optional=True,
+        )
 
         save_param = QgsProcessingParameterBoolean(
-                self.SAVE_MODEL,
-                self.tr("Save projection model after fit."),
-                defaultValue=True
-                )
+            self.SAVE_MODEL,
+            self.tr("Save projection model after fit."),
+            defaultValue=True,
+        )
 
-        for param in (
-                subset_param, 
-                save_param
-                ):
+        for param in (subset_param, save_param):
             param.setFlags(
-                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+                param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+            )
             self.addParameter(param)
 
-
     def processAlgorithm(self, parameters, context, feedback):
         """
         Here is where the processing itself takes place.
@@ -527,14 +530,18 @@ class SKAlgorithm(IAMAPAlgorithm):
         rlayer_name, ext = os.path.splitext(rlayer_basename)
 
         ### Handle args that differ between clustering and projection methods
-        if self.TYPE == 'proj':
-            self.out_dtype = 'float32'
-            self.dst_path, self.layer_name = get_unique_filename(self.output_dir, f'{self.TYPE}.tif', f'{rlayer_name} reduction')
+        if self.TYPE == "proj":
+            self.out_dtype = "float32"
+            self.dst_path, self.layer_name = get_unique_filename(
+                self.output_dir, f"{self.TYPE}.tif", f"{rlayer_name} reduction"
+            )
         else:
-            self.out_dtype = 'uint8'
-            self.dst_path, self.layer_name = get_unique_filename(self.output_dir, f'{self.TYPE}.tif', f'{rlayer_name} cluster')
+            self.out_dtype = "uint8"
+            self.dst_path, self.layer_name = get_unique_filename(
+                self.output_dir, f"{self.TYPE}.tif", f"{rlayer_name} cluster"
+            )
 
-        parameters['OUTPUT_RASTER']=self.dst_path
+        parameters["OUTPUT_RASTER"] = self.dst_path
 
         try:
             default_args = get_arguments(decomposition, self.method_name)
@@ -547,15 +554,22 @@ class SKAlgorithm(IAMAPAlgorithm):
         do_fit_predict = False
 
         try:
-            model = instantiate_sklearn_algorithm(decomposition, self.method_name, **kwargs)
+            model = instantiate_sklearn_algorithm(
+                decomposition, self.method_name, **kwargs
+            )
         except AttributeError:
             model = instantiate_sklearn_algorithm(cluster, self.method_name, **kwargs)
-            ## if model does not have a 'predict()' method, then we do a fit_predict in one go 
-            if not hasattr(model, 'predict'):
+            ## if model does not have a 'predict()' method, then we do a fit_predict in one go
+            if not hasattr(model, "predict"):
                 do_fit_predict = True
-        except:
-            feedback.pushWarning(f'{self.method_name} not properly initialized ! Try passing custom parameters')
-            return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name}
+        except Exception:
+            feedback.pushWarning(
+                f"{self.method_name} not properly initialized ! Try passing custom parameters"
+            )
+            return {
+                "OUTPUT_RASTER": self.dst_path,
+                "OUTPUT_LAYER_NAME": self.layer_name,
+            }
 
         if do_fit_predict:
             proj_img, model = self.fit_predict(model, feedback)
@@ -565,102 +579,94 @@ class SKAlgorithm(IAMAPAlgorithm):
             model = self.fit_model(model, fit_raster, iter, feedback)
 
         self.print_transform_metrics(model, feedback)
-        self.print_cluster_metrics(model,fit_raster, feedback)
-        feedback.pushInfo(f'Fitting done, saving model\n')
-        save_file = f'{self.method_name}.pkl'.lower()
+        self.print_cluster_metrics(model, fit_raster, feedback)
+        feedback.pushInfo("Fitting done, saving model\n")
+        save_file = f"{self.method_name}.pkl".lower()
         if self.save_model:
             out_path = os.path.join(self.output_dir, save_file)
             joblib.dump(model, out_path)
 
         if not do_fit_predict:
-            feedback.pushInfo(f'Inference over raster\n')
+            feedback.pushInfo("Inference over raster\n")
             self.infer_model(model, feedback, scaler)
 
-        return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name}
-
+        return {"OUTPUT_RASTER": self.dst_path, "OUTPUT_LAYER_NAME": self.layer_name}
 
-    def process_common_sklearn(self,parameters, context):
+    def process_common_sklearn(self, parameters, context):
+        self.subset = self.parameterAsInt(parameters, self.SUBSET, context)
+        self.seed = self.parameterAsInt(parameters, self.RANDOM_SEED, context)
+        self.main_param = self.parameterAsInt(parameters, self.MAIN_PARAM, context)
+        self.subset = self.parameterAsInt(parameters, self.SUBSET, context)
 
-        self.subset = self.parameterAsInt(
-            parameters, self.SUBSET, context)
-        self.seed = self.parameterAsInt(
-            parameters, self.RANDOM_SEED, context)
-        self.main_param = self.parameterAsInt(
-            parameters, self.MAIN_PARAM, context)
-        self.subset = self.parameterAsInt(
-            parameters, self.SUBSET, context)
-
-        method_idx = self.parameterAsEnum(
-            parameters, self.METHOD, context)
+        method_idx = self.parameterAsEnum(parameters, self.METHOD, context)
         self.method_name = self.method_opt[method_idx]
 
-        str_kwargs = self.parameterAsString(
-                parameters, self.SK_PARAM, context)
-        if str_kwargs != '':
+        str_kwargs = self.parameterAsString(parameters, self.SK_PARAM, context)
+        if str_kwargs != "":
             self.passed_kwargs = ast.literal_eval(str_kwargs)
         else:
             self.passed_kwargs = {}
 
-        self.input_bands = [i_band -1 for i_band in self.selected_bands]
+        self.input_bands = [i_band - 1 for i_band in self.selected_bands]
 
-        self.save_model = self.parameterAsBoolean(
-                parameters, self.SAVE_MODEL, context)
-        output_dir = self.parameterAsString(
-            parameters, self.OUTPUT, context)
+        self.save_model = self.parameterAsBoolean(parameters, self.SAVE_MODEL, context)
+        output_dir = self.parameterAsString(parameters, self.OUTPUT, context)
         self.output_dir = Path(output_dir)
         self.output_dir.mkdir(parents=True, exist_ok=True)
 
-
     def get_fit_raster(self, feedback):
-
         with rasterio.open(self.rlayer_path) as ds:
-
             transform = ds.transform
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = np.transpose(raster, (1,2,0))
-            raster = raster[:,:,self.input_bands]
+            raster = np.transpose(raster, (1, 2, 0))
+            raster = raster[:, :, self.input_bands]
             fit_raster = raster.reshape(-1, raster.shape[-1])
             scaler = StandardScaler()
             scaler.fit(fit_raster)
 
             if self.subset:
-
-                feedback.pushInfo(f'Using a random subset of {self.subset} pixels, random seed is {self.seed}')
+                feedback.pushInfo(
+                    f"Using a random subset of {self.subset} pixels, random seed is {self.seed}"
+                )
 
                 fit_raster = raster.reshape(-1, raster.shape[-1])
                 nsamples = fit_raster.shape[0]
-    
+
                 # Generate random indices to select subset_size number of samples
                 np.random.seed(self.seed)
-                random_indices = np.random.choice(nsamples, size=self.subset, replace=False)
-                fit_raster = fit_raster[random_indices,:]
+                random_indices = np.random.choice(
+                    nsamples, size=self.subset, replace=False
+                )
+                fit_raster = fit_raster[random_indices, :]
 
                 # remove nans
                 fit_raster = fit_raster[~np.isnan(fit_raster).any(axis=1)]
 
             fit_raster = scaler.transform(fit_raster)
-            np.nan_to_num(fit_raster) # NaN to zero after normalisation
+            np.nan_to_num(fit_raster)  # NaN to zero after normalisation
 
         return fit_raster, scaler
 
-    def fit_model(self, model, fit_raster, iter,feedback):
-
-        feedback.pushInfo(f'Starting fit. If it goes for too long, consider setting a subset.\n')
+    def fit_model(self, model, fit_raster, iter, feedback):
+        feedback.pushInfo(
+            "Starting fit. If it goes for too long, consider setting a subset.\n"
+        )
 
         ## if fitting can be divided, we provide the possibility to cancel and to have progression
-        if iter and hasattr(model, 'partial_fit'):
+        if iter and hasattr(model, "partial_fit"):
             for i in iter:
                 if feedback.isCanceled():
                     feedback.pushWarning(
-                        self.tr("\n !!!Processing is canceled by user!!! \n"))
+                        self.tr("\n !!!Processing is canceled by user!!! \n")
+                    )
                     break
                 model.partial_fit(fit_raster)
                 feedback.setProgress((i / len(iter)) * 100)
@@ -672,102 +678,102 @@ class SKAlgorithm(IAMAPAlgorithm):
         return model
 
     def fit_predict(self, model, feedback):
-
         with rasterio.open(self.rlayer_path) as ds:
-
             transform = ds.transform
             crs = ds.crs
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = np.transpose(raster, (1,2,0))
-            raster = raster[:,:,self.input_bands]
-
+            raster = np.transpose(raster, (1, 2, 0))
+            raster = raster[:, :, self.input_bands]
 
             # raster = (raster-np.mean(raster))/np.std(raster)
             scaler = StandardScaler()
 
             raster = scaler.fit_transform(raster)
-            np.nan_to_num(raster) # NaN to zero after normalisation
+            np.nan_to_num(raster)  # NaN to zero after normalisation
 
             proj_img = model.fit_predict(raster.reshape(-1, raster.shape[-1]))
 
-            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1],-1))
+            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1], -1))
             height, width, channels = proj_img.shape
 
-            feedback.pushInfo(f'Export to geotif\n')
-            with rasterio.open(self.dst_path, 'w', driver='GTiff',
-                               height=height, 
-                               width=width, 
-                               count=channels, 
-                               dtype=self.out_dtype,
-                               crs=crs, 
-                               transform=transform) as dst_ds:
+            feedback.pushInfo("Export to geotif\n")
+            with rasterio.open(
+                self.dst_path,
+                "w",
+                driver="GTiff",
+                height=height,
+                width=width,
+                count=channels,
+                dtype=self.out_dtype,
+                crs=crs,
+                transform=transform,
+            ) as dst_ds:
                 dst_ds.write(np.transpose(proj_img, (2, 0, 1)))
-            feedback.pushInfo(f'Export to geotif done\n')
+            feedback.pushInfo("Export to geotif done\n")
 
         return model
 
-
     def infer_model(self, model, feedback, scaler=None):
-
         with rasterio.open(self.rlayer_path) as ds:
-
             transform = ds.transform
             crs = ds.crs
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = np.transpose(raster, (1,2,0))
-            raster = raster[:,:,self.input_bands]
-
+            raster = np.transpose(raster, (1, 2, 0))
+            raster = raster[:, :, self.input_bands]
 
             inf_raster = raster.reshape(-1, raster.shape[-1])
             if scaler:
                 inf_raster = scaler.transform(inf_raster)
-            np.nan_to_num(inf_raster) # NaN to zero after normalisation
+            np.nan_to_num(inf_raster)  # NaN to zero after normalisation
 
-            if self.TYPE == 'cluster':
+            if self.TYPE == "cluster":
                 proj_img = model.predict(inf_raster)
 
             else:
                 proj_img = model.transform(inf_raster)
 
-            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1],-1))
+            proj_img = proj_img.reshape((raster.shape[0], raster.shape[1], -1))
             height, width, channels = proj_img.shape
 
-            feedback.pushInfo(f'Export to geotif\n')
-            with rasterio.open(self.dst_path, 'w', driver='GTiff',
-                               height=height, 
-                               width=width, 
-                               count=channels, 
-                               dtype=self.out_dtype,
-                               crs=crs, 
-                               transform=transform) as dst_ds:
+            feedback.pushInfo("Export to geotif\n")
+            with rasterio.open(
+                self.dst_path,
+                "w",
+                driver="GTiff",
+                height=height,
+                width=width,
+                count=channels,
+                dtype=self.out_dtype,
+                crs=crs,
+                transform=transform,
+            ) as dst_ds:
                 dst_ds.write(np.transpose(proj_img, (2, 0, 1)))
-            feedback.pushInfo(f'Export to geotif done\n')
+            feedback.pushInfo("Export to geotif done\n")
 
     def update_kwargs(self, kwargs_dict):
+        if "n_clusters" in kwargs_dict.keys():
+            kwargs_dict["n_clusters"] = self.main_param
+        if "n_components" in kwargs_dict.keys():
+            kwargs_dict["n_components"] = self.main_param
 
-        if 'n_clusters' in kwargs_dict.keys():
-            kwargs_dict['n_clusters'] = self.main_param
-        if 'n_components' in kwargs_dict.keys():
-            kwargs_dict['n_components'] = self.main_param
-
-        if 'random_state' in kwargs_dict.keys():
-            kwargs_dict['random_state'] = self.seed
+        if "random_state" in kwargs_dict.keys():
+            kwargs_dict["random_state"] = self.seed
 
         for key, value in self.passed_kwargs.items():
             if key in kwargs_dict.keys():
@@ -779,29 +785,31 @@ class SKAlgorithm(IAMAPAlgorithm):
         """
         Generate help string with default arguments of supported sklearn algorithms.
         """
-            
-        proj_methods = ['fit', 'transform']
-        clust_methods = ['fit', 'fit_predict']
-        help_str = '\n\n Here are the default arguments of the supported algorithms:\n\n'
 
-        if self.TYPE == 'proj':
+        proj_methods = ["fit", "transform"]
+        clust_methods = ["fit", "fit_predict"]
+        help_str = (
+            "\n\n Here are the default arguments of the supported algorithms:\n\n"
+        )
+
+        if self.TYPE == "proj":
             algos = get_sklearn_algorithms_with_methods(decomposition, proj_methods)
-            for algo in algos :
-                args = get_arguments(decomposition,algo)
-                help_str += f'- {algo}:\n'
-                help_str += f'{args}\n'
+            for algo in algos:
+                args = get_arguments(decomposition, algo)
+                help_str += f"- {algo}:\n"
+                help_str += f"{args}\n"
             algos = get_sklearn_algorithms_with_methods(cluster, proj_methods)
-            for algo in algos :
-                args = get_arguments(cluster,algo)
-                help_str += f'- {algo}:\n'
-                help_str += f'{args}\n'
+            for algo in algos:
+                args = get_arguments(cluster, algo)
+                help_str += f"- {algo}:\n"
+                help_str += f"{args}\n"
 
-        if self.TYPE == 'cluster':
+        if self.TYPE == "cluster":
             algos = get_sklearn_algorithms_with_methods(cluster, clust_methods)
-            for algo in algos :
-                args = get_arguments(cluster,algo)
-                help_str += f'- {algo}:\n'
-                help_str += f'{args}\n'
+            for algo in algos:
+                args = get_arguments(cluster, algo)
+                help_str += f"- {algo}:\n"
+                help_str += f"{args}\n"
 
         return help_str
 
@@ -809,8 +817,8 @@ class SKAlgorithm(IAMAPAlgorithm):
         """
         Log common metrics after a PCA.
         """
-        
-        if hasattr(model, 'explained_variance_ratio_'):
+
+        if hasattr(model, "explained_variance_ratio_"):
             # Explained variance ratio
             explained_variance_ratio = model.explained_variance_ratio_
 
@@ -820,19 +828,22 @@ class SKAlgorithm(IAMAPAlgorithm):
             # Loadings (Principal axes)
             loadings = model.components_.T * np.sqrt(model.explained_variance_)
 
-            feedback.pushInfo(f'Explained Variance Ratio : \n{explained_variance_ratio}')
-            feedback.pushInfo(f'Cumulative Explained Variance : \n{cumulative_variance}')
-            feedback.pushInfo(f'Loadings (Principal axes) : \n{loadings}')
+            feedback.pushInfo(
+                f"Explained Variance Ratio : \n{explained_variance_ratio}"
+            )
+            feedback.pushInfo(
+                f"Cumulative Explained Variance : \n{cumulative_variance}"
+            )
+            feedback.pushInfo(f"Loadings (Principal axes) : \n{loadings}")
 
-    def print_cluster_metrics(self, model, fit_raster ,feedback):
+    def print_cluster_metrics(self, model, fit_raster, feedback):
         """
         Log common metrics after a Kmeans.
         """
-        
-        if hasattr(model, 'inertia_'):
 
-            feedback.pushInfo(f'Inertia : \n{model.inertia_}')
-            feedback.pushInfo(f'Cluster sizes : \n{Counter(model.labels_)}')
+        if hasattr(model, "inertia_"):
+            feedback.pushInfo(f"Inertia : \n{model.inertia_}")
+            feedback.pushInfo(f"Cluster sizes : \n{Counter(model.labels_)}")
             ## silouhette score seem to heavy for now
             # feedback.pushInfo(f'Silhouette Score : \n{silhouette_score(fit_raster, model.labels_)}')
             # feedback.pushInfo(f'Silouhette Values : \n{silhouette_values(fit_raster, model.labels_)}')
@@ -842,18 +853,16 @@ class SKAlgorithm(IAMAPAlgorithm):
         return {}
 
 
-
 class SHPAlgorithm(IAMAPAlgorithm):
     """
     Common class for algorithms relying on shapefile data.
     """
 
-    TEMPLATE = 'TEMPLATE'
-    RANDOM_SAMPLES = 'RANDOM_SAMPLES'
-    TMP_DIR = 'iamap_sim'
-    DEFAULT_TEMPLATE = 'template.shp'
-    TYPE = 'similarity'
-    
+    TEMPLATE = "TEMPLATE"
+    RANDOM_SAMPLES = "RANDOM_SAMPLES"
+    TMP_DIR = "iamap_sim"
+    DEFAULT_TEMPLATE = "template.shp"
+    TYPE = "similarity"
 
     def initAlgorithm(self, config=None):
         """
@@ -864,7 +873,6 @@ class SHPAlgorithm(IAMAPAlgorithm):
         self.init_seed()
         self.init_input_shp()
 
-
     def processAlgorithm(self, parameters, context, feedback):
         """
         Here is where the processing itself takes place.
@@ -876,141 +884,144 @@ class SHPAlgorithm(IAMAPAlgorithm):
 
         self.inf_raster(fit_raster)
 
-        return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name, 'USED_SHP':self.used_shp_path}
+        return {
+            "OUTPUT_RASTER": self.dst_path,
+            "OUTPUT_LAYER_NAME": self.layer_name,
+            "USED_SHP": self.used_shp_path,
+        }
 
     def init_input_shp(self):
         samples_param = QgsProcessingParameterNumber(
             name=self.RANDOM_SAMPLES,
             description=self.tr(
-                'Random samples taken if input is not in point geometry'),
+                "Random samples taken if input is not in point geometry"
+            ),
             type=QgsProcessingParameterNumber.Integer,
             optional=True,
             minValue=0,
             defaultValue=500,
-            maxValue=100_000
+            maxValue=100_000,
         )
 
         self.addParameter(
             QgsProcessingParameterVectorLayer(
                 name=self.TEMPLATE,
-                description=self.tr(
-                    'Input shapefile path'),
-            # defaultValue=os.path.join(self.cwd,'assets',self.DEFAULT_TEMPLATE),
+                description=self.tr("Input shapefile path"),
+                # defaultValue=os.path.join(self.cwd,'assets',self.DEFAULT_TEMPLATE),
             ),
         )
 
         samples_param.setFlags(
-            samples_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
+            samples_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced
+        )
         self.addParameter(samples_param)
 
-
     def get_fit_raster(self):
-
         with rasterio.open(self.rlayer_path) as ds:
             gdf = self.gdf.to_crs(ds.crs)
             pixel_values = []
 
             transform = ds.transform
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = raster[self.input_bands,:,:]
+            raster = raster[self.input_bands, :, :]
 
             for index, data in gdf.iterrows():
                 # Get the coordinates of the point in the raster's pixel space
                 x, y = data.geometry.x, data.geometry.y
 
                 # Convert point coordinates to pixel coordinates within the window
-                col, row = ~transform * (x, y)  # Convert from map coordinates to pixel coordinates
+                col, row = ~transform * (
+                    x,
+                    y,
+                )  # Convert from map coordinates to pixel coordinates
                 col, row = int(col), int(row)
-                pixel_values.append(list(raster[:,row, col]))
+                pixel_values.append(list(raster[:, row, col]))
 
             return np.asarray(pixel_values)
 
-
     def inf_raster(self, fit_raster):
-
         with rasterio.open(self.rlayer_path) as ds:
-
             transform = ds.transform
             crs = ds.crs
             win = windows.from_bounds(
-                    self.extent.xMinimum(), 
-                    self.extent.yMinimum(), 
-                    self.extent.xMaximum(), 
-                    self.extent.yMaximum(), 
-                    transform=transform
-                    )
+                self.extent.xMinimum(),
+                self.extent.yMinimum(),
+                self.extent.xMaximum(),
+                self.extent.yMaximum(),
+                transform=transform,
+            )
             raster = ds.read(window=win)
             transform = ds.window_transform(win)
-            raster = raster[self.input_bands,:,:]
+            raster = raster[self.input_bands, :, :]
 
-            raster = np.transpose(raster, (1,2,0))
+            raster = np.transpose(raster, (1, 2, 0))
 
             template = torch.from_numpy(fit_raster).to(torch.float32)
             template = torch.mean(template, dim=0)
-        
+
             feat_img = torch.from_numpy(raster)
             cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
-        
-            sim = cos(feat_img,template)
+
+            sim = cos(feat_img, template)
             sim = sim.unsqueeze(-1)
             sim = sim.numpy()
             height, width, channels = sim.shape
-                        
-            with rasterio.open(self.dst_path, 'w', driver='GTiff',
-                               height=height, width=width, count=channels, dtype=self.out_dtype,
-                               crs=crs, transform=transform) as dst_ds:
-                dst_ds.write(np.transpose(sim, (2, 0, 1)))
-
 
+            with rasterio.open(
+                self.dst_path,
+                "w",
+                driver="GTiff",
+                height=height,
+                width=width,
+                count=channels,
+                dtype=self.out_dtype,
+                crs=crs,
+                transform=transform,
+            ) as dst_ds:
+                dst_ds.write(np.transpose(sim, (2, 0, 1)))
 
     def process_common_shp(self, parameters, context, feedback):
-
-        output_dir = self.parameterAsString(
-            parameters, self.OUTPUT, context)
+        output_dir = self.parameterAsString(parameters, self.OUTPUT, context)
         self.output_dir = Path(output_dir)
         self.output_dir.mkdir(parents=True, exist_ok=True)
 
-        self.seed = self.parameterAsInt(
-            parameters, self.RANDOM_SEED, context)
+        self.seed = self.parameterAsInt(parameters, self.RANDOM_SEED, context)
 
-        self.input_bands = [i_band -1 for i_band in self.selected_bands]
+        self.input_bands = [i_band - 1 for i_band in self.selected_bands]
 
-        template = self.parameterAsVectorLayer(
-            parameters, self.TEMPLATE, context)
+        template = self.parameterAsVectorLayer(parameters, self.TEMPLATE, context)
         self.template = template.dataProvider().dataSourceUri()
-        random_samples = self.parameterAsInt(
-            parameters, self.RANDOM_SAMPLES, context)
+        random_samples = self.parameterAsInt(parameters, self.RANDOM_SAMPLES, context)
 
         gdf = gpd.read_file(self.template)
         gdf = gdf.to_crs(self.crs.toWkt())
 
-        feedback.pushInfo(f'before sampling: {len(gdf)}')
+        feedback.pushInfo(f"before sampling: {len(gdf)}")
         ## If gdf is not point geometry, we take random samples in it
         gdf = get_random_samples_in_gdf(gdf, random_samples, seed=self.seed)
-        feedback.pushInfo(f'after samples:\n {len(gdf)}')
+        feedback.pushInfo(f"after samples:\n {len(gdf)}")
 
-        self.used_shp_path = os.path.join(self.output_dir, 'used.shp')
-        feedback.pushInfo(f'saving used dataframe to: {self.used_shp_path}')
+        self.used_shp_path = os.path.join(self.output_dir, "used.shp")
+        feedback.pushInfo(f"saving used dataframe to: {self.used_shp_path}")
         gdf.to_file(self.used_shp_path)
 
-
-        feedback.pushInfo(f'before extent: {len(gdf)}')
+        feedback.pushInfo(f"before extent: {len(gdf)}")
         bounds = box(
-                self.extent.xMinimum(), 
-                self.extent.yMinimum(), 
-                self.extent.xMaximum(), 
-                self.extent.yMaximum(), 
-                )
+            self.extent.xMinimum(),
+            self.extent.yMinimum(),
+            self.extent.xMaximum(),
+            self.extent.yMaximum(),
+        )
         self.gdf = gdf[gdf.within(bounds)]
-        feedback.pushInfo(f'after extent: {len(self.gdf)}')
+        feedback.pushInfo(f"after extent: {len(self.gdf)}")
 
         if len(self.gdf) == 0:
             feedback.pushWarning("No template points within extent !")
@@ -1020,36 +1031,37 @@ class SHPAlgorithm(IAMAPAlgorithm):
         rlayer_name, ext = os.path.splitext(rlayer_basename)
 
         ### Handle args that differ between clustering and projection methods
-        if self.TYPE == 'similarity':
-            self.dst_path, self.layer_name = get_unique_filename(self.output_dir, f'{self.TYPE}.tif', f'{rlayer_name} similarity')
+        if self.TYPE == "similarity":
+            self.dst_path, self.layer_name = get_unique_filename(
+                self.output_dir, f"{self.TYPE}.tif", f"{rlayer_name} similarity"
+            )
         else:
-            self.dst_path, self.layer_name = get_unique_filename(self.output_dir, f'{self.TYPE}.tif', f'{rlayer_name} ml')
+            self.dst_path, self.layer_name = get_unique_filename(
+                self.output_dir, f"{self.TYPE}.tif", f"{rlayer_name} ml"
+            )
 
         ## default to float32 until overriden if ML algo
-        self.out_dtype = 'float32'
-
+        self.out_dtype = "float32"
 
     # used to handle any thread-sensitive cleanup which is required by the algorithm.
     def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
         return {}
 
-if __name__ == "__main__":
-
 
-    methods = ['fit_predict', 'partial_fit'] 
+if __name__ == "__main__":
+    methods = ["fit_predict", "partial_fit"]
     algos = get_sklearn_algorithms_with_methods(cluster, methods)
     print(algos)
-    methods = ['predict', 'fit'] 
+    methods = ["predict", "fit"]
     algos = get_sklearn_algorithms_with_methods(cluster, methods)
     print(algos)
 
-    for algo in algos :
-        args = get_arguments(cluster,algo)
+    for algo in algos:
+        args = get_arguments(cluster, algo)
         print(algo, args)
 
-
-    methods = ['transform', 'fit'] 
+    methods = ["transform", "fit"]
     algos = get_sklearn_algorithms_with_methods(decomposition, methods)
-    for algo in algos :
-        args = get_arguments(decomposition,algo)
+    for algo in algos:
+        args = get_arguments(decomposition, algo)
         print(algo, args)
diff --git a/utils/geo.py b/utils/geo.py
index fb603eb33fa3bcdad9502440e3a2cd99d46f4c43..cfd410fe3db172be6855003991603ee9fc9f5203 100644
--- a/utils/geo.py
+++ b/utils/geo.py
@@ -1,17 +1,17 @@
+import os
 from typing import Callable, Union
-import sys
 import rasterio
+import rasterio.errors
 import geopandas as gpd
-import pandas as pd
 import numpy as np
-import warnings
-from rasterio.io import MemoryFile
 from rasterio.merge import merge
 
+
 def replace_nan_with_zero(array):
     array[array != array] = 0  # Replace NaN values with zero
     return array
 
+
 def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
     """Returns the average value pixel.
     cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
@@ -24,14 +24,15 @@ def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
     np.logical_and(merged_mask, mask, out=mask)
     np.copyto(merged_data, new_data, where=mask, casting="unsafe")
 
+
 def merge_tiles(
-        tiles:list, 
-        dst_path,
-        dtype:str = 'float32',
-        nodata=None,
-        #method:str | Callable ='first',
-        method: Union[str, Callable] = 'first',
-        ):
+    tiles: list,
+    dst_path,
+    dtype: str = "float32",
+    nodata=None,
+    # method:str | Callable ='first',
+    method: Union[str, Callable] = "first",
+):
     """
     cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
     """
@@ -41,41 +42,44 @@ def merge_tiles(
     # Extract individual bounds
     lefts, bottoms, rights, tops = zip(*extents)
     union_extent = (
-        min(lefts),     # Left
-        min(bottoms),   # Bottom
-        max(rights),    # Right
-        max(tops)       # Top
+        min(lefts),  # Left
+        min(bottoms),  # Bottom
+        max(rights),  # Right
+        max(tops),  # Top
     )
 
-    if method == 'average':
+    if method == "average":
         method = custom_method_avg
 
     # memfile = MemoryFile()
     try:
-        merge(sources=file_handler, # list of dataset objects opened in 'r' mode
-            bounds=union_extent, # tuple
-            nodata=nodata, # float
-            dtype=dtype, # dtype
+        merge(
+            sources=file_handler,  # list of dataset objects opened in 'r' mode
+            bounds=union_extent,  # tuple
+            nodata=nodata,  # float
+            dtype=dtype,  # dtype
             # resampling=Resampling.nearest,
-            method=method, # strategy to combine overlapping rasters
+            method=method,  # strategy to combine overlapping rasters
             # dst_path=memfile.name, # str or PathLike to save raster
             dst_path=dst_path,
             # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
-          )
+        )
     except TypeError:
-        merge(datasets=file_handler, # list of dataset objects opened in 'r' mode
-            bounds=union_extent, # tuple
-            nodata=nodata, # float
-            dtype=dtype, # dtype
+        merge(
+            datasets=file_handler,  # list of dataset objects opened in 'r' mode
+            bounds=union_extent,  # tuple
+            nodata=nodata,  # float
+            dtype=dtype,  # dtype
             # resampling=Resampling.nearest,
-            method=method, # strategy to combine overlapping rasters
+            method=method,  # strategy to combine overlapping rasters
             # dst_path=memfile.name, # str or PathLike to save raster
             dst_path=dst_path,
             # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
-          )
+        )
+
 
 def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_000):
-    '''
+    """
     Reads metadata or computes mean and sd of each band of a geotiff.
     If the metadata is not available, mean and standard deviation can be computed via numpy.
 
@@ -92,30 +96,29 @@ def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_00
         list of mean values per band
     sds : list
         list of standard deviation values per band
-    '''
+    """
 
     np.random.seed(42)
     src = rasterio.open(path)
     means = []
     sds = []
-    for band in range(1, src.count+1):
+    for band in range(1, src.count + 1):
         try:
             tags = src.tags(band)
-            if 'STATISTICS_MEAN' in tags and 'STATISTICS_STDDEV' in tags:
-                mean = float(tags['STATISTICS_MEAN'])
-                sd = float(tags['STATISTICS_STDDEV'])
+            if "STATISTICS_MEAN" in tags and "STATISTICS_STDDEV" in tags:
+                mean = float(tags["STATISTICS_MEAN"])
+                sd = float(tags["STATISTICS_STDDEV"])
                 means.append(mean)
                 sds.append(sd)
             else:
                 raise KeyError("Statistics metadata not found.")
 
         except KeyError:
-
             arr = src.read(band)
             arr = replace_nan_with_zero(arr)
             ## let subset by default for now
             if subset:
-                arr = np.random.choice(arr.flatten(), size=subset) 
+                arr = np.random.choice(arr.flatten(), size=subset)
             if ignore_zeros:
                 mean = np.ma.masked_equal(arr, 0).mean()
                 sd = np.ma.masked_equal(arr, 0).std()
@@ -128,7 +131,6 @@ def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_00
         except Exception as e:
             print(f"Error processing band {band}: {e}")
 
-
     src.close()
     return means, sds
 
@@ -136,19 +138,23 @@ def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_00
 def get_random_samples_in_gdf(gdf, num_samples, seed=42):
     ## if input is not point based, we take random samples in it
     if not all(gdf.geometry.geom_type == "Point"):
-
         ## Calculate the area of each polygon
         ## to determine the number of samples for each category
-        gdf['iamap_area'] = gdf.geometry.area
-        total_area = gdf['iamap_area'].sum()
-        gdf['iamap_sample_size'] = (gdf['iamap_area'] / total_area * num_samples).astype(int)
+        gdf["iamap_area"] = gdf.geometry.area
+        total_area = gdf["iamap_area"].sum()
+        gdf["iamap_sample_size"] = (
+            gdf["iamap_area"] / total_area * num_samples
+        ).astype(int)
 
         series = []
         # Sample polygons proportional to their size
         ## see https://geopandas.org/en/stable/docs/user_guide/sampling.html#Variable-number-of-points
         for idx, row in gdf.iterrows():
-
-            sampled_points = gdf.loc[gdf.index == idx].sample_points(size=row['iamap_sample_size'], rng=seed).explode(ignore_index=True)
+            sampled_points = (
+                gdf.loc[gdf.index == idx]
+                .sample_points(size=row["iamap_sample_size"], rng=seed)
+                .explode(ignore_index=True)
+            )
 
             for point in sampled_points:
                 new_row = row.copy()
@@ -157,28 +163,63 @@ def get_random_samples_in_gdf(gdf, num_samples, seed=42):
 
         point_gdf = gpd.GeoDataFrame(series, crs=gdf.crs)
         point_gdf.index = [i for i in range(len(point_gdf))]
-        del point_gdf['iamap_area']
-        del point_gdf['iamap_sample_size']
+        del point_gdf["iamap_area"]
+        del point_gdf["iamap_sample_size"]
 
         return point_gdf
-            
+
     return gdf
 
-def get_unique_col_name(gdf, base_name='fold'):
+
+def get_unique_col_name(gdf, base_name="fold"):
     column_name = base_name
     counter = 1
 
     # Check if the column already exists, if yes, keep updating the name
     while column_name in gdf.columns:
-        column_name = f'{base_name}{counter}'
+        column_name = f"{base_name}{counter}"
         counter += 1
 
     return column_name
 
 
+def validate_geotiff(output_file, expected_output_size=4428850, expected_wh=(60,24)):
+    """
+    tests geotiff validity by opening with rasterio,
+    checking if the file weights as expected and has the correct width and height.
+    Additionaly, it is checked if there is more than one value in the raster.
+    """
+
+    expected_size_min = .8*expected_output_size
+    expected_size_max = 1.2*expected_output_size
+    # 1. Check if the output file is a valid GeoTIFF
+    try:
+        with rasterio.open(output_file) as src:
+            assert src.meta['driver'] == 'GTiff', "File is not a valid GeoTIFF."
+            width = src.width
+            height = src.height
+            # 2. Read the data and check width/height
+            assert width == expected_wh[0], f"Expected width {expected_wh[0]}, got {width}."
+            assert height == expected_wh[1], f"Expected height {expected_wh[1]}, got {height}."
+            # 3. Read the data and check for unique values
+            data = src.read(1)  # Read the first band
+            unique_values = np.unique(data)
+
+            assert len(unique_values) > 1, "The GeoTIFF contains only one unique value."
+
+    except rasterio.errors.RasterioIOError:
+        print("The file could not be opened as a GeoTIFF, indicating it is invalid.")
+        assert False
+
+    # 4. Check if the file size is within the expected range
+    file_size = os.path.getsize(output_file)
+    assert expected_size_min <= file_size <= expected_size_max, (
+        f"File size {file_size} is outside the expected range."
+    )
+    return
+
 if __name__ == "__main__":
-    
-    gdf = gpd.read_file('assets/ml_poly.shp')
+    gdf = gpd.read_file("assets/ml_poly.shp")
     print(gdf)
     gdf = get_random_samples_in_gdf(gdf, 100)
     print(gdf)
diff --git a/utils/misc.py b/utils/misc.py
index 4ec762548fce14c58000d047ec66b515659e932f..3a4c47f07b9c3d7b924479688472b18a4f072493 100644
--- a/utils/misc.py
+++ b/utils/misc.py
@@ -1,6 +1,5 @@
 import shutil
 import psutil
-import time
 import os
 from pathlib import Path
 import torch
@@ -13,6 +12,7 @@ from PyQt5.QtCore import QVariant
 ## for hashing without using to much memory
 BUF_SIZE = 65536
 
+
 class QGISLogHandler(logging.Handler):
     def __init__(self, feedback):
         super().__init__()
@@ -22,38 +22,40 @@ class QGISLogHandler(logging.Handler):
         msg = self.format(record)
         self.feedback.pushInfo(msg)
 
+
 def get_model_size(model):
     torch.save(model.state_dict(), "temp.p")
-    size = os.path.getsize("temp.p")/1e6
-    os.remove('temp.p')
+    size = os.path.getsize("temp.p") / 1e6
+    os.remove("temp.p")
     return size
 
+
 def calculate_chunk_size(X, memory_buffer=0.1):
     # Estimate available memory
     available_memory = psutil.virtual_memory().available
-    
+
     # Estimate the memory footprint of one sample
     sample_memory = X[0].nbytes
-    
+
     # Determine maximum chunk size within available memory (leaving some buffer)
     max_chunk_size = int(available_memory * (1 - memory_buffer) / sample_memory)
-    
-    return max_chunk_size
 
+    return max_chunk_size
 
 
 def check_disk_space(path):
     # Get disk usage statistics about the given path
     total, used, free = shutil.disk_usage(path)
-    
+
     # Convert bytes to a more readable format (e.g., GB)
-    total_gb = total / (1024 ** 3)
-    used_gb = used / (1024 ** 3)
-    free_gb = free / (1024 ** 3)
-    
+    total_gb = total / (1024**3)
+    used_gb = used / (1024**3)
+    free_gb = free / (1024**3)
+
     return total_gb, used_gb, free_gb
 
-def get_dir_size(path='.'):
+
+def get_dir_size(path="."):
     total = 0
     with os.scandir(path) as it:
         for entry in it:
@@ -61,7 +63,8 @@ def get_dir_size(path='.'):
                 total += entry.stat().st_size
             elif entry.is_dir():
                 total += get_dir_size(entry.path)
-    return total / (1024 ** 3)
+    return total / (1024**3)
+
 
 def remove_files(file_paths):
     for file_path in file_paths:
@@ -73,13 +76,15 @@ def remove_files(file_paths):
         except Exception as e:
             print(f"Error removing {file_path}: {e}")
 
+
 def remove_files_with_extensions(directory, extensions):
     dir_path = Path(directory)
     for ext in extensions:
-        for file in dir_path.glob(f'*{ext}'):
+        for file in dir_path.glob(f"*{ext}"):
             file.unlink()  # Removes the file
 
-def get_unique_filename(directory, filename, layer_name='merged features'):
+
+def get_unique_filename(directory, filename, layer_name="merged features"):
     """
     Check if the filename exists in the given directory. If it does, append a numbered suffix.
     :param directory: The directory where the file will be saved.
@@ -90,7 +95,7 @@ def get_unique_filename(directory, filename, layer_name='merged features'):
     candidate = filename
     updated_layer_name = layer_name
     i = 1
-    
+
     # Check if file exists and update filename
     while os.path.exists(os.path.join(directory, candidate)):
         candidate = f"{base}-{i}{ext}"
@@ -99,19 +104,24 @@ def get_unique_filename(directory, filename, layer_name='merged features'):
 
     return os.path.join(directory, candidate), updated_layer_name
 
-def compute_md5_hash(parameters,keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES']):
-        param_encoder = {key: parameters[key] for key in parameters if key not in keys_to_remove}
-        return hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest()
+
+def compute_md5_hash(parameters, keys_to_remove=["MERGE_METHOD", "WORKERS", "PAUSES"]):
+    param_encoder = {
+        key: parameters[key] for key in parameters if key not in keys_to_remove
+    }
+    return hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest()
+
 
 def get_file_md5_hash(path):
-        md5 = hashlib.md5()
-        with open(path, 'rb') as f:
-            while True:
-                data = f.read(BUF_SIZE)
-                if not data:
-                    break
-                md5.update(data)
-        return md5.hexdigest()
+    md5 = hashlib.md5()
+    with open(path, "rb") as f:
+        while True:
+            data = f.read(BUF_SIZE)
+            if not data:
+                break
+            md5.update(data)
+    return md5.hexdigest()
+
 
 def convert_qvariant_obj(obj):
     if isinstance(obj, QVariant):
@@ -119,6 +129,7 @@ def convert_qvariant_obj(obj):
     else:
         return obj
 
+
 def convert_qvariant(obj):
     if isinstance(obj, QVariant):
         return obj.value()  # Extract the native Python value from QVariant
@@ -129,40 +140,41 @@ def convert_qvariant(obj):
     else:
         return obj
 
-def save_parameters_to_json(parameters, output_dir):
 
-    dst_path = os.path.join(output_dir, 'parameters.json')
+def save_parameters_to_json(parameters, output_dir):
+    dst_path = os.path.join(output_dir, "parameters.json")
     ## convert_qvariant does not work properly for 'CKPT'
     ## converting it to a str
-    converted_parameters = convert_qvariant(parameters) 
-    converted_parameters['CKPT'] = str(converted_parameters['CKPT'])
+    converted_parameters = convert_qvariant(parameters)
+    converted_parameters["CKPT"] = str(converted_parameters["CKPT"])
 
     with open(dst_path, "w") as json_file:
         json.dump(converted_parameters, json_file, indent=4)
 
 
 def log_parameters_to_csv(parameters, output_dir):
-
     # Compute the MD5 hash of the parameters
     params_hash = compute_md5_hash(parameters)
-    
+
     # Define the CSV file path
     csv_file_path = os.path.join(output_dir, "parameters.csv")
-    
+
     # Check if the CSV file exists
     file_exists = os.path.isfile(csv_file_path)
-    
+
     # Read the CSV file and check for the hash if it exists
     if file_exists:
-        with open(csv_file_path, mode='r', newline='') as csvfile:
+        with open(csv_file_path, mode="r", newline="") as csvfile:
             reader = csv.DictReader(csvfile)
             for row in reader:
-                if row['md5hash'] == params_hash:
+                if row["md5hash"] == params_hash:
                     return  # No need to add this set of parameters
 
     # If not already logged, append the new parameters
-    with open(csv_file_path, mode='a', newline='') as csvfile:
-        fieldnames = ['md5hash'] + list(parameters.keys())  # Columns: md5hash + parameter keys
+    with open(csv_file_path, mode="a", newline="") as csvfile:
+        fieldnames = ["md5hash"] + list(
+            parameters.keys()
+        )  # Columns: md5hash + parameter keys
         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
 
         # Write the header if the file is being created for the first time
@@ -170,7 +182,7 @@ def log_parameters_to_csv(parameters, output_dir):
             writer.writeheader()
 
         # Prepare the row with hash + parameters
-        row = {'md5hash': params_hash}
+        row = {"md5hash": params_hash}
         row.update(parameters)
 
         # Write the new row
diff --git a/utils/trch.py b/utils/trch.py
index be7c158f6e18fd6df0e990d9a7bd670f994af12d..47b0020939860568cb905cfd0a8842e276145f22 100644
--- a/utils/trch.py
+++ b/utils/trch.py
@@ -3,11 +3,10 @@ import torch.nn as nn
 
 
 def quantize_model(model, device):
-
     ## Dynamique quantization is not supported on CUDA, hence static conversion
-    if 'cuda' in device:
+    if "cuda" in device:
         # set quantization config for server (x86)
-        model.qconfig = torch.quantization.get_default_config('fbgemm')
+        model.qconfig = torch.quantization.get_default_config("fbgemm")
 
         # insert observers
         torch.quantization.prepare(model, inplace=True)
diff --git a/utils/trchg.py b/utils/trchg.py
index 122fdfc588839115a946f6f7e849ba1c3340454d..3a78e4b184a14d89990e1aea7a862adfa8042836 100644
--- a/utils/trchg.py
+++ b/utils/trchg.py
@@ -4,8 +4,8 @@ from collections.abc import Iterator
 from torchgeo.datasets import BoundingBox
 from torchgeo.samplers.utils import tile_to_chips
 
-class NoBordersGridGeoSampler(GridGeoSampler):
 
+class NoBordersGridGeoSampler(GridGeoSampler):
     def __iter__(self) -> Iterator[BoundingBox]:
         """
         Modification of original Torchgeo sampler to avoid overlapping borders of a dataset.