Skip to content
Snippets Groups Projects
calculate_ndvi.py 34.53 KiB
# -*- coding: UTF-8 -*-
# Python
"""
04-07-2023
@author: jeremy auclair

Calculate NDVI images with xarray
"""

import os  # for path management
import csv  # open csv files
from pystac_client import Client  # to send requests to copernicus server
from planetary_computer import sign_inplace  # to get access to landsat data
from odc.stac import load  # to download copernicus data
from odc.geo.geobox import GeoBox  # to create geobox for data download
from fnmatch import fnmatch  # for character string comparison
from typing import List, Union, Tuple  # to declare variables
import xarray as xr  # to manage dataset
import rioxarray  # to manage dataset projections
import pandas as pd  # to manage dataframes
import rasterio as rio  # to open geotiff files
import numpy as np  # vectorized math
import geopandas as gpd  # to manage shapefile crs projections
from rasterio.enums import Resampling  # reprojection algorithms
from datetime import datetime  # manage dates
from dateutil.relativedelta import relativedelta  # date math
from p_tqdm import p_map  # for multiprocessing with progress bars
from dask.distributed import progress  # for simple progress bar with dask
from psutil import cpu_count  # to get number of physical cores available
from modspa_pixel.config.config import config  # to import config file
from modspa_pixel.preprocessing.input_toolbox import product_str_to_datetime, read_product_info, get_band_paths


