# -*- coding: UTF-8 -*- # Python """ 04-07-2023 @author: jeremy auclair Calculate NDVI images with xarray """ import os # for path exploration import csv # open csv files from fnmatch import fnmatch # for character string comparison from typing import List, Union # to declare variables import xarray as xr # to manage dataset import pandas as pd # to manage dataframes # import dask.array as da # dask xarray from code.toolbox import product_str_to_datetime def calculate_ndvi(extracted_paths: Union[List[str], str], save_dir: str, resolution: int = 20, chunk_size: dict = {'x': 4000, 'y': 4000, 'time': 2}, acorvi_corr: int = 500) -> str: # Check resolution for Sentinel-2 if not resolution in [10, 20]: print('Resolution should be equal to 10 or 20 meters for sentinel-2') return None # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths if type(extracted_paths) == str: with open(extracted_paths, 'r') as file: extracted_paths = [] csvreader = csv.reader(file, delimiter='\n') for row in csvreader: extracted_paths.append(row[0]) # Sort by band red_paths = [] nir_paths = [] mask_paths = [] if resolution == 10: for product in extracted_paths: if fnmatch(product, '*_B04_10m*'): red_paths.append(product) elif fnmatch(product, '*_B8A_20m*'): nir_paths.append(product) elif fnmatch(product, '*_SCL_20m*'): mask_paths.append(product) else: for product in extracted_paths: if fnmatch(product, '*_B04_10m*'): red_paths.append(product) elif fnmatch(product, '*_B08_10m*'): nir_paths.append(product) elif fnmatch(product, '*_SCL_20m*'): mask_paths.append(product) # Sort and get dates red_paths.sort() nir_paths.sort() mask_paths.sort() dates = [product_str_to_datetime(prod) for prod in red_paths] # Open datasets with xarray red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).astype('f4') nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'}).astype('f4') mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'}).astype('f4') if resolution == 10: mask = xr.where((mask == 4) | (mask == 5), 1, 0).interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest') else: mask = xr.where((mask == 4) | (mask == 5), 1, 0) # Set time coordinate red['time'] = pd.to_datetime(dates) nir['time'] = pd.to_datetime(dates) mask['time'] = pd.to_datetime(dates) # Create ndvi dataset and calculate ndvi ndvi = red ndvi = ndvi.drop('red') ndvi['ndvi'] = (((nir.nir - red.red - acorvi_corr)/(nir.nir + red.red + acorvi_corr))*mask.mask) del red, nir, mask # Mask and scale ndvi ndvi['ndvi'] = xr.where(ndvi.ndvi < 0, 0, ndvi.ndvi) ndvi['ndvi'] = xr.where(ndvi.ndvi > 1, 1, ndvi.ndvi) ndvi['ndvi'] = ndvi.ndvi*255 # Create save path ndvi_cube_path = save_dir + os.sep + 'NDVI_precube_' + dates[0].strftime('%d-%m-%Y') + '_' + dates[-1].strftime('%d-%m-%Y') + '.nc' # Save NDVI cude to netcdf ndvi.to_netcdf(ndvi_cube_path, encoding = {"ndvi": {"dtype": "u8", "_FillValue": 0}}) ndvi.close() return ndvi_cube_path