Skip to content
Snippets Groups Projects
Commit fc13674a authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

fix TSNE inference

parent cb65989f
No related branches found
No related tags found
No related merge requests found
......@@ -600,12 +600,19 @@ class SKAlgorithm(IAMAPAlgorithm):
## some clustering algorithms need the entire dataset.
do_fit_predict = False
do_fit_transform = False
if not hasattr(model, "predict") and sk_module==cluster:
do_fit_predict = True
if not hasattr(model, "transform") and sk_module==manifold:
do_fit_transform = True
if do_fit_predict:
proj_img, model = self.fit_predict(model, feedback)
model = self.fit_predict(model, feedback)
if do_fit_transform:
model = self.fit_transform(model, feedback)
else:
iter = get_iter(model, fit_raster)
......@@ -619,7 +626,7 @@ class SKAlgorithm(IAMAPAlgorithm):
out_path = os.path.join(self.output_dir, save_file)
joblib.dump(model, out_path)
if not do_fit_predict:
if not (do_fit_predict or do_fit_transform):
feedback.pushInfo("Inference over raster\n")
self.infer_model(model, feedback, scaler)
......@@ -756,6 +763,52 @@ class SKAlgorithm(IAMAPAlgorithm):
return model
def fit_transform(self, model, feedback):
with rasterio.open(self.rlayer_path) as ds:
transform = ds.transform
crs = ds.crs
win = windows.from_bounds(
self.extent.xMinimum(),
self.extent.yMinimum(),
self.extent.xMaximum(),
self.extent.yMaximum(),
transform=transform,
)
raster = ds.read(window=win)
transform = ds.window_transform(win)
raster = np.transpose(raster, (1, 2, 0))
raster = raster[:, :, self.input_bands]
fit_raster = raster.reshape(-1, raster.shape[-1])
# raster = (raster-np.mean(raster))/np.std(raster)
scaler = StandardScaler()
fit_raster = scaler.fit_transform(fit_raster)
np.nan_to_num(fit_raster) # NaN to zero after normalisation
proj_img = model.fit_transform(fit_raster)
proj_img = proj_img.reshape((raster.shape[0], raster.shape[1], -1))
height, width, channels = proj_img.shape
feedback.pushInfo("Export to geotif\n")
with rasterio.open(
self.dst_path,
"w",
driver="GTiff",
height=height,
width=width,
count=channels,
dtype=self.out_dtype,
crs=crs,
transform=transform,
) as dst_ds:
dst_ds.write(np.transpose(proj_img, (2, 0, 1)))
feedback.pushInfo("Export to geotif done\n")
return model
def infer_model(self, model, feedback, scaler=None):
with rasterio.open(self.rlayer_path) as ds:
transform = ds.transform
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment