# -*- coding: UTF-8 -*-
# Python
"""
04-07-2023
@author: jeremy auclair

Calculate NDVI images with xarray
"""

import os  # for path exploration
import csv  # open csv files
from fnmatch import fnmatch  # for character string comparison
from typing import List, Union  # to declare variables
import xarray as xr  # to manage dataset
import pandas as pd  # to manage dataframes
import rasterio as rio  # to open geotiff files
import geopandas as gpd  # to manage shapefile crs projections
from shapely.geometry import box  # to create boundary box
from modspa_pixel.config.config import config  # to import config file
from modspa_pixel.preprocessing.input_toolbox import product_str_to_datetime


def calculate_ndvi(extracted_paths: Union[List[str], str], save_dir: str, boundary_shapefile_path: str, resolution: int = 20, chunk_size: dict = {'x': 4096, 'y': 4096, 'time': 2}, acorvi_corr: int = 500) -> str:
    """
    Calculate ndvi images in a xarray dataset (a data cube) and save it.
    ndvi values are scaled and saved as `uint8` (0 to 255).
    
    .. warning:: Current version for Copernicus Sentinel-2 images

    Arguments
    =========

    1. extracted_paths: ``Union[List[str], str]``
        list of paths to extracted sentinel-2 products
        or path to ``csv`` file containing those paths
    2. save_dir: ``str``
        directory in which to save the ndvi pre-cube
    3. boundary_shapefile_path: ``str``
        shapefile path to save the geographical bounds
        of the ndvi tile, used later to download weather
        products
    4. resolution: ``int`` ``default = 20``
        resolution in meters
        
        .. warning:: only 10 and 20 meters currently supported
        
        
    5. chunk_size: ``dict`` ``default = {'x': 4096, 'y': 4096, 'time': 2}``
        dictionnary containing the chunk size for
        the xarray dask calculation
    6. acorvi_corr: ``int`` ``default = 500``
        acorvi correction parameter to add to the red band
        to adjust ndvi values

    Returns
    =======

    1. ndvi_cube_path: ``str``
        path to save the ndvi pre-cube
    """
    
    # Check resolution for Sentinel-2
    if not resolution in [10, 20]:
        print('Resolution should be equal to 10 or 20 meters for sentinel-2')
        return None

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
    if type(extracted_paths) == str:
        with open(extracted_paths, 'r') as file:
            extracted_paths = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
                extracted_paths.append(row[0])
    
    # Sort by band
    red_paths = []
    nir_paths = []
    mask_paths = []
    if resolution == 10:
        for product in extracted_paths:
            if fnmatch(product, '*_B04_10m*'):
                red_paths.append(product)
            elif fnmatch(product, '*_B08_10m*'):
                nir_paths.append(product)
            elif fnmatch(product, '*_SCL_20m*'):
                mask_paths.append(product)
    else:
        for product in extracted_paths:
            if fnmatch(product, '*_B04_20m*'):
                red_paths.append(product)
            elif fnmatch(product, '*_B8A_20m*'):
                nir_paths.append(product)
            elif fnmatch(product, '*_SCL_20m*'):
                mask_paths.append(product)
    
    # Create boundary shapefile from Sentinel-2 image for weather download
    ra = rio.open(red_paths[0])
    bounds  = ra.bounds
    geom = box(*bounds)
    df = gpd.GeoDataFrame({"id":1,"geometry":[geom]})
    df.crs = ra.crs
    df.geometry = df.geometry.to_crs('epsg:4326')
    df.to_file(boundary_shapefile_path)
    del df, ra, geom, bounds

    # Sort and get dates
    red_paths.sort()
    nir_paths.sort()
    mask_paths.sort()
    dates = [product_str_to_datetime(prod) for prod in red_paths]
    
    # Open datasets with xarray
    red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'})
    nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'})
    mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'})
    if resolution == 10:
        mask = xr.where((mask == 4) | (mask == 5), 1, 0).interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest')
    else:
        mask = xr.where((mask == 4) | (mask == 5), 1, 0)

    # Set time coordinate
    red['time'] = pd.to_datetime(dates)
    nir['time'] = pd.to_datetime(dates)
    mask['time'] = pd.to_datetime(dates)

    # Create ndvi dataset and calculate ndvi
    ndvi = red
    ndvi = ndvi.drop('red')
    ndvi['NDVI'] = (((nir.nir - red.red - acorvi_corr)/(nir.nir + red.red + acorvi_corr))*mask.mask)
    del red, nir, mask
    
    # Mask and scale ndvi
    ndvi['NDVI'] = xr.where(ndvi.NDVI < 0, 0, ndvi.NDVI)
    ndvi['NDVI'] = xr.where(ndvi.NDVI > 1, 1, ndvi.NDVI)
    ndvi['NDVI'] = (ndvi.NDVI*255)
        
    # Write attributes
    ndvi['NDVI'].attrs['units'] = 'None'
    ndvi['NDVI'].attrs['standard_name'] = 'NDVI'
    ndvi['NDVI'].attrs['description'] = 'Normalized difference Vegetation Index (of the near infrared and red band). A value of one is a high vegetation presence.'
    ndvi['NDVI'].attrs['scale factor'] = '255'
    
    # Create save path
    ndvi_cube_path = save_dir + os.sep + 'NDVI_precube_' + dates[0].strftime('%d-%m-%Y') + '_' + dates[-1].strftime('%d-%m-%Y') + '.nc'
    
    # Save NDVI cube to netcdf
    ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": (4, 1024, 1024)}}) #, 'zlib': True, "complevel": 4}})
    ndvi.close()
    
    return ndvi_cube_path