def download_ndvi_imagery(config_file: str, interp_chunk: dict = {'x': 400, 'y': 400, 'time': -1}, scaling: int = 255, acorvi_corr: int = 0) -> Union[str, Tuple[str, str]]:
    """
    Use the new Copernicus data ecosystem or Planetarycomputer ecosystem to search and download
    clipped and interpolated S2 or LandSat data, calculate NDVI and save it into a precube.

    Arguments
    =========

    1. config_file: ``str``
        path to configuration file
    2. interp_chunk: ``dict`` ``default = {'x': 400, 'y': 400, 'time': -1}``
        dictionnary containing the chunk size for
        the xarray dask ndvi interpolation
    3. scaling: ``int`` ``default = 255``
        integer scaling to save NDVI data as integer values
    4. acorvi_corr: ``int`` ``default = 0``
        acorvi correction parameter to add to the red band
        to adjust ndvi values for THEIA

    Returns
    =======

    ndvi_precube_path: ``str``
        path to save the ndvi pre-cube
    """
    
    # Open config_file
    config_params = config(config_file)
    
    # Load parameters from config file
    preferred_provider = config_params.preferred_provider
    run_name = config_params.run_name
    mode = config_params.mode
    start_date = config_params.start_date
    end_date = config_params.end_date
    shapefile_path = config_params.shapefile_path
    resolution = config_params.resolution
    cloud_cover_limit = config_params.cloud_cover_limit
    
    if preferred_provider == 'copernicus':
        # Set api parameters
        url = 'https://earth-search.aws.element84.com/v1'
        collection = 'sentinel-2-l2a'
        query = {'eo:cloud_cover' : {'lt' : cloud_cover_limit}}
        modifier = None
        
        # Set data parameters
        red, nir, mask_name = 'red', 'nir', 'scl'
        val1, val2 = 4, 5
        
        # Set paths
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'SCIHUB' + os.sep + 'NDVI'
        
    elif preferred_provider == 'usgs':
        # Set api parameters
        url = 'https://planetarycomputer.microsoft.com/api/stac/v1'
        collection = 'landsat-c2-l2'
        query = {'eo:cloud_cover' : {'lt' : cloud_cover_limit}, 'platform': {'in': ['landsat-8', 'landsat-9']}}
        modifier = sign_inplace
        
        # Set data parameters
        red, nir, mask_name = 'red', 'nir08', 'qa_pixel'
        val1, val2 = 21824, 21824
        
        # Set paths
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'USGS' + os.sep + 'NDVI'

    # Search parameters
    bands = [red, nir, mask_name]
    resampling_dict = {red: 'bilinear', nir: 'bilinear', mask_name: 'nearest'}
    
    # Create save paths
    ndvi_cube_path = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_' + 'pre' * (mode == 'parcel') + 'cube_' + start_date + '_' + end_date + '.nc' * (mode == 'pixel') + '.tif' * (mode == 'parcel')
    dates_file = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_precube_' + start_date + '_' + end_date + '_dates.npy'
    
    # Check if file exists and ndvi overwrite is false
    if os.path.exists(ndvi_cube_path) and not config_params.ndvi_overwrite:
        if mode == 'pixel':
            return ndvi_cube_path
        else:
            return ndvi_cube_path, dates_file
    
    # Open shapefile containing geometry
    shapefile = gpd.read_file(shapefile_path)
    bbox = shapefile.to_crs('EPSG:4326').geometry.total_bounds
    
    # Change start and end date to better cover the chosen period
    new_start_date = (datetime.strptime(start_date, '%Y-%m-%d') - relativedelta(months = 1)).strftime('%Y-%m-%d')
    new_end_date = (datetime.strptime(end_date, '%Y-%m-%d') + relativedelta(months = 1)).strftime('%Y-%m-%d')
    
    # Create request
    client = Client.open(url, modifier = modifier)
    search = client.search(collections = [collection], bbox = bbox, datetime = new_start_date + '/' + new_end_date, query = query, max_items = 200)
    
    # Create geobox to better control data load geometry
    geo_box = GeoBox.from_bbox(shapefile.geometry.total_bounds, shapefile.crs, resolution = resolution)
    
    # Get data with required bands
    data = load(search.items(), geobox = geo_box, groupby = 'solar_day', bands = bands, chunks = {}, resampling = resampling_dict)
    
    if preferred_provider == 'usgs':
        # Scale optical bands
        data[[red, nir]] = data[[red, nir]] * 2.75e-5 - 0.2
    
    # Create validity mask
    mask = xr.where((data[mask_name] == val1) | (data[mask_name] == val2), 1, np.NaN)
    
    # Save single red band as reference tif for later use in reprojection algorithms
    ref = data[red].isel(time = 0).rio.write_crs(data.spatial_ref.values)
    ref.rio.to_raster(save_dir + os.sep + run_name + os.sep + run_name + '_grid_reference.tif')
    
    # Calculate NDVI
    ndvi = ((data[nir] - data[red] - acorvi_corr) / (data[nir] + data[red] + acorvi_corr) * mask).to_dataset(name = 'NDVI')

    # Convert timestamp to dates
    ndvi['time'] = pd.to_datetime(pd.to_datetime(ndvi.time.values, format = '%Y-%m-%d').date)
    dates = ndvi.time.values

    # Mask NDVI
    ndvi['NDVI'] = xr.where(ndvi.NDVI > 1, np.NaN, ndvi.NDVI)
    ndvi['NDVI'] = xr.where(ndvi.NDVI < 0, 0, ndvi.NDVI)

    # Save NDVI cube to netcdf
    if mode == 'pixel':
        # Interpolates on a daily frequency
        daily_index = pd.date_range(start = dates[0], end = dates[-1], freq = 'D')

        # Resample the dataset to a daily frequency and reindex with the new DateTimeIndex
        ndvi = ndvi.resample(time = '1D').asfreq().reindex(time = daily_index).chunk(chunks = interp_chunk)
        dates = pd.to_datetime(ndvi.time.values)

        # Interpolate the dataset along the time dimension to fill nan values
        ndvi = ndvi.interpolate_na(dim = 'time', method = 'linear', fill_value = 'extrapolate')
        
        # Remove extra dates, only keep selected window
        ndvi = ndvi.sel({'time': slice(config_params.start_date, config_params.end_date)})
        
        # TODO: understand why dimensions are not in the right order anymore
        # Reorder dimensions
        ndvi = ndvi[['time', 'y', 'x', 'NDVI']]
        
        # Scale ndvi
        ndvi['NDVI'] = (ndvi.NDVI * scaling).round()
            
        # Write attributes
        ndvi['NDVI'].attrs['units'] = 'None'
        ndvi['NDVI'].attrs['standard_name'] = 'NDVI'
        ndvi['NDVI'].attrs['description'] = 'Normalized difference Vegetation Index (of the near infrared and red band). A value of one is a high vegetation presence.'
        ndvi['NDVI'].attrs['scale factor'] = str(scaling)
        
        # Write transform
        ndvi.rio.write_crs(ref.rio.crs, inplace = True)
        ndvi.rio.write_transform(inplace = True)
        ndvi = ndvi.set_coords('spatial_ref')
        
        # Save NDVI cube to netcdf
        if ndvi.dims['y'] > 1000 and ndvi.dims['x'] > 1000:
            file_chunksize = (1, interp_chunk['y'], interp_chunk['x'])
        else:
            file_chunksize = (1, ndvi.dims['y'], ndvi.dims['x'])
        
        write_job = ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": file_chunksize}}, compute = False)
        write_job = write_job.persist()
        
        progress(write_job)
    
        return ndvi_cube_path
        
    else:
        # Drop dates with only NaNs
        ndvi = ndvi.dropna(dim = 'time', how = 'all')

        # Scale
        ndvi['NDVI'] = ((ndvi.NDVI + 1/scaling) * (scaling - 1)).round()
        
        # Save dates as string formats in numpy file
        dates = pd.to_datetime(ndvi.time.values).strftime('%Y-%m-%d')
        np.save(dates_file, dates, allow_pickle = True)

        # Write crs
        ndvi.rio.write_crs(ref.rio.crs, inplace = True)

        # Write nodata
        ndvi.NDVI.rio.write_nodata(0, inplace = True)
        
        # Replace NaNs with 0
        ndvi = ndvi.fillna(0)

        # Save ndvi cube to multiband geotiff
        ndvi.NDVI.rio.to_raster(ndvi_cube_path, dtype = np.uint8)
        
        return ndvi_cube_path, dates_file


def calculate_ndvi(extracted_paths: Union[List[str], str], config_file: str, calc_chunk: dict = {'x': 400, 'y': 400, 'time': -1}, interp_chunk: dict = {'x': 400, 'y': 400, 'time': -1}, scaling: int = 255, acorvi_corr: int = 0) -> str:
    """
    Calculate ndvi images in a xarray dataset, interpolate is to a daily time step (a data cube) and save it.
    ndvi values are scaled and saved as ``uint8`` (0 to 255).
    
    .. warning:: only 10 and 20 meters currently supported

    Arguments
    =========

    1. extracted_paths: ``Union[List[str], str]``
        list of paths to extracted sentinel-2 products
        or path to ``csv``` file containing those paths
    2. config_file: ``str``
        path to configuration file
    3. calc_chunk: ``dict`` ``default = {'x': 400, 'y': 400, 'time': -1}``
        dictionnary containing the chunk size for
        the xarray dask ndvi calculation
    4. interp_chunk: ``dict`` ``default = {'x': 400, 'y': 400, 'time': -1}``
        dictionnary containing the chunk size for
        the xarray dask ndvi interpolation
    5. scaling: ``int`` ``default = 255``
        integer scaling to save NDVI data as integer values
    6. acorvi_corr: ``int`` ``default = 0``
        acorvi correction parameter to add to the red band
        to adjust ndvi values

    Returns
    =======

    1. ndvi_path: ``str``
        path to save the ndvi cube
    """
    
    # Open config_file
    config_params = config(config_file)
    print('\nCalculating NDVI cube and saving it...\n')
    
    # Load parameters from config file
    boundary_shapefile_path = config_params.shapefile_path
    start_date = config_params.start_date
    end_date = config_params.end_date
    mode = config_params.mode
    run_name = config_params.run_name
    resolution = config_params.resolution
    preferred_provider = config_params.preferred_provider
    if preferred_provider == 'copernicus':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'SCIHUB' + os.sep + 'NDVI'
    elif preferred_provider == 'theia':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'THEIA' + os.sep + 'NDVI'
    elif preferred_provider == 'usgs':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'USGS' + os.sep + 'NDVI'

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
    if type(extracted_paths) == str:
        with open(extracted_paths, 'r') as file:
            extracted_paths = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
                extracted_paths.append(row[0])
    
    # Get list of paths to red, nir and mask images
    red_paths, nir_paths, mask_paths = get_band_paths(config_params, extracted_paths)
    dates = [product_str_to_datetime(prod) for prod in red_paths]
    
    # Get crs
    with rio.open(red_paths[0]) as temp:
        crs = temp.crs
    
    # Open shapefile to clip the dataset
    shapefile = gpd.read_file(boundary_shapefile_path)
    shapefile = shapefile.to_crs(crs)
    bounds = shapefile.total_bounds
    
    # Open datasets with xarray
    red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = calc_chunk, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).sortby('y', ascending = False)
    nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = calc_chunk, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'}).sortby('y', ascending = False)
    mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = calc_chunk, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'}).sortby('y', ascending = False)
    
    # Get masking condition and resolution management based on provider
    if preferred_provider == 'copernicus':
        
        # Original resolution in meters
        base_resolution = 10
        
        # Select bands
        red = red.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
        nir = nir.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
        mask = mask.interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest')
            
        mask = xr.where((mask == 4) | (mask == 5), 1, np.NaN).astype(np.float32)
    
    elif preferred_provider == 'theia':
        
        # Original resolution in meters
        base_resolution = 10
        
        # Select bands
        red = red.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
        nir = nir.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
        mask = mask.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))

        red = xr.where(red == -10000, np.nan, red)
        nir = xr.where(nir == -10000, np.nan, nir)
        mask = xr.where((mask == 0), 1, np.NaN).astype(np.float32)
    
    # Rescale optical data for LandSat
    elif preferred_provider == 'usgs':
        
        # Original resolution in meters
        base_resolution = 30
        
        # Select bands
        red = red.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1])) * 2.75e-5 - 0.2
        nir = nir.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1])) * 2.75e-5 - 0.2
        mask = mask.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
        
        mask = xr.where((mask == 21824), 1, np.NaN).astype(np.float32)

    # Set time coordinate
    red['time'] = pd.to_datetime(dates)
    nir['time'] = pd.to_datetime(dates)
    mask['time'] = pd.to_datetime(dates)
    
    # Sort by date
    red = red.sortby(variables = 'time')
    nir = nir.sortby(variables = 'time')
    mask = mask.sortby(variables = 'time')
    dates.sort()
    
    # Save single red band as reference tif for later use in reprojection algorithms
    ref = xr.open_dataset(red_paths[0]).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).sortby('y', ascending = False).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1])).rio.write_crs(crs)
    ref = ref.rio.reproject(ref.rio.crs, shape = (int(ref.sizes['y'] * base_resolution / resolution), int(ref.sizes['x'] * base_resolution / resolution)), resampling = Resampling.bilinear, nodata = np.NaN)
    
    # Recalculate transform
    ref.rio.write_transform(inplace = True)
    ref = ref.set_coords('spatial_ref')
    ref.rio.to_raster(save_dir + os.sep + run_name + os.sep + run_name + '_grid_reference.tif')
    
    # Create save path
    ndvi_cube_path = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_' + 'pre' * (mode == 'parcel') + 'cube_' + start_date + '_' + end_date + '.nc' * (mode == 'pixel') + '.tif' * (mode == 'parcel')
    dates_file = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_precube_' + start_date + '_' + end_date + '_dates.npy'
    
    # Check if file exists and ndvi overwrite is false
    if os.path.exists(ndvi_cube_path) and not config_params.ndvi_overwrite:
        
        if mode == 'pixel':
            return ndvi_cube_path
        
        else:
            return ndvi_cube_path, dates_file

    # Change mask resolution
    if resolution != base_resolution:
        mask = mask.interp(x = ref.coords['x'], y = ref.coords['y'], method = 'nearest')
    
    # Create ndvi dataset and calculate ndvi
    ndvi = ((nir.nir - red.red - acorvi_corr) / (nir.nir + red.red + acorvi_corr)).to_dataset(name = 'NDVI')
    
    # Reproject NDVI
    if resolution != base_resolution:
        ndvi = ndvi.interp(x = ref.coords['x'], y = ref.coords['y'], method = 'linear')
    
    # Mask NDVI
    ndvi['NDVI'] = ndvi.NDVI * mask.mask
    ndvi['NDVI'] = xr.where(ndvi.NDVI > 1, np.NaN, ndvi.NDVI)
    ndvi['NDVI'] = xr.where(ndvi.NDVI < 0, 0, ndvi.NDVI)
    del red, nir, mask
    
    # Drop dates with only NaNs
    ndvi = ndvi.dropna(dim = 'time', how = 'all')
    
    # Recalculate transform
    ndvi.rio.write_transform(inplace = True)
    ndvi = ndvi.set_coords('spatial_ref')
    
    if mode == 'pixel':
    
        # Interpolates on a daily frequency
        daily_index = pd.date_range(start = dates[0], end = dates[-1], freq = 'D')

        # Resample the dataset to a daily frequency and reindex with the new DateTimeIndex
        ndvi = ndvi.resample(time = '1D').asfreq().reindex(time = daily_index).chunk(chunks = interp_chunk)
        dates = pd.to_datetime(ndvi.time.values)

        # Interpolate the dataset along the time dimension to fill nan values
        ndvi = ndvi.interpolate_na(dim = 'time', method = 'linear', fill_value = 'extrapolate')
        
        # Remove extra dates, only keep selected window
        ndvi = ndvi.sel({'time': slice(config_params.start_date, config_params.end_date)})
        
        # TODO: understand why dimensions are not in the right order anymore
        # Reorder dimensions
        ndvi = ndvi[['time', 'y', 'x', 'NDVI']]
        
        # Scale ndvi
        ndvi['NDVI'] = (ndvi.NDVI * scaling).round()
            
        # Write attributes
        ndvi['NDVI'].attrs['units'] = 'None'
        ndvi['NDVI'].attrs['standard_name'] = 'NDVI'
        ndvi['NDVI'].attrs['description'] = 'Normalized difference Vegetation Index (of the near infrared and red band). A value of one is a high vegetation presence.'
        ndvi['NDVI'].attrs['scale factor'] = str(scaling)
        
        # Save NDVI cube to netcdf
        if ndvi.dims['y'] > 1000 and ndvi.dims['x'] > 1000:
            file_chunksize = (1, interp_chunk['y'], interp_chunk['x'])
        else:
            file_chunksize = (1, ndvi.dims['y'], ndvi.dims['x'])
        
        write_job = ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": file_chunksize}}, compute = False)
        write_job = write_job.persist()
        
        progress(write_job)
        
        ndvi.close()
        
        return ndvi_cube_path

    else:
        # Scale ndvi
        ndvi['NDVI'] = (ndvi.NDVI * scaling).round()
        
        # Write attributes
        ndvi['NDVI'].attrs['units'] = 'None'
        ndvi['NDVI'].attrs['standard_name'] = 'NDVI'
        ndvi['NDVI'].attrs['description'] = 'Normalized difference Vegetation Index (of the near infrared and red band). A value of one is a high vegetation presence.'
        ndvi['NDVI'].attrs['scale factor'] = str(scaling)
    
        # Save dates as string formats in numpy file
        dates = pd.to_datetime(ndvi.time.values).strftime('%Y-%m-%d')
        np.save(dates_file, dates, allow_pickle = True)
        
        # Write crs
        ndvi.rio.write_crs(crs, inplace = True)

        # Write nodata
        ndvi.NDVI.rio.write_nodata(0, inplace = True)
        
        # Replace NaNs with 0
        ndvi = ndvi.fillna(0)

        # Save ndvi cube to multiband geotiff
        write_job = ndvi.NDVI.rio.to_raster(ndvi_cube_path, dtype = np.uint8, compute = False)
        write_job = write_job.persist()
        
        progress(write_job)
        
        return ndvi_cube_path, dates_file


