diff --git a/ml.py b/ml.py
index 1f1bb3f73a73d7453cb9e1a921191a90436330ff..a2dc8cfc2c96e5f93378dd8a2f7eb2839bf08d97 100644
--- a/ml.py
+++ b/ml.py
@@ -1,4 +1,5 @@
 import os
+import ast
 import numpy as np
 from pathlib import Path
 from typing import Dict, Any
@@ -58,6 +59,7 @@ class MLAlgorithm(SHPAlgorithm):
     DO_KFOLDS = 'DO_KFOLDS'
     FOLD_COL = 'FOLD_COL'
     NFOLDS = 'NFOLDS'
+    SK_PARAM = 'SK_PARAM'
     TEMPLATE_TEST = 'TEMPLATE_TEST'
     METHOD = 'METHOD'
     TMP_DIR = 'iamap_ml'
@@ -85,6 +87,16 @@ class MLAlgorithm(SHPAlgorithm):
             )
         )
 
+        self.addParameter (
+            QgsProcessingParameterString(
+                name = self.SK_PARAM,
+                description = self.tr(
+                    'Arguments for the initialisation of the algorithm. If empty this goes to sklearn default. It will overwrite cluster or components arguments.'),
+                defaultValue = '',
+                optional=True,
+            )
+        )
+
 
         self.addParameter(
             QgsProcessingParameterFile(
@@ -220,6 +232,13 @@ class MLAlgorithm(SHPAlgorithm):
 
         ## If no test set is provided and the option to perform kfolds is true,
         ## we perform kfolds
+        str_kwargs = self.parameterAsString(
+                parameters, self.SK_PARAM, context)
+
+        if str_kwargs != '':
+            self.passed_kwargs = ast.literal_eval(str_kwargs)
+        else:
+            self.passed_kwargs = {}
         ## If a fold column is provided, this defines the folds. Otherwise, random split
         if self.test_gdf == None and self.do_kfold:
             if fold_col != '':
@@ -227,8 +246,21 @@ class MLAlgorithm(SHPAlgorithm):
             else:
                 self.gdf['fold'] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
                 print(self.gdf)
+        method_idx = self.parameterAsEnum(
+            parameters, self.METHOD, context)
+        self.method_name = self.method_opt[method_idx]
+
+        try:
+            default_args = get_arguments(ensemble, self.method_name)
+        except AttributeError:
+            default_args = get_arguments(neighbors, self.method_name)
 
+        kwargs = self.update_kwargs(default_args)
 
+        try:
+            self.model = instantiate_sklearn_algorithm(ensemble, self.method_name, **kwargs)
+        except AttributeError:
+            self.model = instantiate_sklearn_algorithm(neighbors, self.method_name, **kwargs)
 
     def get_algorithms(self):
         required_methods = ['fit', 'predict']