# -*- coding: UTF-8 -*-
# Python
"""
04-07-2023
@author: jeremy auclair

Calculate NDVI images with xarray
"""

import os  # for path management
import csv  # open csv files
from fnmatch import fnmatch  # for character string comparison
from typing import List, Union  # to declare variables
import xarray as xr  # to manage dataset
import pandas as pd  # to manage dataframes
import rasterio as rio  # to open geotiff files
import numpy as np  # vectorized math
import geopandas as gpd  # to manage shapefile crs projections
import datetime  # for date management
from scipy.ndimage import zoom  # to rescale different resolution images
from shapely.geometry import box  # to create boundary box
from p_tqdm import p_map  # for multiprocessing with progress bars
from psutil import cpu_count  # to get number of physical cores available
from modspa_pixel.config.config import config  # to import config file
from modspa_pixel.preprocessing.input_toolbox import product_str_to_datetime, read_product_info


def calculate_ndvi(extracted_paths: Union[List[str], str], config_file: str, chunk_size: dict = {'x': 1000, 'y': 1000, 'time': -1}, scaling: int = 255, acorvi_corr: int = 0) -> str:
    """
    Calculate ndvi images in a xarray dataset (a data cube) and save it.
    ndvi values are scaled and saved as ``uint8`` (0 to 255).
    
    .. warning:: Current version for Copernicus Sentinel-2 images
    
    .. warning:: only 10 and 20 meters currently supported

    Arguments
    =========

    1. extracted_paths: ``Union[List[str], str]``
        list of paths to extracted sentinel-2 products
        or path to ``csv``` file containing those paths
    2. config_file: ``str``
        path to configuration file
    3. chunk_size: ``dict`` ``default = {'x': 1000, 'y': 1000, 'time': -1}``
        dictionnary containing the chunk size for
        the xarray dask calculation
    4. scaling: ``int`` ``default = 255``
        integer scaling to save NDVI data as integer values
    5. acorvi_corr: ``int`` ``default = 0``
        acorvi correction parameter to add to the red band
        to adjust ndvi values for THEIA

    Returns
    =======

    1. ndvi_precube_path: ``str``
        path to save the ndvi pre-cube
    """
    
    # Open config_file
    config_params = config(config_file)
    
    # Load parameters from config file
    boundary_shapefile_path = config_params.shapefile_path
    run_name = config_params.run_name
    resolution = config_params.resolution
    preferred_provider = config_params.preferred_provider
    if preferred_provider == 'copernicus':
        save_dir = config_params.download_path + os.sep + 'SCIHUB' + os.sep + 'NDVI'
    else:
        save_dir = config_params.download_path + os.sep + 'THEIA' + os.sep + 'NDVI'
        
    # Check resolution for Sentinel-2
    if not resolution in [10, 20]:
        print('Resolution should be equal to 10 or 20 meters for sentinel-2')
        return None

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
    if type(extracted_paths) == str:
        with open(extracted_paths, 'r') as file:
            extracted_paths = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
                extracted_paths.append(row[0])
    
    # Sort by band
    red_paths = []
    nir_paths = []
    mask_paths = []
    
    # Check provider
    if config_params.preferred_provider == 'copernicus':
        
        # Integer offset has to be applied on optical bands for copernicus data
        offset = 0
        
        if resolution == 10:
            for product in extracted_paths:
                for file in os.listdir(product):
                    if fnmatch(file, '*_B04_10m*'):
                        red_paths.append(product + os.sep + file)
                    elif fnmatch(file, '*_B08_10m*'):
                        nir_paths.append(product + os.sep + file)
                    elif fnmatch(file, '*_SCL_20m*'):
                        mask_paths.append(product + os.sep + file)
        else:
            for product in extracted_paths:
                for file in os.listdir(product):
                    if fnmatch(file, '*_B04_20m*'):
                        red_paths.append(product + os.sep + file)
                    elif fnmatch(file, '*_B08_10m*'):
                        nir_paths.append(product + os.sep + file)
                    elif fnmatch(file, '*_SCL_20m*'):
                        mask_paths.append(product + os.sep + file)
    
    elif config_params.preferred_provider == 'copernicus':
        print('Theia data handling not yet implemented')
        return None

    # Sort and get dates
    red_paths.sort()
    nir_paths.sort()
    mask_paths.sort()
    dates = [product_str_to_datetime(prod) for prod in red_paths]
    
    # Get crs
    with rio.open(red_paths[0]) as temp:
        crs = temp.crs
    
    # Open shapefile to clip the dataset
    shapefile = gpd.read_file(boundary_shapefile_path)
    shapefile = shapefile.to_crs(crs)
    bounds = shapefile.bounds.values[0]
    
    # Open datasets with xarray, select only data inside the bounds of the input shapefile
    red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
    nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'}).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
    mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'}).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
    
    if resolution == 10:
        mask = xr.where((mask == 4) | (mask == 5), 1, np.NaN).interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest')
    else:
        nir = nir.interp(x = red.coords['x'], y = red.coords['y'], method = 'linear')
        mask = xr.where((mask == 4) | (mask == 5), 1, np.NaN)

    # Set time coordinate
    red['time'] = pd.to_datetime(dates)
    nir['time'] = pd.to_datetime(dates)
    mask['time'] = pd.to_datetime(dates)
    
    # Save single red band as reference tif for later use in reprojection algorithms
    ref = xr.open_dataset(red_paths[0]).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1])).rio.write_crs(crs)
    ref.rio.to_raster(save_dir + os.sep + run_name + os.sep + run_name + '_grid_reference.tif')
    
    # Create save path
    ndvi_precube_path = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_precube_' + dates[0].strftime('%Y-%m-%d') + '_' + dates[-1].strftime('%Y-%m-%d') + '.nc'
    
    # Check if file exists and ndvi overwrite is false
    if os.path.exists(ndvi_precube_path) and not config_params.ndvi_overwrite:
        
        return ndvi_precube_path

    # Create ndvi dataset and calculate ndvi
    ndvi = red
    ndvi = ndvi.drop('red')
    ndvi['NDVI'] = (((nir.nir - red.red - acorvi_corr)/(nir.nir + red.red + acorvi_corr + 2 * offset)) * mask.mask)
    del red, nir, mask
    
    # Mask and scale ndvi
    ndvi['NDVI'] = xr.where(ndvi.NDVI < 0, 0, ndvi.NDVI)
    ndvi['NDVI'] = xr.where(ndvi.NDVI > 1, 1, ndvi.NDVI)
    ndvi['NDVI'] = ((ndvi.NDVI + 1/scaling) * (scaling - 1)).round()
        
    # Write attributes
    ndvi['NDVI'].attrs['units'] = 'None'
    ndvi['NDVI'].attrs['standard_name'] = 'NDVI'
    ndvi['NDVI'].attrs['description'] = 'Normalized difference Vegetation Index (of the near infrared and red band). A value of one is a high vegetation presence.'
    ndvi['NDVI'].attrs['scale factor'] = str(scaling)
    
    # Save NDVI cube to netcdf
    file_chunksize = (ndvi.dims['time'], min(500, ndvi.dims['y']), min(500, ndvi.dims['x']))
    ndvi.to_netcdf(ndvi_precube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": file_chunksize}})
    ndvi.close()
    
    return ndvi_precube_path


def interpolate_ndvi(ndvi_path: str, config_file: str, chunk_size: dict = {'x': 500, 'y': 500, 'time': -1}) -> str:
    """
    Interpolate the ndvi cube to a daily frequency between the
    desired dates defined in the ``json`` config file. The extra
    month of data downloaded is used for a better interpolation,
    it is then discarded and the final NDVI cube has the dates
    defined in the config file.

    Arguments
    =========

    1. ndvi_path: ``str``
        path to ndvi pre-cube
    2. config_file: ``str``
        path to ``json`` config file
    3. chunk_size: ``dict`` ``default = {'x': 500, 'y': 500, 'time': -1}``
        chunk size to use by dask for calculation,
        ``'time' = -1`` means the chunk has the whole
        time dimension in it. The Dataset can't be
        divided along the time axis for interpolation.

    Returns
    =======

    ``None``
    """
    
    # Open config_file
    config_params = config(config_file)
    
    # Load parameters from config file
    run_name = config_params.run_name
    if config_params.preferred_provider == 'copernicus':
        save_dir = config_params.download_path + os.sep + 'SCIHUB' + os.sep + 'NDVI'
    else:
        save_dir = config_params.download_path + os.sep + 'THEIA' + os.sep + 'NDVI'
    
    # Create save path
    ndvi_cube_path = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_cube_' + config_params.start_date + '_' + config_params.end_date + '.nc'
    
    # Check if file exists and ndvi overwrite is false
    if os.path.exists(ndvi_cube_path) and not config_params.ndvi_overwrite:
        
        return ndvi_cube_path
    
    # Open NDVI pre-cube
    ndvi = xr.open_dataset(ndvi_path, chunks = chunk_size)
    
    # Get dates of NDVI precube
    dates = ndvi.time.values
    
    # Get Spatial reference
    spatial_ref = ndvi.spatial_ref.load()
    
    # Sort images by time
    ndvi = ndvi.sortby('time')
    
    # Interpolates on a daily frequency
    # daily_index = pd.date_range(start = config_params.start_date, end = config_params.end_date, freq = 'D')
    daily_index = pd.date_range(start = dates[0], end = dates[-1], freq = 'D')

    # Resample the dataset to a daily frequency and reindex with the new DateTimeIndex
    ndvi = ndvi.resample(time = '1D').asfreq().reindex(time = daily_index)
    dates = pd.to_datetime(ndvi.time.values)

    # Interpolate the dataset along the time dimension to fill nan values
    ndvi = ndvi.interpolate_na(dim = 'time', method = 'linear', fill_value = 'extrapolate').round(decimals = 0)
    
    # Remove extra dates, only keep selected window
    ndvi = ndvi.sel({'time': slice(config_params.start_date, config_params.end_date)})
    
    # Rescale data
    ndvi = (ndvi * (255 / 254) - 1).round()
    
    # Set negative values as 0
    ndvi = xr.where(ndvi < 0, 0, ndvi.NDVI)
    # Set values above 255 (ndvi > 1) as 255 (ndvi = 1)
    ndvi = xr.where(ndvi > 255, 255, ndvi.NDVI)
    
    # Rewrite spatial reference
    ndvi['spatial_ref'] = spatial_ref
    
    # Save NDVI cube to netcdf
    if ndvi.dims['y'] > 1500 and ndvi.dims['x'] > 1500:
        file_chunksize = (1, 1000, 1000)
    else:
        file_chunksize = (1, ndvi.dims['y'], ndvi.dims['x'])
    ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": file_chunksize}})
    ndvi.close()
    
    return ndvi_cube_path


def write_geotiff(path: str, data: np.ndarray, transform: tuple, projection: str, scaling: int = 255) -> None:
    """
    Writes a GeoTiff image using ``rasterio``. Takes an array and georeferencing data (obtained by ``rasterio``
    on an image with same georeferencing) to write a well formatted image. Data format: ``uint8``

    Arguments
    =========

    1. path: ``str``
        true path of file to save data in (path = path + save name)
    2. data: ``np.ndarray``
        numpy array containging the data
    3. geotransform: ``tuple``
        tuple containing the geo-transform data
    4. projection: ``str``
        string containing the epsg projection data
    5. scaling: ``int`` ``default = 255``
        scaling factor for saving NDVI images as integer arrays

    Returns
    =======

    ``None``
    """

    # Apply scaling
    data = np.round((data + np.float32(1/scaling)) * (scaling - 1), 1)

    # Write NDVI image with rasterio
    with rio.open(path, mode = "w", driver = "GTiff", height = data.shape[0], width = data.shape[1], count = 1, dtype = np.uint8, crs = projection, transform = transform, nodata = 0) as new_dataset:
        new_dataset.write(data, 1)

    return None


def calculate_ndvi_image(args: tuple) -> str:
    """
    Opens zip file of the given product to extract the red, near-infrared and mask bands (without extracting
    the whole archive) and calculate the ndvi, apply a mask to keep only clear land data (no clouds, no snow,
    no water, no unclassified or erroneous pixel). This array is then saved in a GeoTiff image with
    the parent name file and ``'_NDVI.tif'`` extension. If overwrite is false, already existing files won't be
    calculated and saved again (default value is ``false``).

    This function is called in the calculate_ndvi function as a subprocess to allow multiprocessing. Variables
    that have served their purpose are directly deleted to reduce memory usage.

    Arguments (packed in args: ``tuple``)
    =====================================
    
    1. product_path: ``str``
        path to the product to extract for ndvi calculation
    2. save_dir: ``str``
        directory in which to save ndvi images
    3. overwrite: ``bool``
        boolean to choose to rewrite already existing files
    4. ACORVI_correction: ``int``
        correction parameter to apply on the red band for stability of NDVI values

    Returns
    =======

    1. ndvi_path: ``str``
        path to the saved ndvi image
    """

    # Unpack arguments
    product_path, save_dir, overwrite, ACORVI_correction = args
    
    # Check provider
    _, _, provider, _ = read_product_info(product_path) 

    # To ignore numpy warning when dividing by NaN
    np.seterr(invalid='ignore', divide='ignore')

    if provider == 'copernicus':

        # Define offset for copernicus data
        offset = -1000

        # Create file name
        file_name = os.path.basename(product_path)
        # file_name = os.path.splitext(file_name)[0]
        save_name = save_dir + os.sep + file_name + '_NDVI.tif'

        # Check if file is already written. If override is false, no need to calculate and write ndvi again.
        # If override is true: ndvi is calculated and saved again.
        if not os.path.exists(save_name) or overwrite:  # File does not exist or overwrite is true
            pass
        else:
            # Collect save path for return
            return save_name

        # Look for desired images in the directories
        for f in os.listdir(product_path):
            if fnmatch(f, '*_B04_10m.jp2'):  # Red band
                red_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_B08_10m.jp2'):  # Near infrared band
                nir_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_SCL_20m.jp2'):  # Scene classification for mask
                classif_file = os.path.join(product_path, f)

        # Read bands, geometry and projection information
        red_band = rio.open(red_file, mode = 'r')  # read red band
        transorm = red_band.transform
        projection = red_band.crs
        del red_file

        nir_band = rio.open(nir_file, mode = 'r')  # read nir band
        del nir_file

        # Read array data
        red_data = red_band.read(1) + ACORVI_correction
        red_band.close()
        nir_data = nir_band.read(1)
        nir_band.close()

        # Calculate ndvi
        ndvi_data = np.divide((nir_data - red_data), (nir_data + red_data + 2*offset), dtype = np.float32)
        del red_data, nir_data

        # Read classif data
        classif_band = rio.open(classif_file, mode = 'r')  # read classif band
        del classif_file

        # Read array data
        classif_data = classif_band.read(1)
        classif_data = zoom(classif_data, zoom = 2, order = 0, output = np.uint8)
        classif_band.close()

        # # Create binary mask
        binary_mask = np.where((classif_data == 4) | (classif_data == 5) | (classif_data == 6), np.float32(1), np.float32(np.NaN))
        del classif_data

        # Apply mask
        np.multiply(ndvi_data, binary_mask, out=ndvi_data, dtype = np.float32)
        del binary_mask

        # Clip out of range data
        np.minimum(ndvi_data, 1, out = ndvi_data, dtype = np.float32)
        np.maximum(ndvi_data, 0, out = ndvi_data, dtype = np.float32)

        # Write image
        write_geotiff(save_name, ndvi_data, transorm, projection)
        del ndvi_data, transorm, projection

    elif provider == 'theia':

        # Create file name
        file_name = os.path.basename(product_path)
        # file_name = os.path.splitext(file_name)[0]
        save_name = save_dir + os.sep + file_name + '_NDVI.tif'

        # Check if file is already written. If override is false, no need to calculate and write ndvi again.
        # If override is true: ndvi is calculated and saved again.
        if not os.path.exists(save_name) or overwrite:  # File does not exist or overwrite is true
            pass
        else:
            # Collect save path for return
            return save_name
        
        # Look for desired images in the directories
        for f in os.listdir(product_path):
            if fnmatch(f, '*_FRE_B4.tif'):  # Red band
                red_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_FRE_B8.tif'):  # Near infrared band
                nir_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_MG2_R1.tif'):  # Scene classification for mask
                classif_file = os.path.join(product_path, f)

        # Read bands, geometry and projection information
        red_band = rio.open(red_file, mode = 'r')  # read red band
        transorm = red_band.transform
        projection = red_band.crs
        del red_file

        nir_band = rio.open(nir_file, mode = 'r')  # read nir band
        del nir_file

        # Read array data, set -10000 as no data value on optical bands for correct masking
        red_data = red_band.read(1) + ACORVI_correction
        red_data = np.where((red_data == -10000), np.float32(np.NaN), red_data)
        red_band.close()
        nir_data = nir_band.read(1)
        nir_data = np.where((nir_data == -10000), np.float32(np.NaN), nir_data)
        nir_band.close()

        # Calculate ndvi
        ndvi_data = np.divide((nir_data - red_data), (nir_data + red_data), dtype = np.float32)
        del red_data, nir_data

        # Extract classif data
        classif_band = rio.open(classif_file, mode = 'r')  # read classif band
        del classif_file

        # Read array data
        classif_data = classif_band.read(1)
        # classif_data = zoom(classif_data, zoom = 2, order = 0, output = np.uint8)
        classif_band.close()

        # Create binary mask
        binary_mask = np.where((classif_data == 0) | (classif_data == 1), np.float32(1), np.float32(np.NaN))
        del classif_data

        # Apply mask
        np.multiply(ndvi_data, binary_mask, dtype = np.float32, out = ndvi_data)
        del binary_mask

        # Clip out of range data
        np.minimum(ndvi_data, 1, out = ndvi_data, dtype = np.float32)
        np.maximum(ndvi_data, 0, out = ndvi_data, dtype = np.float32)

        # Write image
        write_geotiff(save_name, ndvi_data, transorm, projection)
        del ndvi_data, transorm, projection

    return save_name


def calculate_ndvi_parcel(download_path: Union[List[str], str], save_dir: str, save_path: str, overwrite: bool = False, max_cpu: int = 4, ACORVI_correction: int = 500) -> List[str]:
    """
    Opens the red, near infrared and scene classification to calculate the ndvi, apply a mask to keep only clear land data (no clouds, no snow,
    no water, no unclassified or erroneous pixel). This array is then saved in a GeoTiff image with
    the parrent name file and ``'_NDVI.tif'`` extension. If overwrite is false, already existing fils won't be
    calculated and saved again (default value is ``false``).

    This function calls the calculate_ndvi_image function as a subprocess to allow multiprocessing. A modified
    version of the ``tqdm`` module (``p_tqdm``) is used for progress bars.

    Arguments
    =========
    
    1. download_path: ``list[str]`` or ``str``
        list or paths to the products to extract for ndvi calculation or path to ``csv`` file that contains this list
    2. save_dir: ``str``
        directory in which to save ndvi images
    3. save_path : ``str``
        path to a csv file containing the paths to the saved ndvi images
    4. overwrite: ``bool`` ``default = False``
        boolean to choose to rewrite already existing files
    5. max_cpu: ``int`` `default = 4`
        max number of CPUs that the pool can use, if max_cpu is bigger than available CPUs, the pool only uses availables CPUs
    6. ACORVI_correction: ``int`` ``default = 500``
        correction parameter to apply on the red band for stability of NDVI values

    Returns
    =======
    
    1. ndvi_path: ``list[str]``
        list of paths to the saved ndvi images
    """

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
    if type(download_path) == str:
        with open(download_path, 'r') as file:
            download_path = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
                download_path.append(row[0])

    ndvi_path = []  # Where saved image paths will be stored

    # Prepare arguments for multiprocessing
    args = [(product, save_dir, overwrite, ACORVI_correction) for product in download_path]

    # Get number of CPU cores and limit max value (working on a cluster requires os.sched_getaffinity to get true number of available CPUs, 
    # this is not true on a "personnal" computer, hence the use of the min function)
    try:
        nb_cores = min([max_cpu, cpu_count(logical = False), len(os.sched_getaffinity(0))])
    except:
        nb_cores = min([max_cpu, cpu_count(logical = False)])  # os.sched_getaffinity won't work on windows

    print('\nStarting NDVI calculations with %d cores for %d images...\n' %(nb_cores, len(download_path)))

    # Start pool and get results
    results = p_map(calculate_ndvi_image, args, **{"num_cpus": nb_cores})

    # Collect results and sort them
    for result in results:
        ndvi_path.append(result)
    ndvi_path.sort()

    # Save ndvi paths as a csv file
    with open(save_path, 'w', newline='') as f:
        # using csv.writer method from CSV package
        write = csv.writer(f)

        for ndvi_image in ndvi_path:
            write.writerow([ndvi_image])

    return ndvi_path