def interpolate_ndvi(ndvi_path: str, save_dir: str, config_file: str, chunk_size: dict = {'x': 512, 'y': 512, 'time': -1}) -> str:
    """
    Interpolate the ndvi cube to a daily frequency between the
    desired dates defined in the ``json`` config file.

    Arguments
    =========

    ndvi_path: ``str``
        path to ndvi pre-cube
    save_dir: ``str``
        path to save interpolated ndvi cube
    config_file: ``str``
        path to ``json`` config file
    chunk_size: ``dict`` ``default = {'x': 512, 'y': 512, 'time': -1}``
        chunk size to use by dask for calculation,
        ``'time' = -1`` means the chunk has the whole
        time dimension in it. The Dataset can't be
        divided along the time axis for interpolation.

    Returns
    =======

    ``None``
    """
    
    # Open config_file
    config_params = config(config_file)
    
    # Open NDVI pre-cube
    ndvi = xr.open_dataset(ndvi_path, chunks = chunk_size)
    
    # Get Spatial reference
    spatial_ref = ndvi.spatial_ref.load()
    
    # Sort images by time
    ndvi = ndvi.sortby('time')
    
    # Interpolates on a daily frequency
    daily_index = pd.date_range(start = config_params.start_date, end = config_params.end_date, freq = 'D')

    # Resample the dataset to a daily frequency and reindex with the new DateTimeIndex
    ndvi = ndvi.resample(time = '1D').asfreq().reindex(time = daily_index)
    dates = pd.to_datetime(ndvi.time.values)

    # Interpolate the dataset along the time dimension to fill nan values
    ndvi = ndvi.interpolate_na(dim = 'time', method = 'linear', fill_value = 'extrapolate').round(decimals = 0)
    
    # Set negative values as 0
    ndvi = xr.where(ndvi < 0, 0, ndvi.NDVI)
    # Set values above 255 (ndvi > 1) as 255 (ndvi = 1)
    ndvi = xr.where(ndvi > 255, 255, ndvi.NDVI)
    
    # Rewrite spatial reference
    ndvi['spatial_ref'] = spatial_ref
    
    # Create save path
    ndvi_cube_path = save_dir + os.sep + 'NDVI_cube_' + dates[0].strftime('%d-%m-%Y') + '_' + dates[-1].strftime('%d-%m-%Y') + '.nc'
    
    # Save NDVI cube to netcdf
    ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": (4, 1024, 1024)}}) #, 'zlib': True, "complevel": 4}})
    ndvi.close()
    
    return ndvi_cube_path