# -*- coding: UTF-8 -*-
# Python
"""11-07-2023
@author: jeremy auclair

Usage of the SAMIR model in the Modspa framework. 

"""

import os  # os management
import xarray as xr  # to manage datasets
from numba import njit, float32, int16, set_num_threads  # to compile functions in nopython mode for faster calculation
import numpy as np  # for vectorized maths
import pandas as pd  # to manage dataframes
from typing import List, Tuple  # to declare argument types
import netCDF4 as nc  # to efficiently read and write netCDF files
from tqdm import tqdm  # to show a progress bar
from psutil import virtual_memory  # to check available ram
from psutil import cpu_count  # to get number of physical cores available
from modspa_pixel.parameters.params_samir_class import samir_parameters  # to load SAMIR parameters
from modspa_pixel.source.code_toolbox import format_Byte_size  # to print memory requirements


def test_samir_parameter(table: pd.DataFrame, min_max_param_file: str) -> pd.DataFrame:
    """
    Test the values of the SAMIR parameters to check if they
    are in the correct range (defined the the csv param_range
    file) and automatically calculate the Fcover and Kcb slope
    and offset (from the NDVIsol and NDVImax values) if their
    value is equal to -9999.

    Arguments
    =========

    1. table: ``pd.DataFrame``
        samir parameter dataframe

    Returns
    =======

    1. table: ``pd.DataFrame``
        updated samir parameter dataframe
    """
    
    # Automatically calculate the FC and Kcb slopes and offsets if necessary
    for class_name in table.columns[1:]:
        if table.at['Fslope', class_name] == -9999:
            table.at['Fslope', class_name] = np.round(table.at['FCmax', class_name] / (table.at['NDVImax', class_name] - table.at['NDVIsol', class_name]), decimals = round(np.log10(table.at['Fslope', 'scale_factor'])))
            
        if table.at['Foffset', class_name] == -9999:
            table.at['Foffset', class_name] = - np.round(table.at['NDVIsol', class_name] * table.at['Fslope', class_name], decimals = round(np.log10(table.at['Foffset', 'scale_factor'])))
        
        if table.at['Kslope', class_name] == -9999:
            table.at['Kslope', class_name] = np.round(table.at['Kcmax', class_name] / (table.at['NDVImax', class_name] - table.at['NDVIsol', class_name]), decimals = round(np.log10(table.at['Kslope', 'scale_factor'])))
            
        if table.at['Koffset', class_name] == -9999:
            table.at['Koffset', class_name] = - np.round(table.at['NDVIsol', class_name] * table.at['Kslope', class_name], decimals = round(np.log10(table.at['Koffset', 'scale_factor'])))
    
    # Test values of the parameters
    min_max_table = pd.read_csv(min_max_param_file, index_col = 0)

    # Boolean set to true if a parameter is out of range
    out_of_range_param = False

    # Loop through parameter values for all the classes
    for parameter in min_max_table.index:
        for class_name in table.columns[1:]:
            
            # Test if parameter is out of range
            if table.at[parameter, class_name] > min_max_table.at[parameter, 'max_value'] or table.at[parameter, class_name] < min_max_table.at[parameter, 'min_value']:
                
                # If parameter is out of range, print which parameter and for which class
                print(f'\nParameter {parameter} is out of range for class {class_name}')
                
                # Set boolean to true to exit script
                out_of_range_param = True
    
    # Return 0 if param is out of range
    if out_of_range_param:
        
        return 0
    
    # Remove unecessary rows
    table.drop(['NDVIsol', 'NDVImax'], inplace = True)
    
    return table


def rasterize_samir_parameters(csv_param_file: str, land_cover_raster: str) -> dict:
    """
    Creates a dictionnary containing raster parameter values and the scale factors from the csv parameter file
    and the land cover raster. For each parameter, the function loops on land cover classes to fill the raster.
    
    Before creating the dictionnary, it updates the parameter ``DataFrame`` and verifies the parameter range with
    the ``test_samir_parameter()`` function.

    Arguments
    =========

    1. csv_param_file: ``str``
        path to csv paramter file
    2. land_cover_raster: ``str``
        path to land cover netcdf raster

    Returns
    =======
    
    parameter_dict: ``dict``
        the dictionnary containing all the rasterized Parameters
        and the scale factors

    """
    
    # Get path of min max csv file
    min_max_param_file = os.path.dirname(csv_param_file) + os.sep + 'params_samir_min_max.csv'

    # Load samir params into an object
    table_param = samir_parameters(csv_param_file)

    # Set general variables
    class_count = table_param.table.shape[1] - 2  # remove dtype and Default columns

    # Open land cover raster
    land_cover = xr.open_dataarray(land_cover_raster).to_numpy()
    
    # Create parameter dictionnary
    parameter_dict = {}
    
    # # Modify parameter table based on its content
    table_param.table = test_samir_parameter(table_param.table, min_max_param_file)
    
    # If test_samir_parameter returns 0, parameters are incorrect
    if table_param.table == 0:
        import sys  # to exit script
        print(f'\nModify the SAMIR parameter file: {csv_param_file}\n')
        sys.exit()
         
    # Loop on samir parameters and create
    for parameter in table_param.table.index[1:]:
        
        # Get scale factor
        scale_factor = table_param.table.at[parameter, 'scale_factor']
        
        # If scale_factor == 0, then the parameter does not have to be spatialized, it can stay scaler
        if scale_factor == 0:
            
            # Set scale factor to 1
            parameter_dict['s_' + parameter] = np.float32(1)
            
            # Take Default parameter value
            parameter_dict[parameter + '_'] = np.float32(table_param.table.at[parameter, 'Default'])
            
            continue
        
        # Create 2D array from land cover
        value = land_cover.copy().astype('int16')
        
        # Assign scale factor
        # Scale factors have the following name scheme : s_ + parameter_name
        parameter_dict['s_' + parameter] = np.float32(1/int(scale_factor))
        
        # Loop on classes to set parameter values for each class
        for class_val, class_name in zip(range(1, class_count + 1), table_param.table.columns[2:]):
            
            # Get parameter value
            param_value = table_param.table.at[parameter, class_name]

            # Parameter values are multiplied by the scale factor in order to store all values as int16 types
            # These values are then rounded to make sure there isn't any decimal point issues when casting the values to int16
            value = np.where(value == class_val, round(param_value * scale_factor), value).astype('int16')
        
        # Assign parameter value
        # Parameters have an underscore (_) behind their name for recognition
        parameter_dict[parameter + '_'] = value
    
    return parameter_dict


def prepare_output_dataset(ndvi_path: str, dimensions: dict, scaling_dict: dict = {'E': 1000, 'Tr': 1000, 'SWCe': 1000, 'SWCr': 1000, 'DP': 100, 'Irr': 100, 'ET0': 1000}, additional_outputs: dict = None) -> xr.Dataset:
    """
    Creates the `xarray Dataset` containing the outputs of the SAMIR model that will be saved.
    Additional variables can be saved by adding their names to the `additional_outputs` list.

    Arguments
    =========
    
    1. ndvi_path: ``str``
        path to ndvi cube
    2. dimensions: ``dict``
        frozen dictionnary containing the dimensions of the output dataset
    3. scaling_dict: ``dict`` ``default = {'E': 1000, 'Tr': 1000, 'SWCe': 1000, 'SWCr': 1000, 'DP': 100, 'Irr': 100, 'ET0': 1000}``
        scaling dictionnary for the nominal outputs
    4. additional_outputs: ``List[str]``
        list of additional variable names to be saved

    Returns
    =======
    
    1. model_outputs: ``xr.Dataset``
        model outputs to be saved
    """

    # Evaporation and Transpiraion
    model_outputs = xr.open_dataset(ndvi_path).drop_vars(['NDVI']).copy(deep = True)
    model_outputs = model_outputs.drop_sel(time = model_outputs.time)
    
    # Add scaling dictionnary to global attribute
    model_outputs.attrs['scaling'] = str(scaling_dict)

    model_outputs['E'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['E'].attrs['units'] = 'mm'
    model_outputs['E'].attrs['standard_name'] = 'Evaporation'
    model_outputs['E'].attrs['description'] = 'Accumulated daily evaporation in milimeters'
    model_outputs['E'].attrs['scale factor'] = scaling_dict['E']

    model_outputs['Tr'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['Tr'].attrs['units'] = 'mm'
    model_outputs['Tr'].attrs['standard_name'] = 'Transpiration'
    model_outputs['Tr'].attrs['description'] = 'Accumulated daily plant transpiration in milimeters'
    model_outputs['Tr'].attrs['scale factor'] = scaling_dict['Tr']

    # Soil Water Content
    model_outputs['SWCe'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['SWCe'].attrs['units'] = '%'
    model_outputs['SWCe'].attrs['standard_name'] = 'Soil Water Content of the evaporative layer'
    model_outputs['SWCe'].attrs['description'] = 'Soil water content of the evaporative layer in milimeters'
    model_outputs['SWCe'].attrs['scale factor'] = scaling_dict['SWCe']

    model_outputs['SWCr'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['SWCr'].attrs['units'] = '%'
    model_outputs['SWCr'].attrs['standard_name'] = 'Soil Water Content of the root layer'
    model_outputs['SWCr'].attrs['description'] = 'Soil water content of the root layer in milimeters'
    model_outputs['SWCr'].attrs['scale factor'] = scaling_dict['SWCr']

    # Irrigation
    model_outputs['Irr'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['Irr'].attrs['units'] = 'mm'
    model_outputs['Irr'].attrs['standard_name'] = 'Irrigation'
    model_outputs['Irr'].attrs['description'] = 'Simulated daily irrigation in milimeters'
    model_outputs['Irr'].attrs['scale factor'] = scaling_dict['Irr']

    # Deep Percolation
    model_outputs['DP'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
    model_outputs['DP'].attrs['units'] = 'mm'
    model_outputs['DP'].attrs['standard_name'] = 'Deep Percolation'
    model_outputs['DP'].attrs['description'] = 'Simulated daily Deep Percolation in milimeters'
    model_outputs['DP'].attrs['scale factor'] = scaling_dict['DP']

    if additional_outputs:
        for var, scale in zip(additional_outputs.keys(), additional_outputs.values()):
            model_outputs[var] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.int16))
            model_outputs[var].attrs['scale factor'] = scale

    return model_outputs


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_diff_re(TAW: np.ndarray, Dr: np.ndarray, Zr: np.ndarray, RUE: np.ndarray, De: np.ndarray, Wfc: np.ndarray, Ze: np.ndarray, DiffE: np.ndarray) -> np.ndarray:
    """
    Calculates the diffusion between the top soil layer and the root layer. Uses numba for faster and parallel calculation.

    Arguments
    =========
    
    1. TAW: ``np.ndarray``
        water capacity of root layer
    2. Dr: ``np.ndarray``
        depletion of root layer
    3. Zr: ``np.ndarray``
        height of root layer
    4. RUE: ``np.ndarray``
        total available surface water = (Wfc-Wwp)*Ze
    5. De: ``np.ndarray``
        depletion of the evaporative layer
    6. Wfc: ``np.ndarray``
        field capacity
    7. Ze: ``np.ndarray``
        height of evaporative layer (paramter)
    8. DiffE: ``np.ndarray``
        diffusion coefficient between evaporative
        and root layers (unitless, parameter)

    Returns
    =======
    
    1. diff_re: ``np.ndarray``
        the diffusion between the top soil layer and
        the root layer
    """

    # Temporary variables to make calculation easier to read
    tmp1 = ((TAW - Dr) / Zr - (RUE - De) / Ze) / (Wfc * DiffE)
    tmp2 = ((TAW * Ze) - (RUE - De - Dr) * Zr) / (Zr + Ze) - Dr
    
    # Calculate diffusion according to SAMIR equation
    # Return zero values where the 'DiffE' parameter is equal to 0
    return np.where(DiffE == 0, np.float32(0), np.where(tmp1 < 0, np.maximum(tmp1, tmp2), np.minimum(tmp1, tmp2)))


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_diff_dr(TAW: np.ndarray, TDW: np.ndarray, Dr: np.ndarray, Zr: np.ndarray, Dd: np.ndarray, Wfc: np.ndarray, Zsoil: np.ndarray, DiffR: np.ndarray) -> np.ndarray:
    """
    Calculates the diffusion between the root layer and the deep layer. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. TAW: ``np.ndarray``
        water capacity of root layer
    2. TDW: ``np.ndarray``
        water capacity of deep layer
    3. Dr: ``np.ndarray``
        depletion of root layer
    4. Zr: ``np.ndarray``
        height of root layer
    5. Dd: ``np.ndarray``
        depletion of deep layer
    6. Wfc: ``np.ndarray``
        field capacity
    7. Zsoil: ``np.ndarray``
        total height of soil (paramter)
    8. DiffR: ``np.ndarray``
        Diffusion coefficient between root
        and deep layers (unitless, parameter)

    Returns
    =======
    
    1. Diff_dr: ``np.ndarray``
        the diffusion between the root layer and the
        deep layer
    """

    # Temporary variables to make calculation easier to read
    tmp1 = (((TDW - Dd) / (Zsoil - Zr) - (TAW - Dr) / Zr) / Wfc) * DiffR
    tmp2 = ((TDW * Zr - (TAW - Dr - Dd) * (Zsoil - Zr)) / Zsoil) - Dd

    # Calculate diffusion according to SAMIR equation
    # Return zero values where the 'DiffR' parameter is equal to 0
    return np.where(DiffR == 0, np.float32(0), np.where(tmp1 < 0, np.maximum(tmp1, tmp2), np.minimum(tmp1, tmp2)))


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_W(TEW: np.ndarray, Dei: np.ndarray, Dep: np.ndarray, fewi: np.ndarray, fewp: np.ndarray) -> np.ndarray:
    """
    Calculate W, the weighting factor to split the energy available
    for evaporation depending on the difference in water availability
    in the two evaporation components, ensuring that the larger and
    the wetter, the more the evaporation occurs from that component.
    Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. TEW: ``np.ndarray``
        water capacity of evaporative layer
    2. Dei: ``np.ndarray``
        depletion of the evaporative layer
        (irrigation part)
    3. Dep: ``np.ndarray``
        depletion of the evaporative layer
        (precipitation part)
    4. fewi: ``np.ndarray``
        soil fraction which is wetted by irrigation
        and exposed to evaporation
    5. fewp: ``np.ndarray``
        soil fraction which is wetted by precipitation
        and exposed to evaporation

    Returns
    =======
    
    1. W: ``np.ndarray``
        weighting factor W
    """

    # Calculate the weighting factor to split the energy available for evaporation
    # * Equation: W = 1 / (1 + (fewp * (TEW - Dep) / fewi * (TEW - Dei)))
    # Return W
    # * Equation: W = where(fewi * (TEW - Dei) > 0, W, 0)
    return np.where(fewi * (TEW - Dei) > 0, 1 / (1 + (fewp * (TEW - Dep) / fewi * (TEW - Dei))), np.float32(0))


@njit((float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_Kr(TEW: np.ndarray, De: np.ndarray, REW: np.ndarray) -> np.ndarray:
    """
    calculates of the reduction coefficient for evaporation dependent
    on the amount of water in the soil using the FAO-56 method. Uses numba for faster and parallel calculation.

    Arguments
    =========
    
    1. TEW: ``np.ndarray``
        water capacity of evaporative layer
    2. De: ``np.ndarray``
        depletion of evaporative layer
    3. REW: ``np.ndarray``
        fixed readily evaporable water

    Returns
    =======
    
    1. Kr: ``np.ndarray``
        Kr coefficient
    """

    # Return Kr
    return np.maximum(np.float32(0), np.minimum((TEW - De) / (TEW - REW), np.float32(1)))


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_Ks(Dr: np.ndarray, TAW: np.ndarray, p: np.ndarray, E0: np.ndarray, Tr0: np.ndarray) -> np.ndarray:
    """
    Calculate Ks coefficient after day 1. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. Dr: ``np.ndarray``
        depletion of the root layer
    2. TAW: ``np.ndarray``
        water capacity of the root layer
    3. p: ``np.ndarray``
        fraction of TAW available for plant without inducing stress
    4. E0: ``np.ndarray``
        surface evaporation of previous day
    5. Tr0: ``np.ndarray``
        plant transpiration of previous day

    Returns
    =======

    1. Ks: ``np.ndarray``
        Ks coefficient
    """

    # Return Ks
    # * Equation: Ks = min((TAW - Dr) / (TAW * (1 - (p + 0.04 * (5 - (E0 + Tr0))))), 1)
    return np.minimum((TAW - Dr) / (TAW * (np.float32(1) - (p + 0.04 * (np.float32(5) - (E0 + Tr0))))), np.float32(1)).astype(np.float32)


def calculate_Ke(W: np.ndarray, De: np.ndarray, TEW: np.ndarray, REW: np.ndarray, Kcmax: float, Kcb: np.ndarray, few: np.ndarray) -> np.ndarray:
    """
    Calculate the evaporation Ke coefficient.

    Arguments
    =========

    1. W: ``np.ndarray``
        weighting factor to split the energy available
        for evaporation
    2. De: ``np.ndarray``
        Dei or Dep, depletion of the evaporative layer
    3. TEW: ``np.ndarray``
        water capacity of the evaporative layer
    4. REW: ``np.ndarray``
        fixed readily evaporable water
    5. Kcmax: ``float``
        maximum possible evaporation in the atmosphere
    6. Kcb: ``np.ndarray``
        crop coefficient
    7. few: ``np.ndarray``
        fewi or fewp, soil fraction which is wetted by
        irrigation or precipitation and exposed to evaporation

    Returns
    =======

    1. Ke: ``np.ndarray``
        evaporation coefficient
    """
    
    # * Equation: Kei = np.minimum(W * Kri * (Kc_max - Kcb), fewi * Kc_max)
    # * Equation: Kep = np.minimum((1 - W) * Krp * (Kc_max - Kcb), fewp * Kc_max)
    # Return Ke
    return np.minimum(W * calculate_Kr(TEW, De, REW) * (Kcmax - Kcb), few * Kcmax)


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int16[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_irrig(Dr: np.ndarray, TAW: np.ndarray, p: np.ndarray, Rain: np.ndarray, Kcb: np.ndarray, Irrig_auto: np.ndarray, Kcb_stop_irrig: np.ndarray, E0: np.ndarray, Tr0: np.ndarray) -> np.ndarray:
    """
    Calculate automatic irrigation after day one. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. Dr: ``np.ndarray``
        depletion of the root layer
    2. TAW: ``np.ndarray``
        water capacity of the root layer
    3. p: ``np.ndarray``
        fraction of TAW available for plant without inducing stress
    4. Rain: ``np.ndarray``
        precipitation of the current day
    5. Kcb: ``np.ndarray``
        crop coefficient value
    6. Irrig_auto: ``np.ndarray``
        parameter for automatic irrigation
    7. Kcb_stop_irrig: ``np.ndarray``
        Kcb threshold to stop irrigation
    8. E0: ``np.ndarray``
        surface evaporation of previous day
    9. Tr0: ``np.ndarray``
        plant transpiration of previous day

    Returns
    =======

    1. Irrig: ``np.ndarray``
        simulated irrigation for the current day
    """
    
    # First step
    Irrig = Irrig_auto * np.maximum(Dr - Rain, np.float32(0))
    
    # Return modified Irrig
    return np.where((Dr > TAW * (p + 0.04 * (np.float32(5) - (E0 + Tr0)))) & (Kcb > np.maximum(Kcb, np.float32(1)) * Kcb_stop_irrig), Irrig, np.float32(0))
    

@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_Te(De: np.ndarray, Dr: np.ndarray, Ks: np.ndarray, Kcb: np.ndarray, Ze: np.ndarray, Zr: np.ndarray, TEW: np.ndarray, TAW: np.ndarray, ET0: np.ndarray) -> np.ndarray:
    """
    Calculate Te (root uptake of water) coefficient for current day. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. De: ``np.ndarray``
        Dei or Dep, depletion of the evaporative layer
    2. Dr: ``np.ndarray``
        depletion of the roor layer
    3. Ks: ``np.ndarray``
        stress coefficient
    4. Kcb: ``np.ndarray``
        crop coefficient
    5. Ze: ``np.ndarray``
        height of the evaporative layer
    6. Zr: ``np.ndarray``
        heigh of the root layer
    7. TEW: ``np.ndarray``
        water capacity of the evaporative layer
    8. TAW: ``np.ndarray``
        water capacity of the root layer
    9. ET0: ``np.ndarray``
        reference evapotranspiration

    Returns
    =======

    1. Te: ``np.ndarray``
        Te coefficient
    """
    
    # * Equation: Kt = min( ((Ze / Zr)**0.6) * (1 - De/TEW) / max(1 - Dr / TAW, 0.01), 1)
    # * Equation: Te = Kt * Ks * Kcb * ET0
    return (np.minimum(((Ze / Zr)**0.6) * (np.float32(1) - De / TEW) / np.maximum(np.float32(1) - Dr / TAW, 0.001), np.float32(1)) * Ks * Kcb * ET0).astype(np.float32)


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def update_De_from_Diff(De : np.ndarray, few: np.ndarray, Ke: np.ndarray, Te: np.ndarray, Diff_re: np.ndarray, TEW: np.ndarray, ET0: np.ndarray) -> np.ndarray:
    """
    Last update step for Dei and Dep depletions. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. De: ``np.ndarray``
        Dei or Dep, depletion of the evaporative layer
    2. few: ``np.ndarray``
        fewi or fewp, soil fraction which is wetted by
        irrigation or precipitation and exposed to evaporation
    3. Ke: ``np.ndarray``
        Kei or Kep, evaporation coefficient for soil fraction
        irrigated or rainfed and exposed to evaporation
    4. Te: ``np.ndarray``
        root uptake of water
    5. Diff_re: ``np.ndarray``
        dffusion between the root and evaporation layers
    6. TEW: ``np.ndarray``
        water capacity of the evaporative layer
    7. ET0: ``np.ndarray``
        reference evapotranspiration of the current day

    Returns
    =======

    De: ``np.ndarray``
        updated Dei or Dep
    """
    
    # Update Dei and Dep depletions
    # * Equation: De = where(few > 0, min(max(De + ET0 * Ke / few + Te - Diff_re, 0), TEW), min(max(De + Te - Diff_re, 0), TEW))
    return np.where(few > np.float32(0), np.minimum(np.maximum(De + ET0 * Ke / few + Te - Diff_re, np.float32(0)), TEW), np.minimum(np.maximum(De + Te - Diff_re, np.float32(0)), TEW))
        

@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def update_Dr_from_root(Wfc: np.ndarray, Wwp: np.ndarray, Zr: np.ndarray, Zsoil: np.ndarray, Dr0: np.ndarray, Dd0: np.ndarray, Zr0: np.ndarray) -> np.ndarray:
    """
    Return the updated depletion for the root layer. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. Wfc: ``np.ndarray``
        field capacity
    2. Wwp: ``np.ndarray``
        field wilting point
    3. Zr: ``np.ndarray``
        root depth for current day
    4. Zsoil: ``np.ndarray``
        total soil depth (parameter)
    5. Dr0: ``np.ndarray``
        depletion of the root layer for previous day
    6. Dd0: ``np.ndarray``
        depletion of the deep laye for previous day
    7. Zr0: ``np.ndarray``
        root layer height for previous day

    Returns
    =======
    
    1. output: ``np.ndarray``
        updated depletion for the root layer
    """

    # Temporary variables to make calculation easier to read
    tmp1 = np.maximum(Dr0 + Dd0 * ((Wfc - Wwp) * (Zr - Zr0)) / ((Wfc - Wwp) * (Zsoil - Zr0)), np.float32(0))
    tmp2 = np.maximum(Dr0 + Dr0 * ((Wfc - Wwp) * (Zr - Zr0)) / ((Wfc - Wwp) * Zr0), np.float32(0))

    # Return updated Dr
    return np.where(Zr > Zr0, tmp1, tmp2)


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def update_Dd_from_root(Wfc: np.ndarray, Wwp: np.ndarray, Zr: np.ndarray, Zsoil: np.ndarray, Dr0: np.ndarray, Dd0: np.ndarray, Zr0: np.ndarray) -> np.ndarray:
    """
    Return the updated depletion for the deep layer. Uses numba for faster and parallel calculation.

    Arguments
    =========

    1. Wfc: ``np.ndarray``
        field capacity
    2. Wwp: ``np.ndarray``
        field wilting point
    3. Zr: ``np.ndarray``
        root depth for current day
    4. Zsoil: ``np.ndarray``
        total soil depth (parameter)
    5. Dr0: ``np.ndarray``
        depletion of the root layer for previous day
    6. Dd0: ``np.ndarray``
        depletion of the deep laye for previous day
    7. Zr0: ``np.ndarray``
        root layer height for previous day

    Returns
    =======
    
    1. output: ``np.ndarray``
        updated depletion for the deep layer
    """

    # Temporary variables to make calculation easier to read
    tmp1 = np.maximum(Dd0 - Dd0 * ((Wfc - Wwp) * (Zr - Zr0)) / ((Wfc - Wwp) * (Zsoil - Zr0)), np.float32(0))
    tmp2 = np.maximum(Dd0 - Dr0 * ((Wfc - Wwp) * (Zr - Zr0)) / ((Wfc - Wwp) * Zr0), np.float32(0))

    # Return updated Dd
    return np.where(Zr > Zr0, tmp1, tmp2)


@njit((float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:]), nogil = True, parallel = True, fastmath = True)
def calculate_SWCe(Dei: np.ndarray, Dep: np.ndarray, fewi: np.ndarray, fewp: np.ndarray, TEW: np.ndarray) -> np.ndarray:
    """
    Calculate the soil water content of the evaporative layer.

    Arguments
    =========

    1. Dei: ``np.ndarray``
        depletion of the evaporative layer for the irrigated part
    2. Dep: ``np.ndarray``
        depletion of the evaporative layer for the rainfed part
    3. fewi: ``np.ndarray``
        soil fraction which is wetted by irrigation and exposed
        to evaporation
    4. fewp: ``np.ndarray``
        soil fraction which is wetted by precipitation and exposed
        to evaporation
    5. TEW: ``np.ndarray``
        water capacity of the evaporative layer

    Returns
    =======

    1. SWCe: ``np.ndarray``
        soil water content of the evaporative layer
    """
    
    # Return SWCe
    return np.where((fewi + fewp) > 0, (TEW - (Dei * fewi + Dep * fewp) / (fewi + fewp)) / TEW, (TEW - (Dei + Dep) / 2) / TEW)


def calculate_memory_requirement(x_size: int, y_size: int, time_size: int, nb_inputs: int, nb_outputs: int, nb_variables: int, nb_params: int, nb_bits: int) -> float:
    """
    Calculate memory requirement (GiB) of calculation if all datasets where loaded in memory.
    Used to determine how to divide the datasets in times chunks for more efficient I/O
    operations.

    Arguments
    =========

    1. x_size: ``int``
        x size of dataset
    2. y_size: ``int``
        y size of dataset
    3. time_size: ``int``
        number of time bands
    4. nb_inputs: ``int``
        number of input variables
    5. nb_outputs: ``int``
        number of ouput variables
    6. nb_variables: ``int``
        number of calculation variables
    7. nb_params: ``int``
        number of raster parameters
    8. nb_bits: ``int``
        number of bits of datatype

    Returns
    =======

    1. total_memory_requirement: ``float``
        calculation memory requirement in GiB
    """
    
    # Memory requirement of input datasets
    input_memory_requirement = (x_size * y_size * time_size * nb_inputs * nb_bits) / (1024**3)
    
    # Memory requirement of calculation variables
    calculation_memory_requirement = (x_size * y_size * (nb_params + nb_variables) * nb_bits) / (1024**3)
    
    # Memory requirement of output datasets
    output_memory_requirement = (x_size * y_size * time_size * nb_outputs * nb_bits) / (1024**3)
    
    # Total memory requirement
    total_memory_requirement = input_memory_requirement + calculation_memory_requirement + output_memory_requirement

    return total_memory_requirement


def calculate_time_slices_to_load(memory_requirement: float, time_size: int, security_factor: float, available_ram: int) -> Tuple[int, int, bool]:
    """
    Calculate how many time slices to load in memory (for input and output data)
    based on available ram and calculation requirements.

    Arguments
    =========

    1. memory_requirement: ``float``
        amount of memory needed if whole input/output
        datasets would ne loaded.
    2. time_size: ``int``
        number of time slices in datasets
    3. security_factor: ``float``
        float between 0 and 1 to adjust memory requirements
    4. available_ram: ``int``
        available ram for computation in GiB

    Returns
    =======

    1. time_slice: ``int``
        number of times slices to load for
        input and output data
    2. remainder_to_load: ``int``
        remainder of euclidian division for
        the number of time slices to load
        (last block of data to load)
    3. already_loaded: ``bool``
        used to know wheather data has been loaded
        when the whole time values fit in memory, not
        used otherwise
    """
    
    # Possible division factors
    division_factors = [1, 2, 3, 4, 8, 16, 32, 64, 128, 256]
    
    # Determine the time slice to load
    for scale in division_factors:
        
        if memory_requirement / scale < security_factor * available_ram:
            
            if scale == 1:
                
                time_slice = time_size
                remainder_to_load = None
                already_loaded = False
                
                return time_slice, remainder_to_load, already_loaded
            
            else:
                time_slice = time_size // scale
                remainder_to_load = time_size % time_slice
                already_loaded = None
                
                return time_slice, remainder_to_load, already_loaded
    
    # If dataset is to big, load only one time slice per loop
    time_slice = 1
    remainder_to_load = 1  # in order to correctly save last date
    already_loaded = None
        
    return time_slice, remainder_to_load, already_loaded


def get_empty_arrays(x_size: int, y_size: int, time_size: int, nb_arrays: int, list: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Short function to make the `run_samir()` function easier to read.
    Generates a varying number (`nb_arrays`) of empty numpy arrays of shape 
    `(x_size, y_size, time_size)` in list or tuple mode. Used to
    store variable values before writing in in the output file.

    Arguments
    =========

    1. x_size: ``int``
        x size of dataset
    2. y_size: ``int``
        y size of dataset
    3. time_size: ``int``
        number of time bands
    4. nb_arrays: ``int``
        number of arrays to generate
    5. list: ``bool`` ``default = False``
        weather to return a tuple or a list

    Returns
    =======

    output: ``Tuple[np.ndarray * nb_arrays]`` or ``List[np.ndarray * nb_arrays]``
        output empty arrays
    """
    
    # Return empty arrays into a list
    if list:
        return [np.empty((time_size, y_size, x_size), dtype = np.float32) for k in range(nb_arrays)]
    
    # Return empty arrays into a tuple
    return tuple([np.empty((time_size, y_size, x_size), dtype = np.float32) for k in range(nb_arrays)])


# @profile  # type: ignore
def read_inputs(ndvi_cube_path: str, weather_path: str, i: int, time_slice: int, load_all: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Read input data into numpy arrays based on loop conditions

    Arguments
    =========

    1. ndvi_cube_path: ``str``
        path of input NDVI file
    2. weather_path: ``str``
        path of input weather file
    3. i: ``int``
        current loop counter
    4. time_slice: ``int``
        number of time slices to load
    5. load_all: ``bool`` ``default = False``
        boolean to load the whole datasets

    Returns
    =======

    1. NDVI: ``np.ndarray``
        scaled NDVI values into numpy array
    2. Rain: ``np.ndarray``
        scaled Rain values into numpy array
    3. ET0: ``np.ndarray``
        scaled ET0 values into numpy array
        
    """
    
    # Load whole dataset
    if load_all:
        
        with nc.Dataset(ndvi_cube_path, mode='r') as ds:
            # Dimensions of ndvi dataset : (time, y, x)
            NDVI = np.asarray(ds.variables['NDVI'][:, :, :] / 255, dtype = np.float32)
        with nc.Dataset(weather_path, mode='r') as ds:
            # Dimensions of ndvi dataset : (time, y, x)
            Rain = np.asarray(ds.variables['Rain'][:, :, :] / 100, dtype = np.float32)
            ET0 = np.asarray(ds.variables['ET0'][:, :, :] / 1000, dtype = np.float32)
    
    # Load given number of time slices
    else:
        
        with nc.Dataset(ndvi_cube_path, mode='r') as ds:
            # Dimensions of ndvi dataset : (time, y, x)
            NDVI = np.asarray(ds.variables['NDVI'][i: i + time_slice, :, :] / 255, dtype = np.float32)
        with nc.Dataset(weather_path, mode='r') as ds:
            # Dimensions of ndvi dataset : (time, y, x)
            Rain = np.asarray(ds.variables['Rain'][i: i + time_slice, :, :] / 100, dtype = np.float32)
            ET0 = np.asarray(ds.variables['ET0'][i: i + time_slice, :, :] / 1000, dtype = np.float32)
    
    return NDVI, Rain, ET0


def write_outputs(save_path: str, DP: np.ndarray, SWCe: np.ndarray, SWCr: np.ndarray, E: np.ndarray, Tr: np.ndarray, Irrig: np.ndarray, scaling_dict: dict, additional_outputs: dict, additional_outputs_data: List[np.ndarray], i: int, time_slice: int, write_all = False) -> None:
    """
    Write outputs to netcdf file based on conditions of current loop.

    Arguments
    =========

    1. save_path: ``str``
        output netcdf save path
    2. DP: ``np.ndarray``
        deep percolaton ``np.ndarray``
    3. SWCe: ``np.ndarray``
        soil water content of evaporative layer ``np.ndarray``
    4. SWCr: ``np.ndarray``
        soil water content of root layer ``np.ndarray``
    5. E: ``np.ndarray``
        surface evaporation ``np.ndarray``
    6. Tr: ``np.ndarray``
        plant transpiration ``np.ndarray``
    7. Irrig: ``np.ndarray``
        simulated irrigation ``np.ndarray``
    8. scaling_dict: ``str``
        scaling dictionnary for the nominal outputs
    9. additional_outputs: ``dict``
        dictionnary containing additional outputs and their scale factors
    9. additional_outputs_data: ``List[np.ndarray]``
        list of additional output ``np.ndarray``. Is ``None`` if no additional ouputs
    10. i: ``int``
        current loop counter
    11. time_slice: ``int``
        number of loaded time slices
    12. write_all: ``bool`` ``default = False``
        weather to write the whole output dataset

    Returns
    =======

    ``None``
    """
    
    
    # Write whole output dataset
    if write_all:
        
        with nc.Dataset(save_path, mode='a') as outputs:
            # Dimensions of output dataset : (x, y, time)
            # Deep percolation
            outputs.variables['DP'][:, :, :] = np.round(DP * scaling_dict['DP'])
            # Soil water content of the evaporative layer
            outputs.variables['SWCe'][:, :, :] = np.round(SWCe * scaling_dict['SWCe'])
            # Soil water content of the root layer
            outputs.variables['SWCr'][:, :, :] = np.round(SWCr * scaling_dict['SWCr'])
            # Evaporation
            outputs.variables['E'][:, :, :] = np.round(E * scaling_dict['E'])
            # Transpiration
            outputs.variables['Tr'][:, :, :] = np.round(Tr * scaling_dict['Tr'])
            # Irrigation
            outputs.variables['Irr'][:, :, :] = np.round(Irrig * scaling_dict['Irr'])
            # Additionnal outputs
            if additional_outputs:
                k = 0
                for var, scale in zip(additional_outputs.keys(), additional_outputs.values()):
                    outputs.variables[var][:, :, :] = np.round(additional_outputs_data[k][:,:,:] * scale)
                    k+=1
    
    else:
        
        # Write given number of time slices
        with nc.Dataset(save_path, mode='a') as outputs:
            # Dimensions of output dataset : (x, y, time)
            # Deep percolation
            outputs.variables['DP'][i - time_slice + 1: i + 1, :, :] = np.round(DP * scaling_dict['DP'])
            # Soil water content of the evaporative layer
            outputs.variables['SWCe'][i - time_slice + 1: i + 1, :, :] = np.round(SWCe * scaling_dict['SWCe'])
            # Soil water content of the root layer
            outputs.variables['SWCr'][i - time_slice + 1: i + 1, :, :] = np.round(SWCr * scaling_dict['SWCr'])
            # Evaporation
            outputs.variables['E'][i - time_slice + 1: i + 1, :, :] = np.round(E * scaling_dict['E'])
            # Transpiration
            outputs.variables['Tr'][i - time_slice + 1: i + 1, :, :] = np.round(Tr * scaling_dict['Tr'])
            # Irrigation
            outputs.variables['Irr'][i - time_slice + 1: i + 1, :, :] = np.round(Irrig * scaling_dict['Irr'])
            # Additionnal outputs
            if additional_outputs:
                k=0
                for var, scale in zip(additional_outputs.keys(), additional_outputs.values()):
                    outputs.variables[var][i - time_slice + 1: i + 1, :, :] = np.round(additional_outputs_data[k][:,:,:] * scale)
                    k+=1
        
    return None


# @profile  # type: ignore
def samir_daily(NDVI: np.ndarray, ET0: np.ndarray, Rain: np.ndarray, Wfc: np.ndarray, Wwp: np.ndarray, params: dict, Dr0: np.ndarray, Dd0: np.ndarray, Zr0: np.ndarray, E0: np.ndarray, Tr0: np.ndarray, Dei0: np.ndarray, Dep0: np.ndarray, iday: int) -> Tuple[np.ndarray]:
    """
    Run the SAMIR model on a single day. Inputs and outputs are `numpy.ndarray`.
    Calls functions compiled with numba for intermediary calculations.

    Arguments
    =========

    1. NDVI: ``np.ndarray``
        input NDVI
    2. ET0: ``np.ndarray``
        input ET0
    3. Rain: ``np.ndarray``
        input Rain
    4. Wfc: ``np.ndarray``
        field capacity
    5. Wwp: ``np.ndarray``
        field wilting point
    6. params: ``dict``
        dictionnary containing the rasterized
        samir parameters and their scale factors
    7. Dr0: ``np.ndarray``
        previous day root layer depletion
    8. Dd0: ``np.ndarray``
        previous day deep layer depletion
    9. Zr0: ``np.ndarray``
        previous day root depth
    10. E0: ``np.ndarray``
        previous day surface evaporation
    11. Tr0: ``np.ndarray``
        previous day plant transpiration
    12. Dei0: ``np.ndarray``
        previous day surface layer depletion
        for irrigation part
    13. Dep0: ``np.ndarray``
        previous day surface layer depletion
        for precipitation part
    14. iday: ``int``
        current loop counter

    Returns
    =======

    1. current_day_outouts: `Tuple[np.ndarray]`
        multiple `numpy.ndarray` arrays are returned as a tuple for current day
    """
    
    # Create aliases
    DiffE_ = params['DiffE_']
    DiffR_ = params['DiffR_']
    FW_ = params['FW_']
    FCmax_ = params['FCmax_']
    Foffset_ = params['Foffset_']
    Fslope_ = params['Fslope_']
    Init_RU_ = params['Init_RU_']
    Irrig_auto_ = params['Irrig_auto_']
    Kcmax_ = params['Kcmax_']
    Kslope_ = params['Kslope_']
    Koffset_ = params['Koffset_']
    Kcb_stop_irrig_ = params['Kcb_stop_irrig_']
    REW_ = params['REW_']  # used in eval function to calculate Kri and Krp
    Ze_ = params['Ze_']
    Zsoil_ = params['Zsoil_']
    maxZr_ = params['maxZr_']
    minZr_ = params['minZr_']
    p_ = params['p_']
    
    s_DiffE = params['s_DiffE']
    s_DiffR = params['s_DiffR']
    s_FW = params['s_FW']
    s_FCmax = params['s_FCmax']
    s_Foffset = params['s_Foffset']
    s_Fslope = params['s_Fslope']
    s_Init_RU = params['s_Init_RU']
    s_Kcmax = params['s_Kcmax']
    s_Kslope = params['s_Kslope']
    s_Koffset = params['s_Koffset']
    s_Kcb_stop_irrig = params['s_Kcb_stop_irrig']
    s_REW = params['s_REW']  # used in eval function to calculate Kri and Krp
    s_Ze = params['s_Ze']
    s_Zsoil = params['s_Zsoil']
    s_maxZr = params['s_maxZr']
    s_minZr = params['s_minZr']
    s_p = params['s_p']
    
    # Frequently used parameters
    Ze = s_Ze * Ze_
    FW = s_FW * FW_

    # Update variables
    # Fraction cover
    # * Equation: Fslope * NDVI + Foffset
    FCov = s_Fslope * Fslope_ * NDVI + s_Foffset * Foffset_
    # * Equation: min(max(FCov, 0), FCmax)
    FCov = np.minimum(np.maximum(FCov, 0, dtype = np.float32), s_FCmax * FCmax_, dtype = np.float32)

    # Root depth upate
    # * Equation: Zr = max(minZr + (FCov / FCmax) * (maxZr - minZr), Ze + 0.001)
    Zr = np.maximum(s_minZr * minZr_ + (FCov / (s_FCmax * FCmax_)) * (s_maxZr * maxZr_ - s_minZr * minZr_), Ze + 0.001, dtype = np.float32)

    # Water capacities
    TAW = (Wfc - Wwp) * Zr
    TDW = (Wfc - Wwp) * (s_Zsoil * Zsoil_ - Zr)
    TEW = (Wfc - Wwp/2) * Ze
    RUE = (Wfc - Wwp) * Ze

    # Update depletions from root increase
    if iday == 0:
        Dei = RUE * (1 - s_Init_RU * Init_RU_)
        Dep = RUE * (1 - s_Init_RU * Init_RU_)
        Dr = TAW * (1 - s_Init_RU * Init_RU_)
        Dd = TDW * (1 - s_Init_RU * Init_RU_)
    else:
        Dei = Dei0
        Dep = Dep0
        Dr = update_Dr_from_root(Wfc, Wwp, Zr, s_Zsoil * Zsoil_, Dr0, Dd0, Zr0)
        Dd = update_Dd_from_root(Wfc, Wwp, Zr, s_Zsoil * Zsoil_, Dr0, Dd0, Zr0)
    
    # Kcb
    # * Equation: Kslope * NDVI + Koffset
    Kcb = np.minimum(np.maximum(s_Kslope * Kslope_ * NDVI + s_Koffset * Koffset_, 0, dtype = np.float32), s_Kcmax * Kcmax_, dtype = np.float32)

    # Irrigation 
    if iday == 0:  # First day of simulation
        Irrig = Irrig_auto_ * np.maximum(Dr - Rain, 0, dtype = np.float32)
        # Irrig = np.where((Dr > TAW * s_p * p_) & (Kcb > np.maximum(Kcb, np.float32(1)) * s_Kcb_stop_irrig * Kcb_stop_irrig_), Irrig, np.float32(0))
        Irrig[~(Dr > TAW * s_p * p_) & (Kcb > np.maximum(Kcb, np.float32(1)) * s_Kcb_stop_irrig * Kcb_stop_irrig_)] = np.float32(0)
    else:
        Irrig = calculate_irrig(Dr, TAW, s_p * p_, Rain, Kcb, Irrig_auto_, s_Kcb_stop_irrig * Kcb_stop_irrig_, E0, Tr0)

    # Create temporary variable used multiple times
    temp = np.empty_like(Dr)
    np.subtract(Dr, Rain + Irrig, out = temp)
    
    # DP (Deep percolation)
    DP = - np.minimum(Dd + np.minimum(temp, 0, dtype = np.float32), 0, dtype = np.float32)

    # Update depletions with Rainfall and/or irrigation
    # Dei and Dep
    # * Equation: Dei = min(max(Dei - Rain - Irrig / FW, 0), TEW)
    np.minimum(np.maximum(Dei - Rain - Irrig / FW, 0, dtype = np.float32), TEW, dtype = np.float32, out = Dei)
    # * Equation: Dep = min(max(Dep - Rain - Irrig / FW, 0), TEW)
    np.minimum(np.maximum(Dep - Rain, 0, dtype = np.float32), TEW, dtype = np.float32, out = Dep)

    fewi = np.minimum(1 - FCov, FW, dtype = np.float32)
    fewp = 1 - FCov - fewi

    # De
    # * Equation: De = (Dei * fewi + Dep * fewp) / (fewi + fewp)
    De = np.nansum([Dei * fewi, Dep * fewp], dtype = np.float32, axis = 0) / np.nansum([fewi, fewp], dtype = np.float32, axis = 0)
    # * Equation: De = where(De.isfinite, De, Dei * FW + Dep * (1 - FW))
    De[~np.isfinite(De)] = (Dei * FW + Dep * (1 - FW))[~np.isfinite(De)]

    # Update depletions from rain and irrigation
    np.minimum(np.maximum(temp, 0, dtype = np.float32), TAW, dtype = np.float32, out = Dr)
    np.minimum(np.maximum(Dd + np.minimum(temp, 0, dtype = np.float32), 0, dtype = np.float32), TDW, dtype = np.float32, out = Dd)
    del temp  # remove temp variable

    # Diffusion coefficients
    # * Equation: check calculate_diff_re() and calculate_diff_dr functions
    Diff_rei = calculate_diff_re(TAW, Dr, Zr, RUE, Dei, Wfc, Ze, s_DiffE * DiffE_)
    Diff_rep = calculate_diff_re(TAW, Dr, Zr, RUE, Dep, Wfc, Ze, s_DiffE * DiffE_)
    Diff_dr = calculate_diff_dr(TAW, TDW, Dr, Zr, Dd, Wfc, s_Zsoil * Zsoil_, s_DiffR * DiffR_)

    # Water Stress coefficient
    if iday == 0:
        Ks = np.minimum((TAW - Dr) / (TAW * (1 - s_p * p_)), 1, dtype = np.float32)
    else:
        # When not first day
        Ks = calculate_Ks(Dr, TAW, s_p * p_, E0, Tr0)

    # Reduction coefficient for evaporation
    # Create string expressions that are calculated later, uses less memory
    W = calculate_W(TEW, Dei, Dep, fewi, fewp)

    # * Equation: Kei = np.minimum(W * Kri * (Kc_max - Kcb), fewi * Kc_max)
    Kei = calculate_Ke(W, Dei, TEW, s_REW * REW_, s_Kcmax * Kcmax_, Kcb, fewi)

    # * Equation: Kep = np.minimum((1 - W) * Krp * (Kc_max - Kcb), fewp * Kc_max)
    Kep = calculate_Ke((1-W), Dep, TEW, s_REW * REW_, s_Kcmax * Kcmax_, Kcb, fewp)

    # Prepare coefficients for evapotranspiration
    Tei = calculate_Te(Dei, Dr, Ks, Kcb, s_Ze * Ze_, Zr, TEW, TAW, ET0)
    Tep = calculate_Te(Dep, Dr, Ks, Kcb, s_Ze * Ze_, Zr, TEW, TAW, ET0)

    # Update depletions
    Dei = update_De_from_Diff(Dei, fewi, Kei, Tei, Diff_rei, TEW, ET0)
    Dep = update_De_from_Diff(Dep, fewp, Kep, Tep, Diff_rep, TEW, ET0)

    # * Equation: De = (Dei * fewi + Dep * fewp) / (fewi + fewp)
    np.nansum([Dei * fewi, Dep * fewp], dtype = np.float32, axis = 0, out = De) / np.nansum([fewi, fewp], dtype = np.float32, axis = 0)
    # * Equation: De = where(De.isfinite, De, Dei * FW + Dep * (1 - FW))
    De[~np.isfinite(De)] = (Dei * FW + Dep * (1 - FW))[~np.isfinite(De)]

    # Evaporation
    E = np.maximum((Kei + Kep) * ET0, 0, dtype = np.float32)

    # Transpiration
    Tr = Kcb * Ks * ET0

    # Update depletions (root and deep zones) at the end of the day
    np.minimum(np.maximum(Dr + E + Tr - Diff_dr, 0, dtype = np.float32), TAW, dtype = np.float32, out = Dr)
    np.minimum(np.maximum(Dd + Diff_dr, 0, dtype = np.float32), TDW, dtype = np.float32, out = Dd)

    # Soil water content of evaporative laye
    SWCe = calculate_SWCe(Dei, Dep, fewi, fewp, TEW)

    # Soil water content of root layer
    SWCr = 1 - Dr/TAW

    return DP, Dd, De, Dei, Dep, Diff_dr, Diff_rei, Diff_rep, Dr, E, FCov, Irrig, Kcb, Kei, Kep, Ks, RUE, SWCe, SWCr, TAW, TDW, TEW, Tei, Tep, Tr, W, Zr, fewi, fewp


# @profile  # type: ignore
def run_samir(csv_param_file: str, ndvi_cube_path: str, weather_path: str, soil_params_path: str, land_cover_path: str, save_path: str, scaling_dict: dict = {'E': 1000, 'Tr': 1000, 'SWCe': 1000, 'SWCr': 1000, 'DP': 100, 'Irr': 100, 'ET0': 1000}, additional_outputs: dict = None, available_ram: int = 8, available_cpu: int = 4) -> None:
    """
    Run the *SAMIR* model on the prepared inputs. Calls the ``samir_daily()`` in a time loop.
    Maximizes memory usage with given limits to run faster.

    Arguments
    =========

    1. csv_param_file: ``str``
        SAMIR csv parameter file
    2. ndvi_cube_path: ``str``
        path to ndvi cube
    3. weather_path: ``str``
        path to weather cube
    4. soil_params_path: ``str``
        path to soil dataset
    5. land_cover_path: ``str``
        path to land cover raster
    6. save_path: ``str``
        path to save the model outputs
    7. scaling_dict: ``dict`` ``default = {'E': 1000, 'Tr': 1000, 'SWCe': 1000, 'SWCr': 1000, 'DP': 100, 'Irr': 100, 'ET0': 1000}``
        scaling dictionnary for the nominal outputs
    8. additional_outputs: ``dict`` ``default = None``
        dictionnary containing the names and scale
        factors of potential additional outouts
    9. available_ram: ``int`` ``default = 8``
        available RAM in **GiB** for the model
    10. available_cpu: ``int`` ``default = 4``
        number of available **physical** CPU cores

    Returns
    =======

    ``None``
    """

    # Turn off numpy warings
    np.seterr(divide='ignore', invalid='ignore')
    
    # Test if memory requirement is not loo large
    if np.ceil(virtual_memory().available / (1024**3)) < available_ram:
        print('\nRequested', available_ram, 'GiB of memory when available memory is approximately', round(virtual_memory().available / (1024**3), 1), 'GiB.\n\nExiting script.\n')
        return None

    # Set maximum number of usable CPU cores
    # Get number of CPU cores and limit max value (working on a cluster requires os.sched_getaffinity to get true number of available CPUs, 
    # this is not true on a "personnal" computer, hence the use of the min function)
    nb_cores = min([available_cpu, cpu_count(logical = False), len(os.sched_getaffinity(0))])
    set_num_threads(nb_cores*2)

    # ============ Manage inputs ============ #
    # NDVI (to have a correct empty dataset structure)
    ndvi_cube = xr.open_dataset(ndvi_cube_path)
    ndvi_cube_empty = ndvi_cube.drop_sel(time = ndvi_cube.time)  # create dataset with a time dimension of length 0

    # SAMIR Parameters
    params = rasterize_samir_parameters(csv_param_file, land_cover_path)
    
    # ============ Get size of dataset ============ #
    x_size = ndvi_cube.dims['x']
    y_size = ndvi_cube.dims['y']
    time_size = ndvi_cube.dims['time']
    dimensions = ndvi_cube_empty.dims  # to create empty output dataset
    dates = ndvi_cube.time
    ndvi_cube.close()
    
    # ============ Memory handling ============ #
    # Check how much memory the calculation would take if all the inputs would be loaded in memory
    # Unit is GiB
    # Datatype of variables is float32 for calculation
    nb_bits = 4  # float 32
    nb_inputs = 3  # NDVI, Rain, ET0
    if additional_outputs:
        nb_outputs = 6 + len(additional_outputs)  # DP, E, Irr, SWCe, SWCr, Tr
    else:
        nb_outputs = 6  # DP, E, Irr, SWCe, SWCr, Tr
    nb_variables = 31  # calculation variables (e.g:  Dd, Diff_rei, FCov, etc.)
    security_factor = 0.8  # it is difficult to estimate true memory usage, apply a security factor to prevent memory overload
    
    # Get memory requirement
    total_memory_requirement = calculate_memory_requirement(x_size, y_size, time_size, nb_inputs, nb_outputs, nb_variables, len(params)/2, nb_bits)
    
    # Determine how many time slices can be loaded in memory at once
    # This will allow faster I/O operations and a faster runtime
    time_slice, remainder_to_load, already_loaded = calculate_time_slices_to_load(total_memory_requirement, time_size, security_factor, available_ram)
    
    print_size, print_unit = format_Byte_size(total_memory_requirement)
    print('\nApproximate memory requirement of calculation:', print_size, print_unit, '\nAvailable memory:', available_ram, 'GiB\n\nLoading blocks of', time_slice, 'time bands.\n')
    
    # ============ Prepare outputs ============ #
    model_outputs = prepare_output_dataset(ndvi_cube_path, dimensions, scaling_dict, additional_outputs=additional_outputs)

    # Create encoding dictionnary
    encoding_dict = {}
    for variable in list(model_outputs.keys()):
        # Write encoding dict
        encod = {}
        encod['dtype'] = 'i2'
        encod['chunksizes'] = (1, y_size, x_size)
        encoding_dict[variable] = encod

    # Save empty output
    print('Writing empty output dataset...\n')
    model_outputs.to_netcdf(save_path, encoding = encoding_dict, unlimited_dims = 'time')  # add time as an unlimited dimension, allows to append data along the time dimension
    model_outputs.close()

    # ============ Prepare time iterations ============#

    # Set parameter to None if no additional outputs
    if not additional_outputs:
        additional_outputs_data = None

    # input soil data
    with nc.Dataset(soil_params_path, mode='r') as ds:
        Wfc = np.asarray(ds.variables['Wfc'][:, :], dtype = np.float32)
        Wwp = np.asarray(ds.variables['Wwp'][:, :], dtype = np.float32)

    # ============ Time loop ============ #
    # Create progress bar
    progress_bar = tqdm(total=len(dates), desc='Running model', unit=' days')

    for i in range(0, len(dates)):

        # ============ Load input data and prepare output data ============ #
        # Based on available memory and previous calculation of time slices to load,
        # load a given number of time slices
        if time_slice == time_size and not already_loaded:  # if whole dataset fits in memory and it has not already been loaded
            
            NDVI, Rain, ET0 = read_inputs(ndvi_cube_path, weather_path, i, time_slice, load_all = True)
            already_loaded = True
            write_remainder = False
            
            # Prepare output data
            # Standard outputs
            DP, E, Irrig, SWCe, SWCr, Tr = get_empty_arrays(x_size, y_size, time_slice, 6)
            # Additionnal outputs
            if additional_outputs:
                additional_outputs_data = get_empty_arrays(x_size, y_size, time_slice, len(additional_outputs), list = True)
            
        elif i % time_slice == 0:  # load a time slice every time i is divisible by the size of the time slice
            if i + time_slice <= time_size:  # if the time slice does not gow over the dataset size
                
                NDVI, Rain, ET0 = read_inputs(ndvi_cube_path, weather_path, i, time_slice)
                write_remainder = False
                
                # Prepare output data
                DP, E, Irrig, SWCe, SWCr, Tr = get_empty_arrays(x_size, y_size, time_slice, 6)
                # Additionnal outputs
                if additional_outputs:
                    additional_outputs_data = get_empty_arrays(x_size, y_size, time_slice, len(additional_outputs), list = True)

            else:  # load the remainder when the time slice would go over the dataset size
                
                NDVI, Rain, ET0 = read_inputs(ndvi_cube_path, weather_path, i, remainder_to_load)
                write_remainder = True
                
                # Prepare output data
                DP, E, Irrig, SWCe, SWCr, Tr = get_empty_arrays(x_size, y_size, remainder_to_load, 6)
                # Additionnal outputs
                if additional_outputs:
                    additional_outputs_data = get_empty_arrays(x_size, y_size, remainder_to_load, len(additional_outputs), list = True)

        if i == 0:
            Dei0, Dep0, Dr0, Dd0, E0, Tr0, Zr0 = np.float32(0), np.float32(0), np.float32(0), np.float32(0), np.float32(0), np.float32(0), np.float32(0)
        
        DP[i % time_slice,:,:], Dd, De, Dei, Dep, Diff_dr, Diff_rei, Diff_rep, Dr, E[i % time_slice,:,:], FCov, Irrig[i % time_slice,:,:], Kcb, Kei, Kep, Ks, RUE, SWCe[i % time_slice,:,:], SWCr[i % time_slice,:,:], TAW, TDW, TEW, Tei, Tep, Tr[i % time_slice,:,:], W, Zr, fewi, fewp = samir_daily(NDVI[i % time_slice,:,:], ET0[i % time_slice], Rain[i % time_slice], Wfc, Wwp, params, Dr0, Dd0, Zr0, E0, Tr0, Dei0, Dep0, i)
        
        # Collect additionnal outputs
        if additional_outputs:
            k = 0
            for var in additional_outputs.keys():
                additional_outputs_data[k][i % time_slice,:,:] = eval(var)
                k+=1
        
        # Update previous day values
        Dei0, Dep0, Dr0, Dd0, E0, Tr0, Zr0  = Dei, Dep, Dr, Dd, E[i % time_slice,:,:], Tr[i % time_slice,:,:], Zr
        
        # ============ Write outputs ============ #
        # Based on available memory and previous calculation of time slices to load,
        # Write a given number of time slices
        if time_slice == time_size and i == time_size - 1:  # if whole dataset fits in memory, write data at the end of the loop
            
            write_outputs(save_path, DP, SWCe, SWCr, E, Tr, Irrig, scaling_dict, additional_outputs, additional_outputs_data, i, time_slice, write_all = True)
            
            # Remove written data
            del DP, SWCe, SWCr, E, Tr, Irrig, additional_outputs_data
            additional_outputs_data = None
            
        elif (i % time_slice == time_slice - 1) and (remainder_to_load != None):  # write a time slice every time i is divisible by the size of the time slice
            
            write_outputs(save_path, DP, SWCe, SWCr, E, Tr, Irrig, scaling_dict, additional_outputs, additional_outputs_data, i, time_slice)
            
            # Remove written data
            del DP, SWCe, SWCr, E, Tr, Irrig, additional_outputs_data
            additional_outputs_data = None

        elif i == time_size - 1 and write_remainder:  # write the remainder when the time slice would go over the dataset size
            
            write_outputs(save_path, DP, SWCe, SWCr, E, Tr, Irrig, scaling_dict, additional_outputs, additional_outputs_data, i, remainder_to_load)
            
            # Remove written data
            del DP, SWCe, SWCr, E, Tr, Irrig, additional_outputs_data
            additional_outputs_data = None
        
        # Update progress bar
        progress_bar.update()

    # Close progress bar
    progress_bar.close()
    
    print('\nWritting Output:', save_path)
    
    # # Reopen output model file to reinsert dates  
    with nc.Dataset(save_path, mode = 'a') as model_outputs:
        model_outputs.variables['time'].units = f'days since {np.datetime_as_string(dates[0], unit = "D")} 00:00:00'  # set correct unit
        model_outputs.variables['time'][:] = np.arange(0, len(dates))  # save dates as integers representing the number of days since the first day
        model_outputs.sync() # flush data to disk

    return None