Skip to content
Snippets Groups Projects
geo.py 6 KiB
Newer Older
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
from typing import Callable, Union
import rasterio
import geopandas as gpd
import numpy as np
from rasterio.merge import merge

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def replace_nan_with_zero(array):
    array[array != array] = 0  # Replace NaN values with zero
    return array

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
    """Returns the average value pixel.
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """
    mask = np.empty_like(merged_mask, dtype="bool")
    np.logical_or(merged_mask, new_mask, out=mask)
    np.logical_not(mask, out=mask)
    np.nanmean([merged_data, new_data], axis=0, out=merged_data, where=mask)
    np.logical_not(new_mask, out=mask)
    np.logical_and(merged_mask, mask, out=mask)
    np.copyto(merged_data, new_data, where=mask, casting="unsafe")

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def merge_tiles(
    tiles: list,
    dst_path,
    dtype: str = "float32",
    nodata=None,
    # method:str | Callable ='first',
    method: Union[str, Callable] = "first",
):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    """
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """

    file_handler = [rasterio.open(ds) for ds in tiles]
    extents = [ds.bounds for ds in file_handler]
    # Extract individual bounds
    lefts, bottoms, rights, tops = zip(*extents)
    union_extent = (
        min(lefts),  # Left
        min(bottoms),  # Bottom
        max(rights),  # Right
        max(tops),  # Top
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    )

    if method == "average":
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
        method = custom_method_avg

    # memfile = MemoryFile()
        merge(
            sources=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            # resampling=Resampling.nearest,
            method=method,  # strategy to combine overlapping rasters
            # dst_path=memfile.name, # str or PathLike to save raster
            dst_path=dst_path,
            # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
        merge(
            datasets=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            # resampling=Resampling.nearest,
            method=method,  # strategy to combine overlapping rasters
            # dst_path=memfile.name, # str or PathLike to save raster
            dst_path=dst_path,
            # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_000):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    Reads metadata or computes mean and sd of each band of a geotiff.
    If the metadata is not available, mean and standard deviation can be computed via numpy.

    Parameters
    ----------
    path : str
        path to a geotiff file
    ignore_zeros : boolean
        ignore zeros when computing mean and sd via numpy

    Returns
    -------
    means : list
        list of mean values per band
    sds : list
        list of standard deviation values per band
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    src = rasterio.open(path)
    means = []
    sds = []
    for band in range(1, src.count + 1):
        try:
            tags = src.tags(band)
            if "STATISTICS_MEAN" in tags and "STATISTICS_STDDEV" in tags:
                mean = float(tags["STATISTICS_MEAN"])
                sd = float(tags["STATISTICS_STDDEV"])
                means.append(mean)
                sds.append(sd)
            else:
                raise KeyError("Statistics metadata not found.")

        except KeyError:
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            arr = src.read(band)
            arr = replace_nan_with_zero(arr)
                arr = np.random.choice(arr.flatten(), size=subset)
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            if ignore_zeros:
                mean = np.ma.masked_equal(arr, 0).mean()
                sd = np.ma.masked_equal(arr, 0).std()
            else:
                mean = np.mean(arr)
                sd = np.std(arr)
            means.append(float(mean))
            sds.append(float(sd))

        except Exception as e:
            print(f"Error processing band {band}: {e}")

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    src.close()
    return means, sds

def get_random_samples_in_gdf(gdf, num_samples, seed=42):
    ## if input is not point based, we take random samples in it
    if not all(gdf.geometry.geom_type == "Point"):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
        ## Calculate the area of each polygon
        ## to determine the number of samples for each category
        gdf["iamap_area"] = gdf.geometry.area
        total_area = gdf["iamap_area"].sum()
        gdf["iamap_sample_size"] = (
            gdf["iamap_area"] / total_area * num_samples
        ).astype(int)
        # Sample polygons proportional to their size
        ## see https://geopandas.org/en/stable/docs/user_guide/sampling.html#Variable-number-of-points
            sampled_points = (
                gdf.loc[gdf.index == idx]
                .sample_points(size=row["iamap_sample_size"], rng=seed)
                .explode(ignore_index=True)
            )
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            for point in sampled_points:
                new_row = row.copy()
                new_row.geometry = point
                series.append(new_row)

        point_gdf = gpd.GeoDataFrame(series, crs=gdf.crs)
        point_gdf.index = [i for i in range(len(point_gdf))]
        del point_gdf["iamap_area"]
        del point_gdf["iamap_sample_size"]

def get_unique_col_name(gdf, base_name="fold"):
    column_name = base_name
    counter = 1

    # Check if the column already exists, if yes, keep updating the name
    while column_name in gdf.columns:
        column_name = f"{base_name}{counter}"
if __name__ == "__main__":
    gdf = gpd.read_file("assets/ml_poly.shp")
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    print(gdf)
    gdf = get_random_samples_in_gdf(gdf, 100)
    print(gdf)
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    print(len(gdf))