Skip to content
Snippets Groups Projects
geo.py 5.47 KiB
Newer Older
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
from typing import Callable, Union
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
import rasterio
import geopandas as gpd
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
import numpy as np
import warnings
from rasterio.io import MemoryFile
from rasterio.merge import merge

def replace_nan_with_zero(array):
    array[array != array] = 0  # Replace NaN values with zero
    return array

def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
    """Returns the average value pixel.
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """
    mask = np.empty_like(merged_mask, dtype="bool")
    np.logical_or(merged_mask, new_mask, out=mask)
    np.logical_not(mask, out=mask)
    np.nanmean([merged_data, new_data], axis=0, out=merged_data, where=mask)
    np.logical_not(new_mask, out=mask)
    np.logical_and(merged_mask, mask, out=mask)
    np.copyto(merged_data, new_data, where=mask, casting="unsafe")

def merge_tiles(
        tiles:list, 
        dst_path,
        dtype:str = 'float32',
        nodata=None,
        #method:str | Callable ='first',
        method: Union[str, Callable] = 'first',
        ):
    """
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """

    file_handler = [rasterio.open(ds) for ds in tiles]
    extents = [ds.bounds for ds in file_handler]
    # Extract individual bounds
    lefts, bottoms, rights, tops = zip(*extents)
    union_extent = (
        min(lefts),     # Left
        min(bottoms),   # Bottom
        max(rights),    # Right
        max(tops)       # Top
    )

    if method == 'average':
        method = custom_method_avg

    # memfile = MemoryFile()
    try:
        merge(sources=file_handler, # list of dataset objects opened in 'r' mode
            bounds=union_extent, # tuple
            nodata=nodata, # float
            dtype=dtype, # dtype
            # resampling=Resampling.nearest,
            method=method, # strategy to combine overlapping rasters
            # dst_path=memfile.name, # str or PathLike to save raster
            dst_path=dst_path,
            # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
          )
    except TypeError:
        merge(datasets=file_handler, # list of dataset objects opened in 'r' mode
            bounds=union_extent, # tuple
            nodata=nodata, # float
            dtype=dtype, # dtype
            # resampling=Resampling.nearest,
            method=method, # strategy to combine overlapping rasters
            # dst_path=memfile.name, # str or PathLike to save raster
            dst_path=dst_path,
            # dst_kwds={'blockysize':512, 'blockxsize':512} # Dictionary
          )
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_000):
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    '''
    Reads metadata or computes mean and sd of each band of a geotiff.
    If the metadata is not available, mean and standard deviation can be computed via numpy.

    Parameters
    ----------
    path : str
        path to a geotiff file
    ignore_zeros : boolean
        ignore zeros when computing mean and sd via numpy

    Returns
    -------
    means : list
        list of mean values per band
    sds : list
        list of standard deviation values per band
    '''

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
    src = rasterio.open(path)
    means = []
    sds = []
    for band in range(1, src.count+1):
        try:
            tags = src.tags(band)
            if 'STATISTICS_MEAN' in tags and 'STATISTICS_STDDEV' in tags:
                mean = float(tags['STATISTICS_MEAN'])
                sd = float(tags['STATISTICS_STDDEV'])
                means.append(mean)
                sds.append(sd)
            else:
                raise KeyError("Statistics metadata not found.")

        except KeyError:

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            arr = src.read(band)
            arr = replace_nan_with_zero(arr)
            ## let subset by default for now
            if subset:
                arr = np.random.choice(arr.flatten(), size=subset) 
paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed
            if ignore_zeros:
                mean = np.ma.masked_equal(arr, 0).mean()
                sd = np.ma.masked_equal(arr, 0).std()
            else:
                mean = np.mean(arr)
                sd = np.std(arr)
            means.append(float(mean))
            sds.append(float(sd))

        except Exception as e:
            print(f"Error processing band {band}: {e}")

paul.tresson_ird.fr's avatar
paul.tresson_ird.fr committed

    src.close()
    return means, sds


def get_random_samples_in_gdf(gdf, num_samples):
    ## if input is not point based, we take random samples in it
    if not all(gdf.geometry.geom_type == "Point"):
        # Take a random sample of the data
        # gdf = gdf.sample(n=random_samples)
        # Calculate the area of each polygon
        gdf['area'] = gdf.geometry.area

        # Total area
        total_area = gdf['area'].sum()

        # Calculate the proportion of samples for each polygon based on its area
        gdf['sample_size'] = (gdf['area'] / total_area * 100).astype(int)

        # Initialize a list to store sampled polygons
        sampled_polygons_list = []

        # Sample polygons proportional to their size
        for idx, row in gdf.iterrows():
            # Ensure you don't exceed the population size
            num_samples = min(row['sample_size'], len(gdf))
            
            # Append the samples to the list
            sampled_polygons_list.append(gdf.sample(n=num_samples, replace=False))

        # Combine the sampled polygons
        gdf = gpd.GeoDataFrame(pd.concat(sampled_polygons_list, ignore_index=True), crs=gdf.crs)
    return gdf