Skip to content
Snippets Groups Projects
Commit 44e36714 authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

fix seed in random samples (rather than hardcoded) and fold definition

parent 14c3b653
No related branches found
No related tags found
No related merge requests found
......@@ -223,7 +223,6 @@ class MLAlgorithm(SHPAlgorithm):
## confusion matrix is a np array that does not fit in a json
best_metrics_dict.pop('conf_matrix', None)
best_metrics_dict.pop('class_report', None)
print(best_metrics_dict)
json.dump(best_metrics_dict, json_file, indent=4)
self.infer_model(feedback)
......@@ -296,7 +295,7 @@ class MLAlgorithm(SHPAlgorithm):
feedback.pushInfo(f'before samples: {len(gdf)}')
## get random samples if geometry is not point based
gdf = get_random_samples_in_gdf(gdf, random_samples)
gdf = get_random_samples_in_gdf(gdf, random_samples, seed=self.seed)
feedback.pushInfo(f'before extent: {len(gdf)}')
bounds = box(
......@@ -374,6 +373,7 @@ class MLAlgorithm(SHPAlgorithm):
if fold_col.strip() != '' :
self.gdf[self.fold_col] = self.gdf[fold_col]
else:
np.random.seed(self.seed)
self.gdf[self.fold_col] = np.random.randint(1, nfolds + 1, size=len(self.gdf))
## Else, self.gdf is the train set
else:
......
......@@ -59,7 +59,7 @@ class TestSimAlgorithm(TestReductionAlgorithm):
class TestMLAlgorithm(TestReductionAlgorithm):
algorithm = MLAlgorithm()
default_parameters = {'INPUT': INPUT,'OUTPUT': OUTPUT,'TEMPLATE':TEMPLATE_RF,'GT_COL': GT_COL}
possible_hashes = ['514fc247e4765ca34895a3d6cb9bffd6']
possible_hashes = ['bd22d66180347e043fca58d494876184']
out_name = 'ml.tif'
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment