diff --git a/ml.py b/ml.py index 1f1bb3f73a73d7453cb9e1a921191a90436330ff..a2dc8cfc2c96e5f93378dd8a2f7eb2839bf08d97 100644 --- a/ml.py +++ b/ml.py @@ -1,4 +1,5 @@ import os +import ast import numpy as np from pathlib import Path from typing import Dict, Any @@ -58,6 +59,7 @@ class MLAlgorithm(SHPAlgorithm): DO_KFOLDS = 'DO_KFOLDS' FOLD_COL = 'FOLD_COL' NFOLDS = 'NFOLDS' + SK_PARAM = 'SK_PARAM' TEMPLATE_TEST = 'TEMPLATE_TEST' METHOD = 'METHOD' TMP_DIR = 'iamap_ml' @@ -85,6 +87,16 @@ class MLAlgorithm(SHPAlgorithm): ) ) + self.addParameter ( + QgsProcessingParameterString( + name = self.SK_PARAM, + description = self.tr( + 'Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments.'), + defaultValue = '', + optional=True, + ) + ) + self.addParameter( QgsProcessingParameterFile( @@ -220,6 +232,13 @@ class MLAlgorithm(SHPAlgorithm): ## If no test set is provided and the option to perform kfolds is true, ## we perform kfolds + str_kwargs = self.parameterAsString( + parameters, self.SK_PARAM, context) + + if str_kwargs != '': + self.passed_kwargs = ast.literal_eval(str_kwargs) + else: + self.passed_kwargs = {} ## If a fold column is provided, this defines the folds. Otherwise, random split if self.test_gdf == None and self.do_kfold: if fold_col != '': @@ -227,8 +246,21 @@ class MLAlgorithm(SHPAlgorithm): else: self.gdf['fold'] = np.random.randint(1, nfolds + 1, size=len(self.gdf)) print(self.gdf) + method_idx = self.parameterAsEnum( + parameters, self.METHOD, context) + self.method_name = self.method_opt[method_idx] + + try: + default_args = get_arguments(ensemble, self.method_name) + except AttributeError: + default_args = get_arguments(neighbors, self.method_name) + kwargs = self.update_kwargs(default_args) + try: + self.model = instantiate_sklearn_algorithm(ensemble, self.method_name, **kwargs) + except AttributeError: + self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs) def get_algorithms(self): required_methods = ['fit', 'predict']