# -*- coding: UTF-8 -*-
# Python
"""
11-07-2023
@author: jeremy auclair

Usage of the SAMIR model in the Modspa framework.
"""

import os  # for path exploration
import csv  # open csv files
from fnmatch import fnmatch  # for character string comparison
from typing import List, Tuple  # to declare variables
import xarray as xr  # to manage dataset
import pandas as pd  # to manage dataframes
import numpy as np  # for math and array operations
import rasterio as rio  # to open geotiff files
import geopandas as gpd  # to manage shapefile crs projections
from parameters.params_samir_class import samir_parameters

def rasterize_samir_parameters(csv_param_file: str, parameter_dataset: xr.Dataset, land_cover_raster: str) -> xr.Dataset:
    """
    Creates a raster `xarray` dataset from the csv parameter file, the land cover raster and an empty dataset
    that contains the right structure (emptied ndvi dataset for example). For each parameter, the function loops
    on land cover classes to fill the raster.

    ## Arguments
    1. csv_param_file: `str`
        path to csv paramter file
    2. parameter_dataset: `xr.Dataset`
        empty dataset that contains the right structure (emptied ndvi dataset for example)
    3. land_cover_raster: `str`
        path to land cover netcdf raster

    ## Returns
    1. parameter_dataset: `xr.Dataset`
        the dataset containing all the rasterized Parameters
    """
    
    # Load samir params into an object
    table_param = samir_parameters(csv_param_file)
    
    # Set general variables
    class_count = table_param.table.shape[1] - 2  # remove dtype and default columns
    
    # Open land cover raster
    land_cover = xr.open_dataarray(land_cover_raster)
    
    # Loop on samir parameters and create 
    for parameter in table_param.table.index[1:]:
        
        # Create new variable and set attributes
        parameter_dataset[parameter] = land_cover.astype('i2')
        parameter_dataset[parameter].attrs['name'] = parameter
        parameter_dataset[parameter].attrs['description'] = 'cf SAMIR Doc for detail'
        
        # Loop on classes to set parameter values for each class
        for class_val, class_name in zip(range(1, class_count + 1), table_param.table.columns[2:]):
            
            # Parameter values are multiplied by the scale factor in order to store all values as int16 types
            # These values are then rounded to make sure there isn't any decimal point issues when casting the values to int16
            parameter_dataset[parameter].values = np.where(parameter_dataset[parameter].values == class_val, round(table_param.table.loc[table_param.table.index == parameter][class_name].values[0]*table_param.table.loc[table_param.table.index == parameter]['scale_factor'].values[0]), parameter_dataset[parameter].values).astype('i2')
    
    # Return dataset converted to 'int16' data type to reduce memory usage
    # Scale factor for calculation is stored in the samir_parameters object
    return parameter_dataset


def setup_time_loop(calculation_variables: List[str], calculation_constant_values: List[str], empty_dataset: xr.Dataset) -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset]:
    """
    Creates three temporary `xarray Datasets` that will be used in the SAMIR time loop.
    `variables_t1` corresponds to the variables for the previous day and `variables_t2`
    corresponds to the variables for the current day. After each loop, `variables_t1`
    takes the value of `variables_t2`. `constant_values` corresponds to model values
    that are not time dependant, and therefore stay constant during the SAMIR time loop.

    ## Arguments
    1. calculation_variables: `List[str]`
        list of strings containing the variable names
    2. calculation_constant_values: `List[str]`
        list of strings containing the constant value names
    3. empty_dataset: `xr.Dataset`
        empty dataset that contains the right structure

    ## Returns
    1. variables_t1: `xr.Dataset`
        output dataset for previous day
    2. variables_t2: `xr.Dataset`
        output dataset for current day
    3. constant_values: `xr.Dataset`
        output dataset for constant values
    """
    
    # Create new dataset
    variables_t1 = empty_dataset.copy(deep = True)
    
    # Create empty DataArray for each variable
    for variable in calculation_variables:
        
        # Assign new empty DataArray
        variables_t1[variable] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'float32'))
        variables_t1[variable].attrs['name'] = variable  # set name in attributes
    
    # Create new copy of the variables_t1 dataset
    variables_t2 = variables_t1.copy(deep = True)
    
    # Create dataset for constant values
    constant_values = empty_dataset.copy(deep = True)
    
    # Create empty DataArray for each value
    for value in calculation_constant_values:
        
        # Assign new empty DataArray
        constant_values[value] = (empty_dataset.dims, np.zeros(tuple(empty_dataset.dims[d] for d in list(empty_dataset.dims)), dtype = 'int16'))
        constant_values[value].attrs['name'] = value  # set name in attributes
        constant_values[value].attrs['description'] = 'Values which stays constant during the SAMIR time loop'  # set description in attributes
    
    return variables_t1, variables_t2, constant_values


def run_samir():
    
    return None