Skip to content
Snippets Groups Projects
Commit f475d532 authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

save model to pkl and metrics to json

parent a82f8dd6
No related branches found
No related tags found
No related merge requests found
import os
import ast
import numpy as np
from pathlib import Path
......@@ -63,6 +64,7 @@ class MLAlgorithm(SHPAlgorithm):
DO_KFOLDS = 'DO_KFOLDS'
FOLD_COL = 'FOLD_COL'
NFOLDS = 'NFOLDS'
SAVE_MODEL = 'SAVE_MODEL'
SK_PARAM = 'SK_PARAM'
TEMPLATE_TEST = 'TEMPLATE_TEST'
METHOD = 'METHOD'
......@@ -159,8 +161,15 @@ class MLAlgorithm(SHPAlgorithm):
maxValue=10
)
save_param = QgsProcessingParameterBoolean(
self.SAVE_MODEL,
self.tr("Save model after fit."),
defaultValue=True
)
for param in (
nfold_param,
save_param,
):
param.setFlags(
param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
......@@ -182,6 +191,7 @@ class MLAlgorithm(SHPAlgorithm):
if self.do_kfold:
best_metric = 0
best_metrics_dict = {}
for fold in sorted(self.gdf[self.fold_col].unique()):
feedback.pushInfo(f'==== Fold {fold} ====')
self.test_gdf = self.gdf.loc[self.gdf[self.fold_col] == fold]
......@@ -194,6 +204,7 @@ class MLAlgorithm(SHPAlgorithm):
used_metric = metrics_dict['accuracy']
if used_metric >= best_metric:
best_metric = used_metric
best_metrics_dict = metrics_dict
self.best_model = self.model
if (self.test_gdf is None) and not self.do_kfold:
......@@ -202,6 +213,20 @@ class MLAlgorithm(SHPAlgorithm):
feedback.pushWarning(f'No test set was provided and no cross-validation is done, unable to assess model quality !')
self.best_model = self.model
feedback.pushInfo(f'Fitting done, saving model\n')
save_file = f'{self.method_name}.pkl'.lower()
metrics_save_file = f'{self.method_name}-metrics.json'.lower()
if self.save_model:
out_path = os.path.join(self.output_dir, save_file)
joblib.dump(self.best_model, out_path)
with open(os.path.join(self.output_dir, metrics_save_file), "w") as json_file:
## confusion matrix is a np array that does not fit in a json
best_metrics_dict.pop('conf_matrix', None)
best_metrics_dict.pop('class_report', None)
print(best_metrics_dict)
json.dump(best_metrics_dict, json_file, indent=4)
self.infer_model(feedback)
return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name, 'USED_SHP':self.used_shp_path}
......@@ -257,17 +282,17 @@ class MLAlgorithm(SHPAlgorithm):
def process_ml_shp(self, parameters, context, feedback):
template_test_layer = self.parameterAsVectorLayer(
template_test = self.parameterAsVectorLayer(
parameters, self.TEMPLATE_TEST, context)
template_test = template_test_layer.dataProvider().dataSourceUri()
feedback.pushInfo(f'template_test: {template_test}')
self.test_gdf=None
if template_test != '' :
if template_test is not None :
random_samples = self.parameterAsInt(
parameters, self.RANDOM_SAMPLES, context)
gdf = gpd.read_file(template_test)
gdf = gpd.read_file(template_test.dataProvider().dataSourceUri())
gdf = gdf.to_crs(self.crs.toWkt())
feedback.pushInfo(f'before samples: {len(gdf)}')
......@@ -291,6 +316,8 @@ class MLAlgorithm(SHPAlgorithm):
def process_ml_options(self, parameters, context, feedback):
self.save_model = self.parameterAsBoolean(
parameters, self.SAVE_MODEL, context)
self.do_kfold = self.parameterAsBoolean(
parameters, self.DO_KFOLDS, context)
gt_col = self.parameterAsString(
......@@ -518,7 +545,7 @@ class MLAlgorithm(SHPAlgorithm):
should provide a basic description about what the algorithm does and the
parameters and outputs associated with it..
"""
return self.tr(f"Fit a Machine Learning model using input template\n{self.get_help_sk_methods()}")
return self.tr(f"Fit a Machine Learning model using input template. Only RandomForestClassifier is throughfully tested. \n{self.get_help_sk_methods()}")
def icon(self):
return 'E'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment