From 93fcfdfe20cd1e054a0cf394dfb1433fc01c1b83 Mon Sep 17 00:00:00 2001 From: ptresson <paul.tresson@ird.fr> Date: Thu, 24 Oct 2024 10:50:04 +0200 Subject: [PATCH] handle gt column as string --- ml.py | 54 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/ml.py b/ml.py index 490bd4d..b0e0d52 100644 --- a/ml.py +++ b/ml.py @@ -309,36 +309,26 @@ class MLAlgorithm(SHPAlgorithm): self.do_kfold = self.parameterAsBoolean( parameters, self.DO_KFOLDS, context) - self.gt_col = self.parameterAsString( + gt_col = self.parameterAsString( parameters, self.GT_COL, context) fold_col = self.parameterAsString( parameters, self.FOLD_COL, context) nfolds = self.parameterAsInt( parameters, self.NFOLDS, context) - str_kwargs = self.parameterAsString( parameters, self.SK_PARAM, context) - if str_kwargs != '': - self.passed_kwargs = ast.literal_eval(str_kwargs) - else: - self.passed_kwargs = {} - - ## If no test set is provided and the option to perform kfolds is true, we perform kfolds ## If a fold column is provided, this defines the folds. Otherwise, random split - ## check that no column with name 'fold' exists, otherwise we use 'fold1' etc.. + ## we also make a new column containing gt values self.fold_col = get_unique_col_name(self.gdf, 'fold') + self.gt_col = get_unique_col_name(self.gdf, 'gt') - if self.test_gdf == None and self.do_kfold: - if fold_col.strip() != '' : - self.gdf[self.fold_col] = self.gdf[fold_col] - else: - self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf)) - - ## Else, self.gdf is the train set + ## Instantiate model + if str_kwargs != '': + self.passed_kwargs = ast.literal_eval(str_kwargs) else: - self.train_gdf = self.gdf + self.passed_kwargs = {} method_idx = self.parameterAsEnum( parameters, self.METHOD, context) @@ -357,6 +347,29 @@ class MLAlgorithm(SHPAlgorithm): self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs) + ## different behaviours if we are doing classification or regression + ## If classification, we create a new col with unique integers for each classes + ## to ease inference + self.task_type = check_model_type(self.model) + + if self.task_type == 'classification': + self.out_dtype = 'int8' + self.gdf[self.gt_col] = pd.factorize(self.gdf[gt_col])[0] # unique int for each class + else: + self.gt_col = gt_col + + + ## If no test set is provided and the option to perform kfolds is true, we perform kfolds + if self.test_gdf == None and self.do_kfold: + if fold_col.strip() != '' : + self.gdf[self.fold_col] = self.gdf[fold_col] + else: + self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf)) + ## Else, self.gdf is the train set + else: + self.train_gdf = self.gdf + + def get_raster(self, mode='train'): if mode == 'train': @@ -407,10 +420,8 @@ class MLAlgorithm(SHPAlgorithm): def get_metrics(self, test_gts, predictions, feedback): - task_type = check_model_type(self.model) metrics_dict = {} - - if task_type == 'classification': + if self.task_type == 'classification': # Evaluate the model metrics_dict['accuracy'] = accuracy_score(test_gts, predictions) metrics_dict['precision'] = precision_score(test_gts, predictions, average='weighted') # Modify `average` for multiclass if necessary @@ -418,10 +429,9 @@ class MLAlgorithm(SHPAlgorithm): metrics_dict['f1'] = f1_score(test_gts, predictions, average='weighted') metrics_dict['conf_matrix'] = confusion_matrix(test_gts, predictions) metrics_dict['class_report'] = classification_report(test_gts, predictions) - self.out_dtype = 'int8' - elif task_type == 'regression': + elif self.task_type == 'regression': metrics_dict['mae'] = mean_absolute_error(test_gts, predictions) metrics_dict['mse'] = mean_squared_error(test_gts, predictions) -- GitLab