From f475d532c05b5f57f7929cef7a4dcbb86392ba08 Mon Sep 17 00:00:00 2001
From: ptresson <paul.tresson@ird.fr>
Date: Thu, 24 Oct 2024 15:10:25 +0200
Subject: [PATCH] save model to pkl and metrics to json

---
 ml.py | 37 ++++++++++++++++++++++++++++++++-----
 1 file changed, 32 insertions(+), 5 deletions(-)

diff --git a/ml.py b/ml.py
index d5c998e..824cf84 100644
--- a/ml.py
+++ b/ml.py
@@ -1,3 +1,4 @@
+import os
 import ast
 import numpy as np
 from pathlib import Path
@@ -63,6 +64,7 @@ class MLAlgorithm(SHPAlgorithm):
     DO_KFOLDS = 'DO_KFOLDS'
     FOLD_COL = 'FOLD_COL'
     NFOLDS = 'NFOLDS'
+    SAVE_MODEL = 'SAVE_MODEL'
     SK_PARAM = 'SK_PARAM'
     TEMPLATE_TEST = 'TEMPLATE_TEST'
     METHOD = 'METHOD'
@@ -159,8 +161,15 @@ class MLAlgorithm(SHPAlgorithm):
             maxValue=10
         )
 
+        save_param = QgsProcessingParameterBoolean(
+                self.SAVE_MODEL,
+                self.tr("Save model after fit."),
+                defaultValue=True
+                )
+
         for param in (
                 nfold_param,
+                save_param,
                 ):
             param.setFlags(
                 param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
@@ -182,6 +191,7 @@ class MLAlgorithm(SHPAlgorithm):
 
         if self.do_kfold:
             best_metric = 0
+            best_metrics_dict = {}
             for fold in sorted(self.gdf[self.fold_col].unique()):
                 feedback.pushInfo(f'==== Fold {fold} ====')
                 self.test_gdf = self.gdf.loc[self.gdf[self.fold_col] == fold]
@@ -194,6 +204,7 @@ class MLAlgorithm(SHPAlgorithm):
                     used_metric = metrics_dict['accuracy']
                 if used_metric >= best_metric:
                     best_metric = used_metric
+                    best_metrics_dict = metrics_dict
                     self.best_model = self.model
 
         if (self.test_gdf is None) and not self.do_kfold:
@@ -202,6 +213,20 @@ class MLAlgorithm(SHPAlgorithm):
             feedback.pushWarning(f'No test set was provided and no cross-validation is done, unable to assess model quality !')
             self.best_model = self.model
 
+        
+        feedback.pushInfo(f'Fitting done, saving model\n')
+        save_file = f'{self.method_name}.pkl'.lower()
+        metrics_save_file = f'{self.method_name}-metrics.json'.lower()
+        if self.save_model:
+            out_path = os.path.join(self.output_dir, save_file)
+            joblib.dump(self.best_model, out_path)
+            with open(os.path.join(self.output_dir, metrics_save_file), "w") as json_file:
+                ## confusion matrix is a np array that does not fit in a json
+                best_metrics_dict.pop('conf_matrix', None)
+                best_metrics_dict.pop('class_report', None)
+                print(best_metrics_dict)
+                json.dump(best_metrics_dict, json_file, indent=4)
+
         self.infer_model(feedback)
 
         return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name, 'USED_SHP':self.used_shp_path}
@@ -257,17 +282,17 @@ class MLAlgorithm(SHPAlgorithm):
 
     def process_ml_shp(self, parameters, context, feedback):
 
-        template_test_layer = self.parameterAsVectorLayer(
+        template_test = self.parameterAsVectorLayer(
             parameters, self.TEMPLATE_TEST, context)
-        template_test = template_test_layer.dataProvider().dataSourceUri()
+        feedback.pushInfo(f'template_test: {template_test}')
 
         self.test_gdf=None
 
-        if template_test != '' :
+        if template_test is not None :
             random_samples = self.parameterAsInt(
                 parameters, self.RANDOM_SAMPLES, context)
 
-            gdf = gpd.read_file(template_test)
+            gdf = gpd.read_file(template_test.dataProvider().dataSourceUri())
             gdf = gdf.to_crs(self.crs.toWkt())
 
             feedback.pushInfo(f'before samples: {len(gdf)}')
@@ -291,6 +316,8 @@ class MLAlgorithm(SHPAlgorithm):
 
     def process_ml_options(self, parameters, context, feedback):
 
+        self.save_model = self.parameterAsBoolean(
+                parameters, self.SAVE_MODEL, context)
         self.do_kfold = self.parameterAsBoolean(
             parameters, self.DO_KFOLDS, context)
         gt_col = self.parameterAsString(
@@ -518,7 +545,7 @@ class MLAlgorithm(SHPAlgorithm):
         should provide a basic description about what the algorithm does and the
         parameters and outputs associated with it..
         """
-        return self.tr(f"Fit a Machine Learning model using input template\n{self.get_help_sk_methods()}")
+        return self.tr(f"Fit a Machine Learning model using input template. Only RandomForestClassifier is throughfully tested. \n{self.get_help_sk_methods()}")
 
     def icon(self):
         return 'E'
-- 
GitLab