Skip to content
Snippets Groups Projects
Commit ba84c60b authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

log some metrics after PCA and KMeans (other algorithms are untested for now). closes #31 for now

parent e5c03435
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import tempfile
import numpy as np
import inspect
import joblib
from collections import Counter
from pathlib import Path
from typing import Dict, Any
from qgis.core import (Qgis,
......@@ -29,6 +30,7 @@ import sklearn.decomposition as decomposition
import sklearn.cluster as cluster
from sklearn.base import BaseEstimator
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, silhouette_samples
if __name__ != "__main__":
from .misc import get_unique_filename, calculate_chunk_size
......@@ -632,6 +634,8 @@ class SKAlgorithm(IAMAPAlgorithm):
iter = get_iter(model, fit_raster)
model = self.fit_model(model, fit_raster, iter, feedback)
self.print_transform_metrics(model, feedback)
self.print_cluster_metrics(model,fit_raster, feedback)
feedback.pushInfo(f'Fitting done, saving model\n')
save_file = f'{self.method_name}.pkl'.lower()
if self.save_model:
......@@ -873,6 +877,37 @@ class SKAlgorithm(IAMAPAlgorithm):
return help_str
def print_transform_metrics(self, model, feedback):
"""
Log common metrics after a PCA.
"""
if hasattr(model, 'explained_variance_ratio_'):
# Explained variance ratio
explained_variance_ratio = model.explained_variance_ratio_
# Cumulative explained variance
cumulative_variance = np.cumsum(explained_variance_ratio)
# Loadings (Principal axes)
loadings = model.components_.T * np.sqrt(model.explained_variance_)
feedback.pushInfo(f'Explained Variance Ratio : \n{explained_variance_ratio}')
feedback.pushInfo(f'Cumulative Explained Variance : \n{cumulative_variance}')
feedback.pushInfo(f'Loadings (Principal axes) : \n{loadings}')
def print_cluster_metrics(self, model, fit_raster ,feedback):
"""
Log common metrics after a Kmeans.
"""
if hasattr(model, 'inertia_'):
feedback.pushInfo(f'Inertia : \n{model.inertia_}')
feedback.pushInfo(f'Cluster sizes : \n{Counter(model.labels_)}')
## silouhette score seem to heavy for now
# feedback.pushInfo(f'Silhouette Score : \n{silhouette_score(fit_raster, model.labels_)}')
# feedback.pushInfo(f'Silouhette Values : \n{silhouette_values(fit_raster, model.labels_)}')
# used to handle any thread-sensitive cleanup which is required by the algorithm.
def postProcessAlgorithm(self, context, feedback) -> Dict[str, Any]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment