Skip to content
Snippets Groups Projects
Commit 2abd27e2 authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

instantiate model with kwargs

parent a2b5151e
No related branches found
No related tags found
No related merge requests found
import os import os
import ast
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
...@@ -58,6 +59,7 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -58,6 +59,7 @@ class MLAlgorithm(SHPAlgorithm):
DO_KFOLDS = 'DO_KFOLDS' DO_KFOLDS = 'DO_KFOLDS'
FOLD_COL = 'FOLD_COL' FOLD_COL = 'FOLD_COL'
NFOLDS = 'NFOLDS' NFOLDS = 'NFOLDS'
SK_PARAM = 'SK_PARAM'
TEMPLATE_TEST = 'TEMPLATE_TEST' TEMPLATE_TEST = 'TEMPLATE_TEST'
METHOD = 'METHOD' METHOD = 'METHOD'
TMP_DIR = 'iamap_ml' TMP_DIR = 'iamap_ml'
...@@ -85,6 +87,16 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -85,6 +87,16 @@ class MLAlgorithm(SHPAlgorithm):
) )
) )
self.addParameter (
QgsProcessingParameterString(
name = self.SK_PARAM,
description = self.tr(
'Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments.'),
defaultValue = '',
optional=True,
)
)
self.addParameter( self.addParameter(
QgsProcessingParameterFile( QgsProcessingParameterFile(
...@@ -220,6 +232,13 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -220,6 +232,13 @@ class MLAlgorithm(SHPAlgorithm):
## If no test set is provided and the option to perform kfolds is true, ## If no test set is provided and the option to perform kfolds is true,
## we perform kfolds ## we perform kfolds
str_kwargs = self.parameterAsString(
parameters, self.SK_PARAM, context)
if str_kwargs != '':
self.passed_kwargs = ast.literal_eval(str_kwargs)
else:
self.passed_kwargs = {}
## If a fold column is provided, this defines the folds. Otherwise, random split ## If a fold column is provided, this defines the folds. Otherwise, random split
if self.test_gdf == None and self.do_kfold: if self.test_gdf == None and self.do_kfold:
if fold_col != '': if fold_col != '':
...@@ -227,8 +246,21 @@ class MLAlgorithm(SHPAlgorithm): ...@@ -227,8 +246,21 @@ class MLAlgorithm(SHPAlgorithm):
else: else:
self.gdf['fold'] = np.random.randint(1, nfolds + 1, size=len(self.gdf)) self.gdf['fold'] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
print(self.gdf) print(self.gdf)
method_idx = self.parameterAsEnum(
parameters, self.METHOD, context)
self.method_name = self.method_opt[method_idx]
try:
default_args = get_arguments(ensemble, self.method_name)
except AttributeError:
default_args = get_arguments(neighbors, self.method_name)
kwargs = self.update_kwargs(default_args)
try:
self.model = instantiate_sklearn_algorithm(ensemble, self.method_name, **kwargs)
except AttributeError:
self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
def get_algorithms(self): def get_algorithms(self):
required_methods = ['fit', 'predict'] required_methods = ['fit', 'predict']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment