# -*- coding: UTF-8 -*-
# Python
"""
11-07-2023
@author: jeremy auclair

Usage of the SAMIR model in the Modspa framework.
"""

import os  # for path exploration
import csv  # open csv files
from fnmatch import fnmatch  # for character string comparison
from typing import List, Tuple, Union  # to declare variables
import xarray as xr  # to manage dataset
import pandas as pd  # to manage dataframes
import numpy as np  # for math and array operations
import rasterio as rio  # to open geotiff files
import geopandas as gpd  # to manage shapefile crs projections
from parameters.params_samir_class import samir_parameters


def xr_maximum(ds: xr.DataArray, value: Union[xr.DataArray, float, int]) -> xr.DataArray:
    """
    Equivalent of `numpy.maximum(ds, value)` for xarray DataArrays

    ## Arguments
    1. ds: `xr.DataArray`
        datarray to compare
    2. value: `Union[xr.DataArray, float, int]`
        value (scalar or dataarray) to compare

    ## Returns
    1. output: `xr.DataArray`
        resulting dataarray with maximum value element-wise
    """
    return xr.where(ds <= value, value, ds)


def xr_minimum(ds: xr.DataArray, value: Union[xr.DataArray, float, int]) -> xr.DataArray:
    """
    Equivalent of `numpy.minimum(ds, value)` for xarray DataArrays

    ## Arguments
    1. ds: `xr.DataArray`
        datarray to compare
    2. value: `Union[xr.DataArray, float, int]`
        value (scalar or dataarray) to compare

    ## Returns
    1. output: `xr.DataArray`
        resulting dataarray with minimum value element-wise
    """
    return xr.where(ds >= value, value, ds)


def rasterize_samir_parameters(csv_param_file: str, empty_dataset: xr.Dataset, land_cover_raster: str, chunk_size: dict) -> Tuple[xr.Dataset, dict]:
    """
    Creates a raster `xarray` dataset from the csv parameter file, the land cover raster and an empty dataset
    that contains the right structure (emptied ndvi dataset for example). For each parameter, the function loops
    on land cover classes to fill the raster.

    ## Arguments
    1. csv_param_file: `str`
        path to csv paramter file
    2. empty_dataset: `xr.Dataset`
        empty dataset that contains the right structure (emptied ndvi dataset for example).
    3. land_cover_raster: `str`
        path to land cover netcdf raster
    4. chunk_size: `dict`
        chunk_size for dask computation

    ## Returns
    1. parameter_dataset: `xr.Dataset`
        the dataset containing all the rasterized Parameters
    2. scale_factor: `dict`
        dictionnary containing the scale factors for each parameter
    """
    
    # Load samir params into an object
    table_param = samir_parameters(csv_param_file)
    
    # Set general variables
    class_count = table_param.table.shape[1] - 2  # remove dtype and default columns
    
    # Open land cover raster
    land_cover = xr.open_dataarray(land_cover_raster, chunks = chunk_size)
    
    # Create dataset
    parameter_dataset = empty_dataset.copy(deep = True)
    
    # Create dictionnary containing the scale factors
    scale_factor = {}
    
    # Loop on samir parameters and create 
    for parameter in table_param.table.index[1:]:
        
        # Create new variable and set attributes
        parameter_dataset[parameter] = land_cover.astype('i2')
        parameter_dataset[parameter].attrs['name'] = parameter
        parameter_dataset[parameter].attrs['description'] = 'cf SAMIR Doc for detail'
        parameter_dataset[parameter].attrs['scale factor'] = str(table_param.table.loc[table_param.table.index == parameter]['scale_factor'].values[0])
        
        # Assigne value in dictionnary
        scale_factor[parameter] = 1/int(table_param.table.loc[table_param.table.index == parameter]['scale_factor'].values[0])
        
        # Loop on classes to set parameter values for each class
        for class_val, class_name in zip(range(1, class_count + 1), table_param.table.columns[2:]):
            
            # Parameter values are multiplied by the scale factor in order to store all values as int16 types
            # These values are then rounded to make sure there isn't any decimal point issues when casting the values to int16
            parameter_dataset[parameter].values = np.where(parameter_dataset[parameter].values == class_val, round(table_param.table.loc[table_param.table.index == parameter][class_name].values[0]*table_param.table.loc[table_param.table.index == parameter]['scale_factor'].values[0]), parameter_dataset[parameter].values).astype('i2')
    
    # Return dataset converted to 'int16' data type to reduce memory usage
    # and scale_factor dictionnary for later conversion
    return parameter_dataset, scale_factor


def setup_time_loop(calculation_variables_t1: List[str], calculation_variables_t2: List[str], empty_dataset: xr.Dataset) -> Tuple[xr.Dataset, xr.Dataset]:
    """
    Creates two temporary `xarray Datasets` that will be used in the SAMIR time loop.
    `variables_t1` corresponds to the variables for the previous day and `variables_t2`
    corresponds to the variables for the current day. After each loop, `variables_t1`
    takes the value of `variables_t2` for the corresponding variables.

    ## Arguments
    1. calculation_variables_t1: `List[str]`
        list of strings containing the variable names
        for the previous day dataset
    2. calculation_variables_t2: `List[str]`
        list of strings containing the variable names
        for the current day dataset
    3. empty_dataset: `xr.Dataset`
        empty dataset that contains the right structure

    ## Returns
    1. variables_t1: `xr.Dataset`
        output dataset for previous day
    2. variables_t2: `xr.Dataset`
        output dataset for current day
    """
    
    # Create new dataset
    variables_t1 = empty_dataset.copy(deep = True)
    
    # Create empty DataArray for each variable
    for variable in calculation_variables_t1:
        
        # Assign new empty DataArray
        variables_t1[variable] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'float32'))
        variables_t1[variable].attrs['name'] = variable  # set name in attributes
    
    # Create new dataset
    variables_t2 = empty_dataset.copy(deep = True)
    
    # Create empty DataArray for each variable
    for variable in calculation_variables_t2:
        
        # Assign new empty DataArray
        variables_t2[variable] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'float32'))
        variables_t2[variable].attrs['name'] = variable  # set name in attributes
    
    return variables_t1, variables_t2


def prepare_outputs(empty_dataset: xr.Dataset, additional_outputs: List[str] = None) -> xr.Dataset:
    """
    Creates the `xarray Dataset` containing the outputs of the SAMIR model that will be saved.
    Additional variables can be saved by adding their names to the `additional_outputs` list.

    ## Arguments
    1. empty_dataset: `xr.Dataset`
        empty dataset that contains the right structure
    2. additional_outputs: `List[str]`
        list of additional variable names to be saved

    ## Returns
    1. model_outputs: `xr.Dataset`
        model outputs to be saved
    """
    
    # Evaporation and Transpiraion
    model_outputs = empty_dataset.copy(deep = True)
    
    model_outputs['E'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['E'].attrs['units'] = 'mm'
    model_outputs['E'].attrs['standard_name'] = 'Evaporation'
    model_outputs['E'].attrs['description'] = 'Accumulated daily evaporation in milimeters'
    model_outputs['E'].attrs['scale factor'] = '1'
    
    model_outputs['Tr'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['Tr'].attrs['units'] = 'mm'
    model_outputs['Tr'].attrs['standard_name'] = 'Transpiration'
    model_outputs['Tr'].attrs['description'] = 'Accumulated daily plant transpiration in milimeters'
    model_outputs['Tr'].attrs['scale factor'] = '1'
    
    # Soil Water Content
    model_outputs['SWCe'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['SWCe'].attrs['units'] = 'mm'
    model_outputs['SWCe'].attrs['standard_name'] = 'Soil Water Content of the evaporative zone'
    model_outputs['SWCe'].attrs['description'] = 'Soil water content of the evaporative zone in milimeters'
    model_outputs['SWCe'].attrs['scale factor'] = '1'
    
    model_outputs['SWCr'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['SWCr'].attrs['units'] = 'mm'
    model_outputs['SWCr'].attrs['standard_name'] = 'Soil Water Content of the root zone'
    model_outputs['SWCr'].attrs['description'] = 'Soil water content of the root zone in milimeters'
    model_outputs['SWCr'].attrs['scale factor'] = '1'
    
    # Irrigation
    model_outputs['Irr'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['Irr'].attrs['units'] = 'mm'
    model_outputs['Irr'].attrs['standard_name'] = 'Irrigation'
    model_outputs['Irr'].attrs['description'] = 'Simulated daily irrigation in milimeters'
    model_outputs['Irr'].attrs['scale factor'] = '1'
    
    # Deep Percolation
    model_outputs['DP'] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    model_outputs['DP'].attrs['units'] = 'mm'
    model_outputs['DP'].attrs['standard_name'] = 'Deep Percolation'
    model_outputs['DP'].attrs['description'] = 'Simulated daily Deep Percolation in milimeters'
    model_outputs['DP'].attrs['scale factor'] = '1'
    
    if additional_outputs:
        for var in additional_outputs:
            model_outputs[var] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
    
    return model_outputs


def calculate_diff_re(variable_ds: xr.Dataset, param_ds: xr.Dataset, scale_dict: dict, var: str) -> xr.DataArray:
    """
    Calculates the diffusion between the top soil layer and the root layer.

    ## Arguments
    1. variable_ds: `xr.Dataset`
        dataset containing calculation variables
    2. param_ds: `xr.Dataset`
        dataset containing the rasterized parameters
    3. scale_dict: `dict`
        dictionnary containing the scale factors for
        the rasterized parameters
    4. var: `str`
        name of depletion variable to use (Dei or Dep or De)

    ## Returns
    1. diff_re: `xr.Dataset`
        the diffusion between the top soil layer and
        the root layer
    """
    
    # Temporary variables to make calculation easier to read
    # tmp1 = (((variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr'] - (variable_ds['RUE'] - variable_ds[var]) / (scale_dict['Ze'] * param_ds['Ze'])) / variable_ds['FCov']) * (scale_dict['DiffE'] * param_ds['DiffE'])
    # tmp2 = ((variable_ds['TAW'] * scale_dict['Ze'] * param_ds['Ze']) - (variable_ds['RUE'] - variable_ds[var] - variable_ds['Dr']) * variable_ds['Zr']) / (variable_ds['Zr'] + scale_dict['Ze'] * param_ds['Ze']) - variable_ds['Dr']
    
    # Calculate diffusion according to SAMIR equation
    # diff_re = xr.where(tmp1 < 0, xr_maximum(tmp1, tmp2), xr_minimum(tmp1, tmp2))

    # Return zero values where the 'DiffE' parameter is equal to 0
    return xr.where(param_ds['DiffE'] == 0, 0, xr.where((((variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr'] - (variable_ds['RUE'] - variable_ds[var]) / (scale_dict['Ze'] * param_ds['Ze'])) / variable_ds['FCov']) * (scale_dict['DiffE'] * param_ds['DiffE']) < 0, xr_maximum((((variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr'] - (variable_ds['RUE'] - variable_ds[var]) / (scale_dict['Ze'] * param_ds['Ze'])) / variable_ds['FCov']) * (scale_dict['DiffE'] * param_ds['DiffE']), ((variable_ds['TAW'] * scale_dict['Ze'] * param_ds['Ze']) - (variable_ds['RUE'] - variable_ds[var] - variable_ds['Dr']) * variable_ds['Zr']) / (variable_ds['Zr'] + scale_dict['Ze'] * param_ds['Ze']) - variable_ds['Dr']), xr_minimum((((variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr'] - (variable_ds['RUE'] - variable_ds[var]) / (scale_dict['Ze'] * param_ds['Ze'])) / variable_ds['FCov']) * (scale_dict['DiffE'] * param_ds['DiffE']), ((variable_ds['TAW'] * scale_dict['Ze'] * param_ds['Ze']) - (variable_ds['RUE'] - variable_ds[var] - variable_ds['Dr']) * variable_ds['Zr']) / (variable_ds['Zr'] + scale_dict['Ze'] * param_ds['Ze']) - variable_ds['Dr'])))


def calculate_diff_dr(variable_ds: xr.Dataset, param_ds: xr.Dataset, scale_dict: dict) -> xr.DataArray:
    """
    Calculates the diffusion between the root layer and the deep layer.

    ## Arguments
    1. variable_ds: `xr.Dataset`
        dataset containing calculation variables
    2. param_ds: `xr.Dataset`
        dataset containing the rasterized parameters
    3. scale_dict: `dict`
        dictionnary containing the scale factors for
        the rasterized parameters

    ## Returns
    1. diff_dr: `xr.Dataset`
        the diffusion between the root layer and the
        deep layer
    """
    
    # Temporary variables to make calculation easier to read
    # tmp1 = (((variable_ds['TDW'] - variable_ds['Dd']) / (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr']) - (variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr']) / variable_ds['FCov']) * scale_dict['DiffR'] * param_ds['DiffR']
    # tmp2 = (variable_ds['TDW'] * variable_ds['Zr'] - (variable_ds['TAW'] - variable_ds['Dr'] - variable_ds['Dd']) * (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr'])) / (scale_dict['Zsoil'] * param_ds['Zsoil']) - variable_ds['Dd']
    
    # Calculate diffusion according to SAMIR equation
    # diff_dr = xr.where(tmp1 < 0, xr_maximum(tmp1, tmp2), xr_minimum(tmp1, tmp2))
    
    # Return zero values where the 'DiffR' parameter is equal to 0
    # return xr.where(param_ds['DiffR'] == 0, 0, diff_dr)
    return xr.where(param_ds['DiffR'] == 0, 0, xr.where((((variable_ds['TDW'] - variable_ds['Dd']) / (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr']) - (variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr']) / variable_ds['FCov']) * scale_dict['DiffR'] * param_ds['DiffR'] < 0, xr_maximum((((variable_ds['TDW'] - variable_ds['Dd']) / (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr']) - (variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr']) / variable_ds['FCov']) * scale_dict['DiffR'] * param_ds['DiffR'], (variable_ds['TDW'] * variable_ds['Zr'] - (variable_ds['TAW'] - variable_ds['Dr'] - variable_ds['Dd']) * (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr'])) / (scale_dict['Zsoil'] * param_ds['Zsoil']) - variable_ds['Dd']), xr_minimum((((variable_ds['TDW'] - variable_ds['Dd']) / (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr']) - (variable_ds['TAW'] - variable_ds['Dr']) / variable_ds['Zr']) / variable_ds['FCov']) * scale_dict['DiffR'] * param_ds['DiffR'], (variable_ds['TDW'] * variable_ds['Zr'] - (variable_ds['TAW'] - variable_ds['Dr'] - variable_ds['Dd']) * (scale_dict['Zsoil'] * param_ds['Zsoil'] - variable_ds['Zr'])) / (scale_dict['Zsoil'] * param_ds['Zsoil']) - variable_ds['Dd'])))


def run_samir():
    
    return None