diff --git a/ml.py b/ml.py index d5c998e9bf20bde0647295f8fc1974a03f480c2f..824cf848319feaa523f6c38868803dd7fb2388e7 100644 --- a/ml.py +++ b/ml.py @@ -1,3 +1,4 @@ +import os import ast import numpy as np from pathlib import Path @@ -63,6 +64,7 @@ class MLAlgorithm(SHPAlgorithm): DO_KFOLDS = 'DO_KFOLDS' FOLD_COL = 'FOLD_COL' NFOLDS = 'NFOLDS' + SAVE_MODEL = 'SAVE_MODEL' SK_PARAM = 'SK_PARAM' TEMPLATE_TEST = 'TEMPLATE_TEST' METHOD = 'METHOD' @@ -159,8 +161,15 @@ class MLAlgorithm(SHPAlgorithm): maxValue=10 ) + save_param = QgsProcessingParameterBoolean( + self.SAVE_MODEL, + self.tr("Save model after fit."), + defaultValue=True + ) + for param in ( nfold_param, + save_param, ): param.setFlags( param.flags() | QgsProcessingParameterDefinition.FlagAdvanced) @@ -182,6 +191,7 @@ class MLAlgorithm(SHPAlgorithm): if self.do_kfold: best_metric = 0 + best_metrics_dict = {} for fold in sorted(self.gdf[self.fold_col].unique()): feedback.pushInfo(f'==== Fold {fold} ====') self.test_gdf = self.gdf.loc[self.gdf[self.fold_col] == fold] @@ -194,6 +204,7 @@ class MLAlgorithm(SHPAlgorithm): used_metric = metrics_dict['accuracy'] if used_metric >= best_metric: best_metric = used_metric + best_metrics_dict = metrics_dict self.best_model = self.model if (self.test_gdf is None) and not self.do_kfold: @@ -202,6 +213,20 @@ class MLAlgorithm(SHPAlgorithm): feedback.pushWarning(f'No test set was provided and no cross-validation is done, unable to assess model quality !') self.best_model = self.model + + feedback.pushInfo(f'Fitting done, saving model\n') + save_file = f'{self.method_name}.pkl'.lower() + metrics_save_file = f'{self.method_name}-metrics.json'.lower() + if self.save_model: + out_path = os.path.join(self.output_dir, save_file) + joblib.dump(self.best_model, out_path) + with open(os.path.join(self.output_dir, metrics_save_file), "w") as json_file: + ## confusion matrix is a np array that does not fit in a json + best_metrics_dict.pop('conf_matrix', None) + best_metrics_dict.pop('class_report', None) + print(best_metrics_dict) + json.dump(best_metrics_dict, json_file, indent=4) + self.infer_model(feedback) return {'OUTPUT_RASTER':self.dst_path, 'OUTPUT_LAYER_NAME':self.layer_name, 'USED_SHP':self.used_shp_path} @@ -257,17 +282,17 @@ class MLAlgorithm(SHPAlgorithm): def process_ml_shp(self, parameters, context, feedback): - template_test_layer = self.parameterAsVectorLayer( + template_test = self.parameterAsVectorLayer( parameters, self.TEMPLATE_TEST, context) - template_test = template_test_layer.dataProvider().dataSourceUri() + feedback.pushInfo(f'template_test: {template_test}') self.test_gdf=None - if template_test != '' : + if template_test is not None : random_samples = self.parameterAsInt( parameters, self.RANDOM_SAMPLES, context) - gdf = gpd.read_file(template_test) + gdf = gpd.read_file(template_test.dataProvider().dataSourceUri()) gdf = gdf.to_crs(self.crs.toWkt()) feedback.pushInfo(f'before samples: {len(gdf)}') @@ -291,6 +316,8 @@ class MLAlgorithm(SHPAlgorithm): def process_ml_options(self, parameters, context, feedback): + self.save_model = self.parameterAsBoolean( + parameters, self.SAVE_MODEL, context) self.do_kfold = self.parameterAsBoolean( parameters, self.DO_KFOLDS, context) gt_col = self.parameterAsString( @@ -518,7 +545,7 @@ class MLAlgorithm(SHPAlgorithm): should provide a basic description about what the algorithm does and the parameters and outputs associated with it.. """ - return self.tr(f"Fit a Machine Learning model using input template\n{self.get_help_sk_methods()}") + return self.tr(f"Fit a Machine Learning model using input template. Only RandomForestClassifier is throughfully tested. \n{self.get_help_sk_methods()}") def icon(self): return 'E'