From ee8f0f94f1b41147399ea64403d55d5776d6adf7 Mon Sep 17 00:00:00 2001
From: ptresson <paul.tresson@ird.fr>
Date: Thu, 3 Oct 2024 10:52:16 +0200
Subject: [PATCH] save parameters to json in the output_subdir

---
 encoder.py | 40 +++++++++++++++++++++++++++++-----------
 1 file changed, 29 insertions(+), 11 deletions(-)

diff --git a/encoder.py b/encoder.py
index de64f68..3a9307b 100644
--- a/encoder.py
+++ b/encoder.py
@@ -9,6 +9,7 @@ import shutil
 import numpy as np
 from pathlib import Path
 from typing import Dict, Any
+import json
 
 import rasterio
 from qgis.PyQt.QtCore import QCoreApplication
@@ -52,6 +53,7 @@ from .utils.misc import (QGISLogHandler,
                          remove_files, 
                          check_disk_space,
                          get_unique_filename,
+                         convert_qvariant,
                          )
 
 
@@ -331,6 +333,20 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
         """
         self.process_options(parameters, context, feedback)
 
+        ## compute parameters hash to have a unique identifier for the run
+        ## some parameters do not change the encoding part of the algorithm
+        keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES']
+        param_encoder = {key: parameters[key] for key in parameters if key not in keys_to_remove}
+
+        param_hash = hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest()
+        output_subdir = os.path.join(self.output_dir,param_hash)
+        output_subdir = Path(output_subdir)
+        output_subdir.mkdir(parents=True, exist_ok=True)
+        self.output_subdir = output_subdir
+        feedback.pushInfo(f'output_subdir: {output_subdir}')
+        self.save_parameters_to_json(parameters)
+        feedback.pushInfo(f'saving parameters to json file')
+
         RasterDataset.filename_glob = self.rlayer_name
         RasterDataset.all_bands = [
             self.rlayer.bandName(i_band) for i_band in range(1, self.rlayer.bandCount()+1)
@@ -457,17 +473,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
         feedback.pushInfo(f'Total batch num: {len(dataloader)}')
         feedback.pushInfo(f'\n\n{"-"*16}\nBegining inference \n{"-"*16}\n\n')
 
-        ## compute parameters hash to have a unique identifier for the run
-        ## some parameters do not change the encoding part of the algorithm
-        keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES']
-        param_encoder = {key: parameters[key] for key in parameters if key not in keys_to_remove}
-
-        param_hash = hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest()
-        output_subdir = os.path.join(self.output_dir,param_hash)
-        output_subdir = Path(output_subdir)
-        output_subdir.mkdir(parents=True, exist_ok=True)
-        self.output_subdir = output_subdir
-        feedback.pushInfo(f'output_subdir: {output_subdir}')
 
 
         last_batch_done = self.get_last_batch_done()
@@ -637,6 +642,19 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
 
         return {"Output feature path": self.output_subdir, 'Patch samples saved': self.iPatch, 'OUTPUT_RASTER':dst_path, 'OUTPUT_LAYER_NAME':layer_name}
 
+    def save_parameters_to_json(self, parameters):
+        dst_path = os.path.join(self.output_subdir, 'parameters.json')
+        ## convert_qvariant does not work properly for 'CKPT'
+        ## converting it to a str
+        converted_parameters = convert_qvariant(parameters) 
+        print(parameters)
+        converted_parameters['CKPT'] = str(converted_parameters['CKPT'])
+
+        for key, item in converted_parameters.items():
+            print(key, type(item))
+        with open(dst_path, "w") as json_file:
+            json.dump(converted_parameters, json_file, indent=4)
+
     def remove_temp_files(self):
         """
         cleaning up temp tiles
-- 
GitLab