def interpolate_ndvi(ndvi_path: str, config_file: str, chunk_size: dict = {'x': 400, 'y': 400, 'time': -1}, scaling: int = 255,) -> str:
    """
    Interpolate the ndvi cube to a daily frequency between the
    desired dates defined in the ``json`` config file. The extra
    month of data downloaded is used for a better interpolation,
    it is then discarded and the final NDVI cube has the dates
    defined in the config file.

    Arguments
    =========

    1. ndvi_path: ``str``
        path to ndvi pre-cube
    2. config_file: ``str``
        path to ``json`` config file
    3. chunk_size: ``dict`` ``default = {'x': 400, 'y': 400, 'time': -1}``
        chunk size to use by dask for calculation,
        ``'time' = -1`` means the chunk has the whole
        time dimension in it. The Dataset can't be
        divided along the time axis for interpolation.
    4. scaling: ``int`` ``default = 255``
        integer scaling to save NDVI data as integer values

    Returns
    =======

    ``None``
    """
    
    # Open config_file
    config_params = config(config_file)
    
    # Load parameters from config file
    run_name = config_params.run_name
    if config_params.preferred_provider == 'copernicus':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'SCIHUB' + os.sep + 'NDVI'
    elif config_params.preferred_provider == 'theia':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'THEIA' + os.sep + 'NDVI'
    elif config_params.preferred_provider == 'usgs':
        save_dir = config_params.data_path + os.sep + 'IMAGERY' + os.sep + 'USGS' + os.sep + 'NDVI'
    
    # Create save path
    ndvi_cube_path = save_dir + os.sep + run_name + os.sep + run_name + '_NDVI_cube_' + config_params.start_date + '_' + config_params.end_date + '.nc'
    
    # Check if file exists and ndvi overwrite is false
    if os.path.exists(ndvi_cube_path) and not config_params.ndvi_overwrite:
        
        return ndvi_cube_path
    
    # Open NDVI pre-cube
    ndvi = xr.open_dataset(ndvi_path, chunks = chunk_size)
    
    # Get dates of NDVI precube
    dates = ndvi.time.values
    
    # Get Spatial reference
    spatial_ref = ndvi.spatial_ref.load()
    
    # Interpolates on a daily frequency
    daily_index = pd.date_range(start = dates[0], end = dates[-1], freq = 'D')

    # Resample the dataset to a daily frequency and reindex with the new DateTimeIndex
    ndvi = ndvi.resample(time = '1D').asfreq().reindex(time = daily_index)
    dates = pd.to_datetime(ndvi.time.values)

    # Interpolate the dataset along the time dimension to fill nan values
    ndvi = ndvi.interpolate_na(dim = 'time', method = 'linear', fill_value = 'extrapolate').round(decimals = 0)
    
    # Remove extra dates, only keep selected window
    ndvi = ndvi.sel({'time': slice(config_params.start_date, config_params.end_date)})
    
    # Rescale data
    ndvi = (ndvi * (scaling / (scaling - 1)) - 1).round()
    
    # Set negative values as 0
    ndvi = xr.where(ndvi < 0, 0, ndvi.NDVI)
    # Set values above 255 (ndvi > 1) as 255 (ndvi = 1)
    ndvi = xr.where(ndvi > scaling, scaling, ndvi.NDVI)
    
    # Rewrite spatial reference
    ndvi['spatial_ref'] = spatial_ref
    
    # Set file chunk size
    if ndvi.dims['y'] > 1000 and ndvi.dims['x'] > 1000:
        file_chunksize = (1, chunk_size['y'], chunk_size['x'])
    else:
        file_chunksize = (1, ndvi.dims['y'], ndvi.dims['x'])
    
    # Save NDVI cube to netcdf
    write_job = ndvi.to_netcdf(ndvi_cube_path, encoding = {"NDVI": {"dtype": "u1", "_FillValue": 0, "chunksizes": file_chunksize}}, compute = False)
    write_job = write_job.persist()
        
    progress(write_job)
            
    ndvi.close()
    
    return ndvi_cube_path


def write_geotiff(path: str, data: np.ndarray, transform: tuple, projection: str, scaling: int = 255) -> None:
    """
    Writes a GeoTiff image using ``rasterio``. Takes an array and georeferencing data (obtained by ``rasterio``
    on an image with same georeferencing) to write a well formatted image. Data format: ``uint8``

    Arguments
    =========

    1. path: ``str``
        true path of file to save data in (path = path + save name)
    2. data: ``np.ndarray``
        numpy array containging the data
    3. geotransform: ``tuple``
        tuple containing the geo-transform data
    4. projection: ``str``
        string containing the epsg projection data
    5. scaling: ``int`` ``default = 255``
        scaling factor for saving NDVI images as integer arrays

    Returns
    =======

    ``None``
    """

    # Apply scaling
    data = np.round((data + np.float32(1/scaling)) * (scaling - 1), 1)

    # Write NDVI image with rasterio
    with rio.open(path, mode = "w", driver = "GTiff", height = data.shape[0], width = data.shape[1], count = 1, dtype = np.uint8, crs = projection, transform = transform, nodata = 0) as new_dataset:
        new_dataset.write(data, 1)

    return None


def calculate_ndvi_image(args: tuple) -> str:
    """
    Opens zip file of the given product to extract the red, near-infrared and mask bands (without extracting
    the whole archive) and calculate the ndvi, apply a mask to keep only clear land data (no clouds, no snow,
    no water, no unclassified or erroneous pixel). This array is then saved in a GeoTiff image with
    the parent name file and ``'_NDVI.tif'`` extension. If overwrite is false, already existing files won't be
    calculated and saved again (default value is ``false``).

    This function is called in the calculate_ndvi function as a subprocess to allow multiprocessing. Variables
    that have served their purpose are directly deleted to reduce memory usage.

    Arguments (packed in args: ``tuple``)
    =====================================
    
    1. product_path: ``str``
        path to the product to extract for ndvi calculation
    2. save_dir: ``str``
        directory in which to save ndvi images
    3. shapefile_path: ``str``
        path to the shapefile (``.shp``) for which the data is calculated. Used to clip
        satellite imagery
    4. overwrite: ``bool``
        boolean to choose to rewrite already existing files
    4. ACORVI_correction: ``int``
        correction parameter to apply on the red band for stability of NDVI values

    Returns
    =======

    1. ndvi_path: ``str``
        path to the saved ndvi image
    """

    # Unpack arguments
    product_path, save_dir, shapefile_path, overwrite, ACORVI_correction = args
    
    # Check provider
    _, _, provider, _ = read_product_info(product_path) 

    # To ignore numpy warning when dividing by NaN
    np.seterr(invalid='ignore', divide='ignore')

    # Define offset for copernicus data
    offset = 0

    # Create file name
    file_name = os.path.basename(product_path)
    # file_name = os.path.splitext(file_name)[0]
    save_name = save_dir + os.sep + file_name + '_NDVI.tif'

    # Check if file is already written. If override is false, no need to calculate and write ndvi again.
    # If override is true: ndvi is calculated and saved again.
    if not os.path.exists(save_name) or overwrite:  # File does not exist or overwrite is true
        pass
    else:
        # Collect save path for return
        return save_name

    # Look for desired images in the directories
    if provider == 'copernicus':

        for f in os.listdir(product_path):
            if fnmatch(f, '*_B04_10m.jp2'):  # Red band
                red_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_B08_10m.jp2'):  # Near infrared band
                nir_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_SCL_20m.jp2'):  # Scene classification for mask
                classif_file = os.path.join(product_path, f)
    
    elif provider == 'theia':
        
        for f in os.listdir(product_path):
            if fnmatch(f, '*_FRE_B4.tif'):  # Red band
                red_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_FRE_B8.tif'):  # Near infrared band
                nir_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_MG2_R1.tif'):  # Scene classification for mask
                classif_file = os.path.join(product_path, f)
    
    elif provider == 'usgs':
        
        for f in os.listdir(product_path):
            if fnmatch(f, '*_SR_B4.TIF'):  # Red band
                red_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_SR_B5.TIF'):  # Near infrared band
                nir_file = os.path.join(product_path, f)
            elif fnmatch(f, '*_QA_PIXEL.TIF'):  # Scene classification for mask
                classif_file = os.path.join(product_path, f)
          

    # Read bands, geometry and projection information
    red_band = xr.open_dataarray(red_file)  # read red band
    projection = red_band.rio.crs
    del red_file
    
    # Open shapefile to clip the dataset
    shapefile = gpd.read_file(shapefile_path)
    shapefile = shapefile.to_crs(projection)
    bounds = shapefile.total_bounds
    del shapefile

    red_band = red_band.sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))
    transorm = red_band.rio.transform(recalc = True)
    nir_band = xr.open_dataarray(nir_file).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))  # read nir band
    del nir_file

    # Read array data
    red_data = red_band.values + ACORVI_correction
    nir_data = nir_band.values
    nir_band.close()
    
    # Adjust data for specific providers
    if provider == 'theia':
        red_data = np.where((red_data == -10000), np.float32(np.NaN), red_data)
        nir_data = np.where((nir_data == -10000), np.float32(np.NaN), nir_data)
    elif provider == 'usgs':
        red_data = red_data * 2.75e-5 - 0.2
        nir_data = nir_data * 2.75e-5 - 0.2

    # Calculate ndvi
    ndvi_data = np.divide((nir_data - red_data), (nir_data + red_data + 2*offset), dtype = np.float32)
    del red_data, nir_data

    # Read classif data
    classif_band = xr.open_dataarray(classif_file).sel(x = slice(bounds[0], bounds[2]), y = slice(bounds[3], bounds[1]))  # read classif band
    del classif_file

    # Adjust mask based on provider
    if provider == 'copernicus':
        classif_band = xr.where((classif_band == 4) | (classif_band == 5), 1, np.NaN).interp(x = red_band.coords['x'], y = red_band.coords['y'], method = 'nearest')
    elif provider == 'theia':
        classif_band = xr.where((classif_band == 0), 1, np.NaN)
    elif provider == 'usgs':
        classif_band = xr.where((classif_band == 21824), 1, np.NaN)
        
    # Read array data
    binary_mask = classif_band.values
    classif_band.close()
    red_band.close()

    # Apply mask
    np.multiply(ndvi_data, binary_mask, out=ndvi_data, dtype = np.float32)
    del binary_mask

    # Clip out of range data
    np.minimum(ndvi_data, 1, out = ndvi_data, dtype = np.float32)
    np.maximum(ndvi_data, 0, out = ndvi_data, dtype = np.float32)

    # Write image
    write_geotiff(save_name, ndvi_data[0], transorm, projection)
    del ndvi_data, transorm, projection

    return save_name


def calculate_ndvi_parcel(ndvi_path: Union[List[str], str], save_dir: str, save_path: str, shapefile_path: str, overwrite: bool = False, max_cpu: int = 4, ACORVI_correction: int = 0) -> List[str]:
    """
    Opens the red, near infrared and scene classification to calculate the ndvi, apply a mask to keep only clear land data (no clouds, no snow,
    no water, no unclassified or erroneous pixel). This array is then saved in a GeoTiff image with
    the parrent name file and ``'_NDVI.tif'`` extension. If overwrite is false, already existing fils won't be
    calculated and saved again (default value is ``false``).

    This function calls the calculate_ndvi_image function as a subprocess to allow multiprocessing. A modified
    version of the ``tqdm`` module (``p_tqdm``) is used for progress bars.

    Arguments
    =========
    
    1. ndvi_path: ``list[str]`` or ``str``
        list of paths to the products to extract for ndvi calculation or path to a ``csv`` file that contains this list
    2. save_dir: ``str``
        directory in which to save ndvi images
    3. save_path : ``str``
        path to a csv file containing the paths to the saved ndvi images
    4. shapefile_path: ``str``
        path to the shapefile (``.shp``) for which the data is calculated. Used to clip
        satellite imagery
    5. overwrite: ``bool`` ``default = False``
        boolean to choose to rewrite already existing files
    5. max_cpu: ``int`` `default = 4`
        max number of CPUs that the pool can use, if max_cpu is bigger than available CPUs, the pool only uses availables CPUs
    6. ACORVI_correction: ``int`` ``default = 500``
        correction parameter to apply on the red band for stability of NDVI values

    Returns
    =======
    
    1. ndvi_path: ``list[str]``
        list of paths to the saved ndvi images
    """

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
    if type(ndvi_path) == str:
        with open(ndvi_path, 'r') as file:
            ndvi_path = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
                ndvi_path.append(row[0])

    ndvi_path = []  # Where saved image paths will be stored

    # Prepare arguments for multiprocessing
    args = [(product, save_dir, shapefile_path, overwrite, ACORVI_correction) for product in ndvi_path]

    # Get number of CPU cores and limit max value (working on a cluster requires os.sched_getaffinity to get true number of available CPUs, 
    # this is not true on a "personnal" computer, hence the use of the min function)
    try:
        nb_cores = min([max_cpu, cpu_count(logical = False), len(os.sched_getaffinity(0))])
    except:
        nb_cores = min([max_cpu, cpu_count(logical = False)])  # os.sched_getaffinity won't work on windows

    print('\nStarting NDVI calculations with %d cores for %d images...\n' %(nb_cores, len(ndvi_path)))

    # Start pool and get results
    results = p_map(calculate_ndvi_image, args, **{"num_cpus": nb_cores})

    # Collect results and sort them
    for result in results:
        ndvi_path.append(result)
    ndvi_path.sort()

    # Save ndvi paths as a csv file
    with open(save_path, 'w', newline='') as f:
        # using csv.writer method from CSV package
        write = csv.writer(f)

        for ndvi_image in ndvi_path:
            write.writerow([ndvi_image])

    return ndvi_path