Skip to content
Snippets Groups Projects
geo.py 7.49 KiB
Newer Older
import os
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
from typing import Callable, Union
import rasterio
import rasterio.errors
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
import geopandas as gpd
import numpy as np
from rasterio.merge import merge

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def replace_nan_with_zero(array):
    array[array != array] = 0  # Replace NaN values with zero
    return array

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
    """Returns the average value pixel.
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """
    mask = np.empty_like(merged_mask, dtype="bool")
    np.logical_or(merged_mask, new_mask, out=mask)
    np.logical_not(mask, out=mask)
    np.nanmean([merged_data, new_data], axis=0, out=merged_data, where=mask)
    np.logical_not(new_mask, out=mask)
    np.logical_and(merged_mask, mask, out=mask)
    np.copyto(merged_data, new_data, where=mask, casting="unsafe")

def get_extents(raster_files):
    extents = []
    for file in raster_files:
        with rasterio.open(file, 'r') as src:
            bounds = src.bounds
            extents.append(bounds)
    return extents
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
def merge_tiles(
    tiles: list,
    dst_path,
    dtype: str = "float32",
    nodata=None,
    # method:str | Callable ='first',
    method: Union[str, Callable] = "first",
):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    """
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """

    file_handler = [rasterio.open(ds) for ds in tiles]
    extents = [ds.bounds for ds in file_handler]
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    # Extract individual bounds
    lefts, bottoms, rights, tops = zip(*extents)
    union_extent = (
        min(lefts),  # Left
        min(bottoms),  # Bottom
        max(rights),  # Right
        max(tops),  # Top
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    )

    if method == "average":
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
        method = custom_method_avg

    try:
        merge(
            sources=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            method=method,  # strategy to combine overlapping rasters
            dst_path=dst_path,
        )
    ## different rasterio versions take different keyword args
    except TypeError:
        merge(
            datasets=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            method=method,  # strategy to combine overlapping rasters
            dst_path=dst_path,
        )
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_000):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    Reads metadata or computes mean and sd of each band of a geotiff.
    If the metadata is not available, mean and standard deviation can be computed via numpy.

    Parameters
    ----------
    path : str
        path to a geotiff file
    ignore_zeros : boolean
        ignore zeros when computing mean and sd via numpy

    Returns
    -------
    means : list
        list of mean values per band
    sds : list
        list of standard deviation values per band
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    src = rasterio.open(path)
    means = []
    sds = []
    for band in range(1, src.count + 1):
        try:
            tags = src.tags(band)
            if "STATISTICS_MEAN" in tags and "STATISTICS_STDDEV" in tags:
                mean = float(tags["STATISTICS_MEAN"])
                sd = float(tags["STATISTICS_STDDEV"])
                means.append(mean)
                sds.append(sd)
            else:
                raise KeyError("Statistics metadata not found.")

        except KeyError:
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            arr = src.read(band)
            arr = replace_nan_with_zero(arr)
                arr = np.random.choice(arr.flatten(), size=subset)
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            if ignore_zeros:
                mean = np.ma.masked_equal(arr, 0).mean()
                sd = np.ma.masked_equal(arr, 0).std()
            else:
                mean = np.mean(arr)
                sd = np.std(arr)
            means.append(float(mean))
            sds.append(float(sd))

        except Exception as e:
            print(f"Error processing band {band}: {e}")

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    src.close()
    return means, sds

def get_random_samples_in_gdf(gdf, num_samples, seed=42):
    ## if input is not point based, we take random samples in it
    if not all(gdf.geometry.geom_type == "Point"):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
        ## Calculate the area of each polygon
        ## to determine the number of samples for each category
        gdf["iamap_area"] = gdf.geometry.area
        total_area = gdf["iamap_area"].sum()
        gdf["iamap_sample_size"] = (
            gdf["iamap_area"] / total_area * num_samples
        ).astype(int)
        # Sample polygons proportional to their size
        ## see https://geopandas.org/en/stable/docs/user_guide/sampling.html#Variable-number-of-points
            sampled_points = (
                gdf.loc[gdf.index == idx]
                .sample_points(size=row["iamap_sample_size"], rng=seed)
                .explode(ignore_index=True)
            )
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            for point in sampled_points:
                new_row = row.copy()
                new_row.geometry = point
                series.append(new_row)

        point_gdf = gpd.GeoDataFrame(series, crs=gdf.crs)
        point_gdf.index = [i for i in range(len(point_gdf))]
        del point_gdf["iamap_area"]
        del point_gdf["iamap_sample_size"]

def get_unique_col_name(gdf, base_name="fold"):
    column_name = base_name
    counter = 1

    # Check if the column already exists, if yes, keep updating the name
    while column_name in gdf.columns:
        column_name = f"{base_name}{counter}"
def validate_geotiff(output_file, expected_output_size=4428850, expected_wh=(60,24)):
    """
    tests geotiff validity by opening with rasterio,
    checking if the file weights as expected and has the correct width and height.
    Additionaly, it is checked if there is more than one value in the raster.
    """

    expected_size_min = .8*expected_output_size
    expected_size_max = 1.2*expected_output_size
    # 1. Check if the output file is a valid GeoTIFF
    try:
        with rasterio.open(output_file) as src:
            assert src.meta['driver'] == 'GTiff', "File is not a valid GeoTIFF."
            width = src.width
            height = src.height
            # 2. Read the data and check width/height
            assert width == expected_wh[0], f"Expected width {expected_wh[0]}, got {width}."
            assert height == expected_wh[1], f"Expected height {expected_wh[1]}, got {height}."
            # 3. Read the data and check for unique values
            data = src.read(1)  # Read the first band
            unique_values = np.unique(data)

            assert len(unique_values) > 1, "The GeoTIFF contains only one unique value."

    except rasterio.errors.RasterioIOError:
        print("The file could not be opened as a GeoTIFF, indicating it is invalid.")
        assert False

    # 4. Check if the file size is within the expected range
    file_size = os.path.getsize(output_file)
    assert expected_size_min <= file_size <= expected_size_max, (
        f"File size {file_size} is outside the expected range."
    )
    return

if __name__ == "__main__":
    gdf = gpd.read_file("assets/ml_poly.shp")
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    print(gdf)
    gdf = get_random_samples_in_gdf(gdf, 100)
    print(gdf)
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    print(len(gdf))