Skip to content
Snippets Groups Projects
Commit da6c8f5c authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

get metrics (and add missing update_kwargs() method)

parent bd0f7824
No related branches found
No related tags found
No related merge requests found
...@@ -33,7 +33,6 @@ from qgis.core import (Qgis, ...@@ -33,7 +33,6 @@ from qgis.core import (Qgis,
import torch import torch
import torch.nn as nn import torch.nn as nn
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
...@@ -49,6 +48,21 @@ from .utils.algo import ( ...@@ -49,6 +48,21 @@ from .utils.algo import (
import sklearn.ensemble as ensemble import sklearn.ensemble as ensemble
import sklearn.neighbors as neighbors import sklearn.neighbors as neighbors
from sklearn.base import ClassifierMixin, RegressorMixin from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
confusion_matrix,
classification_report
)
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
r2_score,
)
def check_model_type(model): def check_model_type(model):
if isinstance(model, ClassifierMixin): if isinstance(model, ClassifierMixin):
return "classification" return "classification"
...@@ -269,6 +283,54 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -269,6 +283,54 @@ class MLAlgorithm(SHPAlgorithm):
except AttributeError: except AttributeError:
self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs) self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
def update_kwargs(self, kwargs_dict):
for key, value in self.passed_kwargs.items():
if key in kwargs_dict.keys():
kwargs_dict[key] = value
return kwargs_dict
def get_metrics(self, test_gts, predictions, feedback):
task_type = check_model_type(self.model)
if task_type == 'classification':
# Evaluate the model
accuracy = accuracy_score(test_gts, predictions)
precision = precision_score(test_gts, predictions, average='weighted') # Modify `average` for multiclass if necessary
recall = recall_score(test_gts, predictions, average='weighted')
f1 = f1_score(test_gts, predictions, average='weighted')
conf_matrix = confusion_matrix(test_gts, predictions)
class_report = classification_report(test_gts, predictions)
feedback.pushInfo(f'Accuracy:\t {accuracy}')
feedback.pushInfo(f'Precision:\t {precision}')
feedback.pushInfo(f'Recall:\t {recall}')
feedback.pushInfo(f'F1-Score:\t {f1}')
feedback.pushInfo(f'Confusion Matrix:\n {conf_matrix}')
feedback.pushInfo(f'Classification Report:\n {class_report}')
elif task_type == 'regression':
pass
mae = mean_absolute_error(test_gts, predictions)
mse = mean_squared_error(test_gts, predictions)
rmse = np.sqrt(mse)
r2 = r2_score(test_gts, predictions)
feedback.pushInfo(f'MAE:\t {mae}')
feedback.pushInfo(f'MSE:\t {mse}')
feedback.pushInfo(f'RMSE:\t {rmse}')
feedback.pushInfo(f'R2 Score:\t {r2}')
else:
feedback.pushWarning('Unable to evaluate the model !!')
def get_algorithms(self): def get_algorithms(self):
required_methods = ['fit', 'predict'] required_methods = ['fit', 'predict']
ensemble_algos = get_sklearn_algorithms_with_methods(ensemble, required_methods) ensemble_algos = get_sklearn_algorithms_with_methods(ensemble, required_methods)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment