import os
from typing import Callable, Union
import rasterio
import rasterio.errors
import geopandas as gpd
import numpy as np
from rasterio.merge import merge


def replace_nan_with_zero(array):
    array[array != array] = 0  # Replace NaN values with zero
    return array


def custom_method_avg(merged_data, new_data, merged_mask, new_mask, **kwargs):
    """Returns the average value pixel.
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """
    mask = np.empty_like(merged_mask, dtype="bool")
    np.logical_or(merged_mask, new_mask, out=mask)
    np.logical_not(mask, out=mask)
    np.nanmean([merged_data, new_data], axis=0, out=merged_data, where=mask)
    np.logical_not(new_mask, out=mask)
    np.logical_and(merged_mask, mask, out=mask)
    np.copyto(merged_data, new_data, where=mask, casting="unsafe")

def get_extents(raster_files):
    extents = []
    for file in raster_files:
        with rasterio.open(file, 'r') as src:
            bounds = src.bounds
            extents.append(bounds)
    return extents

def merge_tiles(
    tiles: list,
    dst_path,
    dtype: str = "float32",
    nodata=None,
    # method:str | Callable ='first',
    method: Union[str, Callable] = "first",
):
    """
    cf. https://amanbagrecha.github.io/posts/2022-07-31-merge-rasters-the-modern-way-using-python/index.html
    """

    file_handler = [rasterio.open(ds) for ds in tiles]
    extents = [ds.bounds for ds in file_handler]
    # Extract individual bounds
    lefts, bottoms, rights, tops = zip(*extents)
    union_extent = (
        min(lefts),  # Left
        min(bottoms),  # Bottom
        max(rights),  # Right
        max(tops),  # Top
    )

    if method == "average":
        method = custom_method_avg

    try:
        merge(
            sources=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            method=method,  # strategy to combine overlapping rasters
            dst_path=dst_path,
        )
    ## different rasterio versions take different keyword args
    except TypeError:
        merge(
            datasets=file_handler,  # list of dataset objects opened in 'r' mode
            bounds=union_extent,  # tuple
            nodata=nodata,  # float
            dtype=dtype,  # dtype
            method=method,  # strategy to combine overlapping rasters
            dst_path=dst_path,
        )

    # close datasets
    for ds in file_handler:
        ds.close()


def get_mean_sd_by_band(path, force_compute=True, ignore_zeros=True, subset=1_000):
    """
    Reads metadata or computes mean and sd of each band of a geotiff.
    If the metadata is not available, mean and standard deviation can be computed via numpy.

    Parameters
    ----------
    path : str
        path to a geotiff file
    ignore_zeros : boolean
        ignore zeros when computing mean and sd via numpy

    Returns
    -------
    means : list
        list of mean values per band
    sds : list
        list of standard deviation values per band
    """

    np.random.seed(42)
    src = rasterio.open(path)
    means = []
    sds = []
    for band in range(1, src.count + 1):
        try:
            tags = src.tags(band)
            if "STATISTICS_MEAN" in tags and "STATISTICS_STDDEV" in tags:
                mean = float(tags["STATISTICS_MEAN"])
                sd = float(tags["STATISTICS_STDDEV"])
                means.append(mean)
                sds.append(sd)
            else:
                raise KeyError("Statistics metadata not found.")

        except KeyError:
            arr = src.read(band)
            arr = replace_nan_with_zero(arr)
            ## let subset by default for now
            if subset:
                arr = np.random.choice(arr.flatten(), size=subset)
            if ignore_zeros:
                mean = np.ma.masked_equal(arr, 0).mean()
                sd = np.ma.masked_equal(arr, 0).std()
            else:
                mean = np.mean(arr)
                sd = np.std(arr)
            means.append(float(mean))
            sds.append(float(sd))

        except Exception as e:
            print(f"Error processing band {band}: {e}")

    src.close()
    return means, sds


def get_random_samples_in_gdf(gdf, num_samples, seed=42):
    ## if input is not point based, we take random samples in it
    if not all(gdf.geometry.geom_type == "Point"):
        ## Calculate the area of each polygon
        ## to determine the number of samples for each category
        gdf["iamap_area"] = gdf.geometry.area
        total_area = gdf["iamap_area"].sum()
        gdf["iamap_sample_size"] = (
            gdf["iamap_area"] / total_area * num_samples
        ).astype(int)

        series = []
        # Sample polygons proportional to their size
        ## see https://geopandas.org/en/stable/docs/user_guide/sampling.html#Variable-number-of-points
        for idx, row in gdf.iterrows():
            sampled_points = (
                gdf.loc[gdf.index == idx]
                .sample_points(size=row["iamap_sample_size"], rng=seed)
                .explode(ignore_index=True)
            )

            for point in sampled_points:
                new_row = row.copy()
                new_row.geometry = point
                series.append(new_row)

        point_gdf = gpd.GeoDataFrame(series, crs=gdf.crs)
        point_gdf.index = [i for i in range(len(point_gdf))]
        del point_gdf["iamap_area"]
        del point_gdf["iamap_sample_size"]

        return point_gdf

    return gdf


def get_unique_col_name(gdf, base_name="fold"):
    column_name = base_name
    counter = 1

    # Check if the column already exists, if yes, keep updating the name
    while column_name in gdf.columns:
        column_name = f"{base_name}{counter}"
        counter += 1

    return column_name


def validate_geotiff(output_file, expected_output_size=4428850, expected_wh=(60,24)):
    """
    tests geotiff validity by opening with rasterio,
    checking if the file weights as expected and has the correct width and height.
    Additionaly, it is checked if there is more than one value in the raster.
    """

    expected_size_min = .8*expected_output_size
    expected_size_max = 1.2*expected_output_size
    # 1. Check if the output file is a valid GeoTIFF
    try:
        with rasterio.open(output_file) as src:
            assert src.meta['driver'] == 'GTiff', "File is not a valid GeoTIFF."
            width = src.width
            height = src.height
            # 2. Read the data and check width/height
            assert width == expected_wh[0], f"Expected width {expected_wh[0]}, got {width}."
            assert height == expected_wh[1], f"Expected height {expected_wh[1]}, got {height}."
            # 3. Read the data and check for unique values
            data = src.read(1)  # Read the first band
            unique_values = np.unique(data)

            assert len(unique_values) > 1, "The GeoTIFF contains only one unique value."

    except rasterio.errors.RasterioIOError:
        print("The file could not be opened as a GeoTIFF, indicating it is invalid.")
        assert False

    # 4. Check if the file size is within the expected range
    file_size = os.path.getsize(output_file)
    assert expected_size_min <= file_size <= expected_size_max, (
        f"File size {file_size} is outside the expected range."
    )
    return

if __name__ == "__main__":
    gdf = gpd.read_file("assets/ml_poly.shp")
    print(gdf)
    gdf = get_random_samples_in_gdf(gdf, 100)
    print(gdf)
    print(len(gdf))