From ee8f0f94f1b41147399ea64403d55d5776d6adf7 Mon Sep 17 00:00:00 2001 From: ptresson <paul.tresson@ird.fr> Date: Thu, 3 Oct 2024 10:52:16 +0200 Subject: [PATCH] save parameters to json in the output_subdir --- encoder.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/encoder.py b/encoder.py index de64f68..3a9307b 100644 --- a/encoder.py +++ b/encoder.py @@ -9,6 +9,7 @@ import shutil import numpy as np from pathlib import Path from typing import Dict, Any +import json import rasterio from qgis.PyQt.QtCore import QCoreApplication @@ -52,6 +53,7 @@ from .utils.misc import (QGISLogHandler, remove_files, check_disk_space, get_unique_filename, + convert_qvariant, ) @@ -331,6 +333,20 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): """ self.process_options(parameters, context, feedback) + ## compute parameters hash to have a unique identifier for the run + ## some parameters do not change the encoding part of the algorithm + keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES'] + param_encoder = {key: parameters[key] for key in parameters if key not in keys_to_remove} + + param_hash = hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest() + output_subdir = os.path.join(self.output_dir,param_hash) + output_subdir = Path(output_subdir) + output_subdir.mkdir(parents=True, exist_ok=True) + self.output_subdir = output_subdir + feedback.pushInfo(f'output_subdir: {output_subdir}') + self.save_parameters_to_json(parameters) + feedback.pushInfo(f'saving parameters to json file') + RasterDataset.filename_glob = self.rlayer_name RasterDataset.all_bands = [ self.rlayer.bandName(i_band) for i_band in range(1, self.rlayer.bandCount()+1) @@ -457,17 +473,6 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): feedback.pushInfo(f'Total batch num: {len(dataloader)}') feedback.pushInfo(f'\n\n{"-"*16}\nBegining inference \n{"-"*16}\n\n') - ## compute parameters hash to have a unique identifier for the run - ## some parameters do not change the encoding part of the algorithm - keys_to_remove = ['MERGE_METHOD', 'WORKERS', 'PAUSES'] - param_encoder = {key: parameters[key] for key in parameters if key not in keys_to_remove} - - param_hash = hashlib.md5(str(param_encoder).encode("utf-8")).hexdigest() - output_subdir = os.path.join(self.output_dir,param_hash) - output_subdir = Path(output_subdir) - output_subdir.mkdir(parents=True, exist_ok=True) - self.output_subdir = output_subdir - feedback.pushInfo(f'output_subdir: {output_subdir}') last_batch_done = self.get_last_batch_done() @@ -637,6 +642,19 @@ class EncoderAlgorithm(QgsProcessingAlgorithm): return {"Output feature path": self.output_subdir, 'Patch samples saved': self.iPatch, 'OUTPUT_RASTER':dst_path, 'OUTPUT_LAYER_NAME':layer_name} + def save_parameters_to_json(self, parameters): + dst_path = os.path.join(self.output_subdir, 'parameters.json') + ## convert_qvariant does not work properly for 'CKPT' + ## converting it to a str + converted_parameters = convert_qvariant(parameters) + print(parameters) + converted_parameters['CKPT'] = str(converted_parameters['CKPT']) + + for key, item in converted_parameters.items(): + print(key, type(item)) + with open(dst_path, "w") as json_file: + json.dump(converted_parameters, json_file, indent=4) + def remove_temp_files(self): """ cleaning up temp tiles -- GitLab