Skip to content
Snippets Groups Projects
lib_era5_land_pixel.py 39 KiB
Newer Older
"""
Functions to call ECMWF Reanalysis with CDS-api

- ERA5-land daily request
- request a list of hourly variables dedicated to the calculus of ET0
and the generation of MODSPA daily forcing files

heavily modified from @rivallandv's original file
import os  # for path exploration and file management
import numpy as np  # for math on arrays
import xarray as xr  # to manage nc files
from datetime import datetime  # to manage dates
from fnmatch import fnmatch  # for file name matching
import pandas as pd  # to manage dataframes
import rasterio  as rio  # to manage geotiff images
import geopandas as gpd  # to manage shapefile crs projections
from rasterio.mask import mask  # to mask images
from rasterio.enums import Resampling  # reprojection algorithms
import netCDF4 as nc  # to write netcdf4 files
from tqdm import tqdm  # to follow progress
from multiprocessing import Pool, Manager  # to parallelize functions
from psutil import virtual_memory  # to check available ram
from psutil import cpu_count  # to get number of physical cores available
from modspa_pixel.config.config import config  # to import config file
from modspa_pixel.source.modspa_samir import calculate_time_slices_to_load  # to optimise I/O operations
import warnings  # to suppress pandas warning

# CDS API external library
# source: https://pypi.org/project/cdsapi/
import cdsapi  # to download cds data

# FAO ET0 calculator external library
# Notes
# source: https://github.com/Evapotranspiration/ETo
# documentation: https://eto.readthedocs.io/en/latest/
import eto  # to calculate ET0


def era5_enclosing_shp_aera(bbox: list[float], pas: float) -> tuple[float, float, float, float]:
    Find the four coordinates including the boxbound scene
    to agree with gridsize resolution
    system projection: WGS84 lat/lon degree
        bounding box of the demanded area
        list of floats: [lat north, lon west, lat south, lon east] in degree WGS84
    2. pas: ``float``
        gridsize
        
    Returns
    =======
    
    1. era5_area: ``tuple[float, float, float, float]``
        coordinates list corresponding to N,W,S,E corners of the grid in decimal degree
        
    .. note:: 
    
        gdal coordinates reference upper left corner of pixel, ERA5 coordinates refere to center of grid. To resolve this difference an offset of pas/2 is applied
    lat_max, lon_min, lat_min, lon_max = bbox[3], bbox[0], bbox[1], bbox[2]
    
    era5_lat_max = round((lat_max // pas + 1) * pas, 2)
    era5_lon_min = round((lon_min // pas) * pas, 2)
    era5_lat_min = round((lat_min // pas) * pas, 2)
    era5_lon_max = round((lon_max // pas + 1) * pas, 2)
    era5_area = [era5_lat_max, era5_lon_min, era5_lat_min, era5_lon_max]
def split_dates_by_year(start_date: str, end_date: str) -> list[tuple[str, str]] | list:
    Given a start and end date, returns tuples of start and end dates IN THE SAME YEAR.
    1. start_date: ``str``
        start date in YYYY-MM-DD format
    2. end_date: ``str``
        end date in YYYY-MM-DD format

    1. dates: ``list[tuple[str, str]] | list``
        output tuples of start and end dates
    start = datetime.strptime(start_date, '%Y-%m-%d')
    end = datetime.strptime(end_date, '%Y-%m-%d')

    if start.year == end.year:
        return [(start_date, end_date)]
    dates = []
    current_start = start
    while current_start.year <= end.year:
        if current_start.year == end.year:
            current_end = end
        else:
            current_end = datetime(current_start.year, 12, 31)
        dates.append((current_start.strftime('%Y-%m-%d'), current_end.strftime('%Y-%m-%d')))
        current_start = datetime(current_start.year + 1, 1, 1)
def call_era5landhourly(args: tuple) -> None:
    Download weather data for the given variable. Arguments are packed in a tuple for multiprocessing.
    1. variable: ``str``
        name of ER5-Land weather variable
    2. output_path: ``str``
        output path to download netcdf file
    3. start_date: ``str``
        start date in YYYY-MM-DD format
        (start and end date must be in the same
        year to reduce data to download)
    4. end_date: ``str``
        end date in YYYY-MM-DD format
        (start and end date must be in the same
        year to reduce data to download)
    5. bbox: ``list[float, float, float, float]``
        bounding box of area to download data
    6. gridsize: ``float`` ``default = 0.1``
        gridsize of data to download
    1. output_filename: ``str``
        output file name
    """     
    variable, output_path, start_date, end_date, bbox, gridsize = args
    # full path name of the output file
    output_filename = os.path.join(output_path, 'ERA5-land_' + variable + '_' + start_date + '_' + end_date + '.nc')
    # Get time periods for download
    start_date = datetime.strptime(start_date, '%Y-%m-%d')
    end_date = datetime.strptime(end_date, '%Y-%m-%d')
    # Generate time inputs
    months = []
    current = start_date
    while current <= end_date:
        month_str = current.strftime('%m')
        if month_str not in months:
            months.append(month_str)
        # Move to the next month
        if current.month == 12:
            current = current.replace(year=current.year + 1, month=1)
        else:
            current = current.replace(month=current.month + 1)
    # Generate the list of days
    days = [f'{day:02}' for day in range(1, 32)]
    # Generate time
    time = [f'{hour:02}:00' for hour in range(0, 24)]
    # Get modified bbox
    area = era5_enclosing_shp_aera(bbox, gridsize)
    # Check if file already exists
    if os.path.isfile(output_filename):
        print('\n', output_filename, 'already exist !\n')
    else:
        # cds api request
        client = cdsapi.Client(timeout = 300)
        try:
            client.retrieve('reanalysis-era5-single-levels',
                        request = {
                            'product_type': ['reanalysis'],
                            'variable': [variable],
                            'year': [start_date.strftime(format = '%Y')],
                            'month': months,
                            'day': days,
                            'time': time,
                            'data_format': 'netcdf',
                            'download_format': 'unarchived',
                            'area': area,
                            'grid': [gridsize, gridsize],
                        },
                        target = output_filename)
            print('\n', output_filename, ' downloaded !\n')
            
        except Exception as e:
            print('\nRequest failed, error message:\n\n', e, '\n')
            
    return output_filename


def uz_to_u2(u_z: list[float], h: float) -> list[float]:
    """
    The wind speed measured at heights other than 2 m can be adjusted according
    to the follow equation

    Arguments
    ----------
    u_z : TYPE float array
        measured wind speed z m above the ground surface, ms- 1.
    h : TYPE float
        height of the measurement above the ground surface, m.

    Returns
    -------
    u2 : TYPE float array
        average daily wind speed in meters per second (ms- 1 ) measured at 2 m above the ground.
    """

    return u_z * 4.87/(np.log(67.8 * h - 5.42))


def ea_calc(T: float) -> float:
    """
    comments
    Actual vapour pressure (ea) derived from dewpoint temperature '
    
    Arguments
    ----------
    T : Temperature in degree celsius.

    Returns
    -------
    e_a :the actual Vapour pressure in Kpa
    """
    return 0.6108 * np.exp(17.27 * T / (T + 237.15))
def combine_weather2netcdf(rain_file: str, ET0_tile: str, ndvi_path: str, save_path: str, available_ram: int) -> None:
    Convert the Rain and ET0 geotiffs into a single weather netcdf dataset.

    Arguments
    =========

    1. rain_file: ``str``
    5. available_ram: ``int``
    
    # Open tif files
    rain_tif = rio.open(rain_file)
    ET0_tif = rio.open(ET0_tile)
    
    # Open ndvi netcdf to get structure
    ndvi = xr.open_dataset(ndvi_path)
Jeremy Auclair's avatar
Jeremy Auclair committed
    dates = ndvi.time
    
    # Get empty dimensions
    dimensions = ndvi.drop_sel(time = ndvi.time).sizes  # create dataset with a time dimension of length 0
    weather = ndvi.drop_vars(['NDVI']).copy(deep = True)
Jeremy Auclair's avatar
Jeremy Auclair committed
    weather = weather.drop_sel(time = weather.time)
    # Set dataset attributes
    weather.attrs['name'] = 'ModSpa Pixel weather dataset'
    weather.attrs['description'] = 'Weather variables (Rain and ET0) for the ModSpa SAMIR (FAO-56) model at the pixel scale. Variables are scaled to be stored as integers.'
    weather.attrs['scaling'] = "{'Rain': 100, 'ET0': 1000}"
    
    # Set variable attributes
Jeremy Auclair's avatar
Jeremy Auclair committed
    weather['Rain'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.uint16))
    weather['Rain'].attrs['units'] = 'mm'
    weather['Rain'].attrs['standard_name'] = 'total_precipitation'
    weather['Rain'].attrs['description'] = 'Accumulated daily precipitation in mm'
    weather['Rain'].attrs['scale factor'] = '100'
Jeremy Auclair's avatar
Jeremy Auclair committed
    weather['ET0'] = (dimensions, np.zeros(tuple(dimensions[d] for d in list(dimensions)), dtype = np.uint16))
    weather['ET0'].attrs['standard_name'] = 'evapotranspiration'
    weather['ET0'].attrs['description'] = 'Accumulated daily reference evapotranspiration in mm'
    weather['ET0'].attrs['scale factor'] = '1000'
    
    # Create encoding dictionnary
    for variable in list(weather.keys()):
        # Write encoding dict
        encoding_dict = {}
        encod = {}
        encod['dtype'] = 'u2'
        # TODO: figure out optimal file chunk size
Jeremy Auclair's avatar
Jeremy Auclair committed
        file_chunksize = (1, dimensions['y'], dimensions['x'])
        # TODO: check if compression affects reading speed
        encoding_dict[variable] = encod

    # Save empty output
Jeremy Auclair's avatar
Jeremy Auclair committed
    print('\nWriting empty weather dataset')
    weather.to_netcdf(save_path, encoding = encoding_dict, unlimited_dims = 'time')
    weather.close()

    # Get geotiff dimensions (time, x, y)
    dims = (rain_tif.count, rain_tif.height, rain_tif.width)
    
    # Determine the memory requirement of operation
    nb_vars = 1  # one variable written at a time
    memory_requirement = ((dims[0] * dims[1] * dims[2]) * nb_vars * nb_bytes) / (1024**3)  # in GiB
    
    # Get the number of time bands that can be loaded at once
    time_slice, remainder, already_written = calculate_time_slices_to_load(dims[2], dims[1], dims[0], nb_vars, 0, 0, 0, nb_bytes, available_ram)
    print('\nApproximate memory requirement of conversion:', round(memory_requirement, 3), 'GiB\nAvailable memory:', available_ram, 'GiB\n\nLoading blocks of', time_slice, 'time bands.\n')
    
    # Open empty dataset
    weather = nc.Dataset(save_path, mode = 'r+')
    
    # Create progress bar
    progress_bar = tqdm(total = dims[0], desc='Writing weather data', unit=' bands')

    # Data variables
        if time_slice == dims[0] and not already_written:  # if whole dataset fits in memory and it has not already been loaded
            
            weather.variables['Rain'][:,:,:] = rain_tif.read()
            weather.variables['ET0'][:,:,:] = ET0_tif.read()
        elif i % time_slice == 0 and not already_written:  # load a time slice every time i is divisible by the size of the time slice
            if i + time_slice <= dims[0]:  # if the time slice does not gow over the dataset size
                
                weather.variables['Rain'][i: i + time_slice, :, :] = rain_tif.read(tuple(k+1 for k in range(i, i + time_slice)))
                weather.variables['ET0'][i: i + time_slice, :, :] = ET0_tif.read(tuple(k+1 for k in range(i, i + time_slice)))
            
            else:  # load the remainder when the time slice would go over the dataset size
                
                weather.variables['Rain'][i: i + remainder, :, :] = rain_tif.read(tuple(k+1 for k in range(i, i + remainder)))
                weather.variables['ET0'][i: i + remainder, :, :] = ET0_tif.read(tuple(k+1 for k in range(i, i + remainder)))
        
        progress_bar.update()
Jeremy Auclair's avatar
Jeremy Auclair committed

    # Write dates in weather dataset
    weather.variables['time'].units = f'days since {np.datetime_as_string(dates[0], unit = "D")} 00:00:00'  # set correct unit
    weather.variables['time'][:] = np.arange(0, len(dates))  # save dates as integers representing the number of days since the first day
    weather.sync() # flush data to disk
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Close progress bar
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Close datasets
def calculate_ET0_pixel(input_dataset: xr.Dataset, h: float = 10, safran: bool = False) -> xr.DataArray:
    """
    Calculate ET0 over the year for a single pixel of the ERA5 weather dataset.

    1. input_dataset: ``xr.Dataset``
        extracted dataset chunked to contain only one spatial pixel
    2. h: ``float`` ``default = 10``
        height of ERA5 wind measurement in meters
        boolean to adapt to a custom SAFRAN weather dataset
    Returns
    =======

    1. ET0_values: ``np.ndarray``
        numpy array containing the ET0 values for each day
    """
    
    # Adapt dataset structure
    if safran:
        lat, lon = input_dataset.coords['y'].values, input_dataset.coords['x'].values
        pixel_dataset = input_dataset.squeeze(dim = ['x', 'y'], drop = True)
    else:
        lat, lon = input_dataset.coords['lat'].values, input_dataset.coords['lon'].values
        pixel_dataset = input_dataset.squeeze(dim = ['lat', 'lon'], drop = True)
    
    # Conversion of xarray dataset to dataframe for ET0 calculation
    ET0 = pixel_dataset.t2m.resample(time = '1D').min().to_dataframe().rename(columns = {'t2m' : 'T_min'}) - 273.15  # conversion of temperatures from K to °C
    ET0['T_max'] = pixel_dataset.t2m.resample(time = '1D').max().to_dataframe()['t2m'].values - 273.15  # conversion of temperatures from K to °C
    
    ET0['R_s'] = pixel_dataset.ssrd.resample(time = '1D').sum().to_dataframe()['ssrd'].values / 1e6  # to convert downward total radiation from J/m² to MJ/m²
    # Calculate relative humidity
    pixel_dataset['ea'] = ea_calc(pixel_dataset.t2m - 273.15)
    pixel_dataset['es'] = ea_calc(pixel_dataset.d2m - 273.15)
    pixel_dataset['rh'] = np.clip(100.*(pixel_dataset.es / pixel_dataset.ea), a_min = 0, a_max = 100)
    
    ET0['RH_max'] = pixel_dataset.rh.resample(time = '1D').max().to_dataframe()['rh'].values
    ET0['RH_min'] = pixel_dataset.rh.resample(time = '1D').min().to_dataframe()['rh'].values
    
    if safran:
        # Add wind
        ET0['U_z'] = pixel_dataset.U_z.to_dataframe()['U_z'].values
    
    else:
        # Conversion of eastward and northward wind values to scalar wind
        pixel_dataset['uz'] = np.sqrt(pixel_dataset.u10 ** 2 + pixel_dataset.v10 ** 2)
        ET0['U_z'] =  pixel_dataset.uz.resample(time = '1D').mean().to_dataframe()['uz'].values
    # Start ET0 calculation
    eto_calc = eto.ETo()
    warnings.filterwarnings('ignore')  # remove pandas warning

    # ET0 calculation for given pixel (lat, lon) values
    eto_calc.param_est(ET0,
                        freq = 'D',  # daily frequence
                        # Elevation of the met station above mean sea level (m) (only needed if P is not in df).
                        z_msl = 0.,
                        lat = lat,
                        lon = lon,
                        TZ_lon = None,
                        z_u = h)  # h: height of raw wind speed measurement

    # Retrieve ET0 values
    ET0_values = np.reshape(eto_calc.eto_fao(max_ETo = 15, min_ETo = 0, interp = True, maxgap = 10).values, [len(ET0.index), 1, 1])  # ETo_FAO_mm
    output_coords = input_dataset.resample(time = '1D').sum().coords
    output_dims = input_dataset.resample(time = '1D').sum().sizes
    
    output = xr.DataArray(data = ET0_values, coords = output_coords, dims = output_dims, name = 'ET0')
    
    return output
def convert_interleave_mode(args: tuple[str, str, bool]) -> None:
    """
    Convert Geotiff files obtained from OTB to Band interleave mode for faster band reading.

    Arguments
    =========
    
    (packed in args: ``tuple``)
    
    1. input_image: ``str``
    2. output_image: ``str``
    3. remove: ``bool`` ``default = True``
    """
    
    input_image, output_image, remove = args
    
    # Open the input file in read mode
    with rio.open(input_image, "r") as src:

        # Open the output file in write mode
        with rio.open(output_image, 'w', driver = src.driver, height = src.height, width = src.width, count = src.count, dtype = src.dtypes[0], crs = src.crs, transform = src.transform, interleave = 'BAND',) as dst:

            # Loop over the blocks or windows of the input file
            for _, window in src.block_windows(1):

                # Write the data to the output file
                dst.write(src.read(window = window), window = window)
    
    # Remove unecessary image
    if remove:
        os.remove(input_image)
    
    return None


def era5Land_daily_to_yearly_pixel(weather_files: list[str], variables: list[str], output_file: str, raw_S2_image_ref: str, ndvi_path: str, start_date: str, end_date: str, h: float = 10, max_ram: int = 8, use_OTB: bool = False, weather_overwrite: bool = False, safran: bool = False) -> str:
    """
    Calculate ET0 values from the ERA5 netcdf weather variables.
    Output netcdf contains the ET0 and precipitation values for
    each day in the selected time period and reprojected on the
    same grid as the NDVI values.
    1. weather_file: ``str``
        path to netCDF raw weather files
    2. variables: ``list[str]``
        list of variables downloaded from era5
    3. output_file: ``str``
        output file name without extension
        raw Sentinel 2 image at right resolution for reprojection
        path to ndvi dataset, used for attributes and coordinates
        beginning of the time window to download (format: ``YYYY-MM-DD``)
        end of the time window to download (format: ``YYYY-MM-DD``)
        height of ERA5 wind measurements in meters
    9. max_ram: ``int`` ``default = 8``
        max ram (in GiB) for reprojection and conversion. Two
        subprocesses are spawned for OTB, each receiviving 
        half of requested memory.
    10. use_OTB: ``bool`` ``default = False``
        boolean to choose to use OTB or not, tests will be added later
    11. weather_overwrite: ``bool`` ``default = False``
        boolean to choose to overwrite weather netCDF
    12. safran: ``bool`` ``default = False``
        boolean to adapt to a custom SAFRAN weather dataset
    1. output_file_final: ``str``
        path to ``netCDF4`` file containing precipitation and ET0 data
    # Test if file exists
    if os.path.exists(output_file + '.nc') and not weather_overwrite:
        return output_file + '.nc'
    
    # Test if memory requirement is not loo large
    if np.ceil(virtual_memory().available / (1024**3)) < max_ram:
        print('\nRequested', max_ram, 'GiB of memory when available memory is approximately', round(virtual_memory().available / (1024**3), 1), 'GiB.\n\nExiting script.\n')
        return None
    
    # Load all weather files in a single dataset
    raw_weather_ds = xr.Dataset()
    for var in variables:
        temp = []
        for file in weather_files:
            if fnmatch(file, '*' + var + '*'):
                temp.append(file)
        raw_weather_ds = xr.merge([raw_weather_ds, xr.open_mfdataset(temp).drop_vars(['number', 'expver']).rename({'valid_time': 'time', 'latitude': 'lat', 'longitude': 'lon'})])
    # Clip extra dates
    raw_weather_ds = raw_weather_ds.sel({'time': slice(start_date, end_date)}).sortby(variables = 'time')
    resampled_weather_ds = raw_weather_ds.resample(time = '1D').sum()
    # Create ET0 variable (that will be saved) and set attributes 
    resampled_weather_ds = resampled_weather_ds.assign(ET0 = (resampled_weather_ds.sizes, np.zeros(tuple(resampled_weather_ds.sizes[d] for d in list(resampled_weather_ds.sizes)), dtype = np.float32)))
        # Chunk weather dataset
        raw_weather_ds = raw_weather_ds.chunk({'time': -1, 'y': 1, 'x': 1})
        
        # Apply ET0 function
        resampled_weather_ds['ET0'] = raw_weather_ds.map_blocks(calculate_ET0_pixel, args = (h, safran), template = resampled_weather_ds.ET0.chunk({'time': -1, 'y': 1, 'x': 1}))
        final_weather_ds = resampled_weather_ds.drop_vars(names = ['ssrd', 't2m', 'd2m', 'RH_max', 'RH_min', 'U_z'])  # remove unwanted variables
        # Chunk weather dataset
        raw_weather_ds = raw_weather_ds.chunk({'time': -1, 'lon': 1, 'lat': 1})
        
        # Apply ET0 function
        resampled_weather_ds['ET0'] = raw_weather_ds.map_blocks(calculate_ET0_pixel, args = (h, safran), template = resampled_weather_ds.ET0.chunk({'time': -1, 'lon': 1, 'lat': 1}))
        final_weather_ds = resampled_weather_ds.drop_vars(names = ['ssrd', 'v10', 'u10', 't2m', 'd2m'])  # remove unwanted variables
        
    # Scale data and rewrite netcdf attributes
    final_weather_ds['tp'] = final_weather_ds['tp'] * 1000  # conversion from m to mm
    
    # Change datatype to reduce memory usage
    final_weather_ds['tp'] = (final_weather_ds['tp']  * 100).astype('u2').chunk(chunks = {"time": 1})
    final_weather_ds['ET0'] = (final_weather_ds['ET0']  * 1000).astype('u2').chunk(chunks = {"time": 1})
    # Write projection
    final_weather_ds.rio.write_crs('EPSG:4326', inplace = True)
    final_weather_ds['ET0'].attrs['standard_name'] = 'Potential evapotranspiration'
    final_weather_ds['ET0'].attrs['comment'] = 'Potential evapotranspiration accumulated over the day, calculated with the FAO-56 method (scale factor = 1000)'

    final_weather_ds['tp'].attrs['standard_name'] = 'Precipitation'
    final_weather_ds['tp'].attrs['comment'] = 'Volume of total daily precipitation expressed as water height in milimeters (scale factor = 100)'
    # TODO: find how to test OTB installation from python
    if use_OTB:
        # Save dataset to geotiff, still in wgs84 (lat, lon) coordinates
        output_file_rain = output_file + '_rain.tif'
        output_file_ET0 = output_file + '_ET0.tif'
        final_weather_ds.tp.rio.to_raster(output_file_rain, dtype = 'uint16')
        final_weather_ds.ET0.rio.to_raster(output_file_ET0, dtype = 'uint16')
        
        # Reprojected image paths
        output_file_rain_reproj = output_file + '_rain_reproj.tif'
        output_file_ET0_reproj = output_file + '_ET0_reproj.tif'
        
        # Converted image paths
        output_file_final = output_file + '.nc'
        
        # otbcli_SuperImpose commands
        OTB_command_reproj1 = 'otbcli_Superimpose -inr ' + raw_S2_image_ref + ' -inm ' + output_file_rain + ' -out ' + output_file_rain_reproj + ' uint16 -interpolator linear -ram ' + str(int(max_ram * 1024/2))
        OTB_command_reproj2 = 'otbcli_Superimpose -inr ' + raw_S2_image_ref + ' -inm ' + output_file_ET0 + ' -out ' + output_file_ET0_reproj + ' uint16 -interpolator linear -ram ' + str(int(max_ram * 1024/2))
        commands_reproj = [OTB_command_reproj1, OTB_command_reproj2]
        
        with Pool(2) as p:
            p.map(os.system, commands_reproj)
        
        # Combine to netCDF file
        combine_weather2netcdf(output_file_rain_reproj, output_file_ET0_reproj, ndvi_path, output_file_final, available_ram = max_ram)
            
        # remove old files and rename outputs
        os.remove(output_file_rain)
        os.remove(output_file_ET0)
        os.remove(output_file_rain_reproj)
        os.remove(output_file_ET0_reproj)
    else:
        # Set dataset attributes
        final_weather_ds.attrs['name'] = 'ModSpa Pixel weather dataset'
        final_weather_ds.attrs['description'] = 'Weather variables (Rain and ET0) for the ModSpa SAMIR (FAO-56) model at the pixel scale. Variables are scaled to be stored as integers.'
        final_weather_ds.attrs['scaling'] = "{'Rain': 100, 'ET0': 1000}"
        
        # Set file names
        output_file_final = output_file + '.nc'
        
        # Open reference image
        ref = rioxarray.open_rasterio(raw_S2_image_ref)
        # Get metadata
        target_crs = ref.rio.crs
        spatial_ref = ref.spatial_ref.load()
        
        # Define ressources
        mem_limit = min([int(np.ceil(len(ref.x) * len(ref.y) * len(final_weather_ds.time) * len(final_weather_ds.data_vars) * np.dtype(np.float32).itemsize / (1024 ** 2)) * 1.1), 0.8 * virtual_memory().available / (1024**2), max_ram * 1024])
        nb_threads = min([cpu_count(logical = True), len(os.sched_getaffinity(0))])
        
        # Reproject
        final_weather_ds = final_weather_ds.rio.reproject(target_crs, transform = ref.rio.transform(), shape = (ref.rio.height, ref.rio.width), resampling = Resampling.bilinear, num_threads = nb_threads, warp_mem_limit = mem_limit)
        # Rename
        final_weather_ds = final_weather_ds.rename({'tp': 'Rain'})
        
        # Create encoding dictionnary
        for variable in list(final_weather_ds.keys()):
            # Write encoding dict
            encod = {}
            encod['dtype'] = 'u2'
            if '_FillValue' in final_weather_ds[variable].attrs:
                del final_weather_ds[variable].attrs['_FillValue']
            encod['_FillValue'] = 0
            # TODO: figure out optimal file chunk size
            file_chunksize = (1, final_weather_ds.sizes['y'], final_weather_ds.sizes['x'])
            encod['chunksizes'] = file_chunksize
            # TODO: check if compression affects reading speed
            encod['zlib'] = True
            encod['complevel'] = 1
            final_weather_ds[variable].encoding.update(encod)
            
        # Rewrite georeferencing
        final_weather_ds.rio.write_crs(target_crs, inplace = True)
        final_weather_ds['spatial_ref'] = spatial_ref
        final_weather_ds.attrs['crs'] = final_weather_ds.rio.crs.to_string()
        final_weather_ds = final_weather_ds.set_coords('spatial_ref')

        # Save empty output
        final_weather_ds.to_netcdf(output_file_final)
        final_weather_ds.close()
def era5Land_daily_to_yearly_parcel(weather_files: list[str], variables: list[str], output_file: str, start_date: str, end_date: str, h: float = 10) -> str:
    """
    Calculate ET0 values from the ERA5 netcdf weather variables.
    Output netcdf contains the ET0 and precipitation values for
    each day in the selected time period.

    Arguments
    =========

    2. variables: ``list[str]``
        list of variables downloaded from era5
    3. output_file: ``str``
    3. start_date: ``str``
        beginning of the time window to download (format: ``YYYY-MM-DD``)
    4. end_date: ``str``
        end of the time window to download (format: ``YYYY-MM-DD``)
    5. h: ``float`` ``default = 10``
        height of ERA5 wind measurements in meters

    Returns
    =======

    1. output_file_rain: ``str``
        path to ``Geotiff`` file containing precipitation data
    2. output_file_ET0: ``str``
        path to ``Geotiff`` file containing ET0 data
    """
    
    # Load all weather files in a single dataset
    raw_weather_ds = xr.Dataset()
    for var in variables:
        temp = []
        for file in weather_files:
            if fnmatch(file, '*' + var + '*'):
                temp.append(file)
        raw_weather_ds = xr.merge([raw_weather_ds, xr.open_mfdataset(temp).drop_vars(['number', 'expver']).rename({'valid_time': 'time', 'latitude': 'lat', 'longitude': 'lon'})])
    raw_weather_ds = raw_weather_ds.sel({'time': slice(start_date, end_date)}).sortby(variables = 'time')
    resampled_weather_ds = raw_weather_ds.resample(time = '1D').sum()
    # Create ET0 variable (that will be saved) and set attributes 
    resampled_weather_ds = resampled_weather_ds.assign(ET0 = (resampled_weather_ds.sizes, np.zeros(tuple(resampled_weather_ds.sizes[d] for d in list(resampled_weather_ds.sizes)), dtype = np.float32)))

    # Loop on lattitude and longitude coordinates to calculate ET0 per "pixel"
    # Chunk weather dataset
    raw_weather_ds = raw_weather_ds.chunk({'time': -1, 'lon': 1, 'lat': 1})
    # Apply ET0 function
    resampled_weather_ds['ET0'] = raw_weather_ds.map_blocks(calculate_ET0_pixel, args = (h, False), template = resampled_weather_ds.ET0.chunk({'time': -1, 'lon': 1, 'lat': 1}))

    # Get necessary data for final dataset and rewrite netcdf attributes
    final_weather_ds = resampled_weather_ds.drop_vars(names = ['ssrd', 'v10', 'u10', 't2m', 'd2m'])  # remove unwanted variables
    final_weather_ds['tp'] = final_weather_ds['tp'] * 1000  # conversion from m to mm
    
    # Change datatype to reduce memory usage
    final_weather_ds['tp'] = (final_weather_ds['tp']  * 100).astype('u2').chunk(chunks = {"time": 1})
    final_weather_ds['ET0'] = (final_weather_ds['ET0']  * 1000).astype('u2').chunk(chunks = {"time": 1})
    
    # Write projection
    final_weather_ds = final_weather_ds.rio.write_crs('EPSG:4326')
    
    # Set variable attributes 
    final_weather_ds['ET0'].attrs['standard_name'] = 'Potential evapotranspiration'
    final_weather_ds['ET0'].attrs['comment'] = 'Potential evapotranspiration accumulated over the day, calculated with the FAO-56 method (scale factor = 1000)'

    final_weather_ds['tp'].attrs['standard_name'] = 'Precipitation'
    final_weather_ds['tp'].attrs['comment'] = 'Volume of total daily precipitation expressed as water height in milimeters (scale factor = 100)'

    # Save dataset to geotiff, still in wgs84 (lat, lon) coordinates
    output_file_rain = output_file + '_rain.tif'
    output_file_ET0 = output_file + '_ET0.tif'
    final_weather_ds.tp.rio.to_raster(output_file_rain, dtype = 'uint16')
    final_weather_ds.ET0.rio.to_raster(output_file_ET0, dtype = 'uint16')

    return output_file_rain, output_file_ET0


def extract_rasterstats(args: tuple) -> list[float]:
    Generate a dataframe for a given raster and a geopandas shapefile object. 
    It iterates over the features of the shapefile geometry (polygons). This
    information is stored in a list.
    It returns a list that contains the raster values, a feature ``id``
    and the date for the image and every polygon in the shapefile geometry.
    It also has identification data relative to the shapefile: landcover (``LC``),
    land cover identifier (``id``) This list is returned to be later agregated
    in a ``DataFrame``.

    This function is used to allow multiprocessing for weather extraction.
    
    Arguments (packed in args: ``tuple``)
    =====================================

        list containing weather values and feature information for every
        polygon in the shapefile
    """
    
    # Open arguments packed in args
    
    # Open config file
    config_params = config(config_file)
    
    # Create dataframe where zonal statistics will be stored
    
    # Get dates
    dates = pd.to_datetime(pd.date_range(start = config_params.start_date, end = config_params.end_date, freq = 'D')).values

    # Open ndvi image and shapefile geometry

    # Loop on the individual polygons in the shapefile geometry
    for index, row in shapefile.iterrows():
        
        # Get the feature geometry as a shapely object
        geom = row.geometry
        
        # id number of the current parcel geometry
        id = index + 1
        
        # Get land cover
        LC = row.LC
        
        # Crop the raster using the bounding box
        try:
            masked_raster, _ = mask(raster_dataset, [geom], crop = True)
        except:
            print('\nShapefile bounds are not contained in weather dataset bounds.\n\nExiting script.')
            return None
        
        # Mask the raster using the geometry
        masked_raster, _ = mask(raster_dataset, [geom], crop = True, all_touched = True)
        
        # Replace the nodata values with nan
        masked_raster = masked_raster.astype(np.float32)
        masked_raster[masked_raster == nodata] = np.nan
        mean = np.nanmean(masked_raster, axis = (1,2))
        
        # Add statistics to output list
        raster_stats.extend([[dates[i], id, mean[i], LC] for i in range(nbands)])
def init_worker(shared_value_, lock_) -> None:
    """
    Function to initialize the pool workers with shared value and lock.

    Arguments
    =========

    1. shared_value_: ``float``
        shared progress bar value
    2. lock_: ``Lock``
        lock to access shared value
    """
    
    global shared_value, lock
    shared_value = shared_value_
    lock = lock_


def divide_geodataframe(gdf: gpd.GeoDataFrame, n: int) -> list[gpd.GeoDataFrame]:
    """
    Divide geodataframes into n equal parts.

    Arguments
    =========

    1. gdf: ``gpd.GeoDataFrame``
        input geodataframe
    2. n: ``int``
        number of parts to divide into

    Returns
    =======

    1. divided_gdfs: ``list[gpd.GeoDataFrame]``
        output geodataframes
    """
    
    # Calculate the size of each part
    part_size = len(gdf) // n
    remainder = len(gdf) % n

    # Create a list to store the divided GeoDataFrames
    divided_gdfs = []

    start_idx = 0
    for i in range(n):
        end_idx = start_idx + part_size + (1 if i < remainder else 0)
        divided_gdfs.append(gdf.iloc[start_idx:end_idx])
        start_idx = end_idx

    return divided_gdfs


def extract_weather_dataframe(rain_path: str, ET0_path: str, shapefile: str, config_file: str, save_path: str, max_cpu: int = 4) -> None:
    """
    Extract a weather dataframe for each variable (Rain, ET0) and merge them in one
    dataframe. This dataframe is saved as ``csv`` file.

    Arguments
    =========

    1. rain_path: ``str``
    2. ET0_path: ``str``
    3. shapefile: ``str``
    4. config_file: ``str``
    5. save_path: ``str``
    6. max_cpu: ``int`` ``default = 4``
        max number of CPU cores to use for calculation
    print(f'\nStarting weather data extraction on {max_cpu} cores..\n')
    # Shared value and lock for controlling access to the value
    manager = Manager()
    shared_value = manager.Value('i', 0)
    lock = manager.Lock()
    
    # Get target epsg
    with rio.open(rain_path, mode = 'r') as src:
        target_epsg = src.crs
    
    # Total iterations (assuming each extract_rasterstats call represents one iteration)
    shapefile = gpd.read_file(shapefile)
    shapefile['geometry'] = shapefile['geometry'].to_crs(target_epsg)
    total_iterations = len(shapefile.index)
    
    args1 = [(rain_path, smaller_shapefile, config_file) for smaller_shapefile in divide_geodataframe(shapefile, max_cpu)]
    args2 = [(ET0_path, smaller_shapefile, config_file) for smaller_shapefile in divide_geodataframe(shapefile, max_cpu)]
    
    # # Generate arguments for multiprocessing
    # args = [(rain_path, shapefile, config_file), (ET0_path, shapefile, config_file)]
    args = args1 + args2

    # Create and initialize the pool
    pool = Pool(processes = 2 * max_cpu, initializer = init_worker, initargs = (shared_value, lock))
    
    # Progress bar
    with tqdm(desc = 'Extracting zonal statistics', total = total_iterations, unit = ' polygons', dynamic_ncols = True) as pbar:
        # Start the worker processes
        results = [pool.apply_async(extract_rasterstats, args=(arg,)) for arg in args]

        while shared_value.value < total_iterations: