Skip to content
Snippets Groups Projects
Commit 93fcfdfe authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

handle gt column as string

parent 0f56e308
No related branches found
No related tags found
No related merge requests found
...@@ -309,36 +309,26 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -309,36 +309,26 @@ class MLAlgorithm(SHPAlgorithm):
self.do_kfold = self.parameterAsBoolean( self.do_kfold = self.parameterAsBoolean(
parameters, self.DO_KFOLDS, context) parameters, self.DO_KFOLDS, context)
self.gt_col = self.parameterAsString( gt_col = self.parameterAsString(
parameters, self.GT_COL, context) parameters, self.GT_COL, context)
fold_col = self.parameterAsString( fold_col = self.parameterAsString(
parameters, self.FOLD_COL, context) parameters, self.FOLD_COL, context)
nfolds = self.parameterAsInt( nfolds = self.parameterAsInt(
parameters, self.NFOLDS, context) parameters, self.NFOLDS, context)
str_kwargs = self.parameterAsString( str_kwargs = self.parameterAsString(
parameters, self.SK_PARAM, context) parameters, self.SK_PARAM, context)
if str_kwargs != '':
self.passed_kwargs = ast.literal_eval(str_kwargs)
else:
self.passed_kwargs = {}
## If no test set is provided and the option to perform kfolds is true, we perform kfolds
## If a fold column is provided, this defines the folds. Otherwise, random split ## If a fold column is provided, this defines the folds. Otherwise, random split
## check that no column with name 'fold' exists, otherwise we use 'fold1' etc.. ## check that no column with name 'fold' exists, otherwise we use 'fold1' etc..
## we also make a new column containing gt values
self.fold_col = get_unique_col_name(self.gdf, 'fold') self.fold_col = get_unique_col_name(self.gdf, 'fold')
self.gt_col = get_unique_col_name(self.gdf, 'gt')
if self.test_gdf == None and self.do_kfold: ## Instantiate model
if fold_col.strip() != '' : if str_kwargs != '':
self.gdf[self.fold_col] = self.gdf[fold_col] self.passed_kwargs = ast.literal_eval(str_kwargs)
else:
self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
## Else, self.gdf is the train set
else: else:
self.train_gdf = self.gdf self.passed_kwargs = {}
method_idx = self.parameterAsEnum( method_idx = self.parameterAsEnum(
parameters, self.METHOD, context) parameters, self.METHOD, context)
...@@ -357,6 +347,29 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -357,6 +347,29 @@ class MLAlgorithm(SHPAlgorithm):
self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs) self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
## different behaviours if we are doing classification or regression
## If classification, we create a new col with unique integers for each classes
## to ease inference
self.task_type = check_model_type(self.model)
if self.task_type == 'classification':
self.out_dtype = 'int8'
self.gdf[self.gt_col] = pd.factorize(self.gdf[gt_col])[0] # unique int for each class
else:
self.gt_col = gt_col
## If no test set is provided and the option to perform kfolds is true, we perform kfolds
if self.test_gdf == None and self.do_kfold:
if fold_col.strip() != '' :
self.gdf[self.fold_col] = self.gdf[fold_col]
else:
self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
## Else, self.gdf is the train set
else:
self.train_gdf = self.gdf
def get_raster(self, mode='train'): def get_raster(self, mode='train'):
if mode == 'train': if mode == 'train':
...@@ -407,10 +420,8 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -407,10 +420,8 @@ class MLAlgorithm(SHPAlgorithm):
def get_metrics(self, test_gts, predictions, feedback): def get_metrics(self, test_gts, predictions, feedback):
task_type = check_model_type(self.model)
metrics_dict = {} metrics_dict = {}
if self.task_type == 'classification':
if task_type == 'classification':
# Evaluate the model # Evaluate the model
metrics_dict['accuracy'] = accuracy_score(test_gts, predictions) metrics_dict['accuracy'] = accuracy_score(test_gts, predictions)
metrics_dict['precision'] = precision_score(test_gts, predictions, average='weighted') # Modify `average` for multiclass if necessary metrics_dict['precision'] = precision_score(test_gts, predictions, average='weighted') # Modify `average` for multiclass if necessary
...@@ -418,10 +429,9 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -418,10 +429,9 @@ class MLAlgorithm(SHPAlgorithm):
metrics_dict['f1'] = f1_score(test_gts, predictions, average='weighted') metrics_dict['f1'] = f1_score(test_gts, predictions, average='weighted')
metrics_dict['conf_matrix'] = confusion_matrix(test_gts, predictions) metrics_dict['conf_matrix'] = confusion_matrix(test_gts, predictions)
metrics_dict['class_report'] = classification_report(test_gts, predictions) metrics_dict['class_report'] = classification_report(test_gts, predictions)
self.out_dtype = 'int8'
elif task_type == 'regression': elif self.task_type == 'regression':
metrics_dict['mae'] = mean_absolute_error(test_gts, predictions) metrics_dict['mae'] = mean_absolute_error(test_gts, predictions)
metrics_dict['mse'] = mean_squared_error(test_gts, predictions) metrics_dict['mse'] = mean_squared_error(test_gts, predictions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment