From 93fcfdfe20cd1e054a0cf394dfb1433fc01c1b83 Mon Sep 17 00:00:00 2001
From: ptresson <paul.tresson@ird.fr>
Date: Thu, 24 Oct 2024 10:50:04 +0200
Subject: [PATCH] handle gt column as string

---
 ml.py | 54 ++++++++++++++++++++++++++++++++----------------------
 1 file changed, 32 insertions(+), 22 deletions(-)

diff --git a/ml.py b/ml.py
index 490bd4d..b0e0d52 100644
--- a/ml.py
+++ b/ml.py
@@ -309,36 +309,26 @@ class MLAlgorithm(SHPAlgorithm):
 
         self.do_kfold = self.parameterAsBoolean(
             parameters, self.DO_KFOLDS, context)
-        self.gt_col = self.parameterAsString(
+        gt_col = self.parameterAsString(
             parameters, self.GT_COL, context)
         fold_col = self.parameterAsString(
             parameters, self.FOLD_COL, context)
         nfolds = self.parameterAsInt(
             parameters, self.NFOLDS, context)
-
         str_kwargs = self.parameterAsString(
                 parameters, self.SK_PARAM, context)
 
-        if str_kwargs != '':
-            self.passed_kwargs = ast.literal_eval(str_kwargs)
-        else:
-            self.passed_kwargs = {}
-
-        ## If no test set is provided and the option to perform kfolds is true, we perform kfolds
         ## If a fold column is provided, this defines the folds. Otherwise, random split
-
         ## check that no column with name 'fold' exists, otherwise we use 'fold1' etc..
+        ## we also make a new column containing gt values
         self.fold_col = get_unique_col_name(self.gdf, 'fold')
+        self.gt_col = get_unique_col_name(self.gdf, 'gt')
 
-        if self.test_gdf == None and self.do_kfold:
-            if fold_col.strip() != '' :
-                self.gdf[self.fold_col] = self.gdf[fold_col]
-            else:
-                self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
-
-        ## Else, self.gdf is the train set
+        ## Instantiate model
+        if str_kwargs != '':
+            self.passed_kwargs = ast.literal_eval(str_kwargs)
         else:
-            self.train_gdf = self.gdf
+            self.passed_kwargs = {}
 
         method_idx = self.parameterAsEnum(
             parameters, self.METHOD, context)
@@ -357,6 +347,29 @@ class MLAlgorithm(SHPAlgorithm):
             self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
 
 
+        ## different behaviours if we are doing classification or regression
+        ## If classification, we create a new col with unique integers for each classes
+        ## to ease inference
+        self.task_type = check_model_type(self.model)
+        
+        if self.task_type == 'classification':
+            self.out_dtype = 'int8'
+            self.gdf[self.gt_col] = pd.factorize(self.gdf[gt_col])[0] # unique int for each class
+        else:
+            self.gt_col = gt_col
+
+
+        ## If no test set is provided and the option to perform kfolds is true, we perform kfolds
+        if self.test_gdf == None and self.do_kfold:
+            if fold_col.strip() != '' :
+                self.gdf[self.fold_col] = self.gdf[fold_col]
+            else:
+                self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
+        ## Else, self.gdf is the train set
+        else:
+            self.train_gdf = self.gdf
+
+
     def get_raster(self, mode='train'):
 
         if mode == 'train':
@@ -407,10 +420,8 @@ class MLAlgorithm(SHPAlgorithm):
 
     def get_metrics(self, test_gts, predictions, feedback):
 
-        task_type = check_model_type(self.model)
         metrics_dict = {}
-
-        if task_type == 'classification':
+        if self.task_type == 'classification':
             # Evaluate the model
             metrics_dict['accuracy'] = accuracy_score(test_gts, predictions)
             metrics_dict['precision'] = precision_score(test_gts, predictions, average='weighted')  # Modify `average` for multiclass if necessary
@@ -418,10 +429,9 @@ class MLAlgorithm(SHPAlgorithm):
             metrics_dict['f1'] = f1_score(test_gts, predictions, average='weighted')
             metrics_dict['conf_matrix'] = confusion_matrix(test_gts, predictions)
             metrics_dict['class_report'] = classification_report(test_gts, predictions)
-            self.out_dtype = 'int8'
 
 
-        elif task_type == 'regression':
+        elif self.task_type == 'regression':
 
             metrics_dict['mae'] = mean_absolute_error(test_gts, predictions)
             metrics_dict['mse'] = mean_squared_error(test_gts, predictions)
-- 
GitLab