Skip to content
Snippets Groups Projects
calculate_ndvi.py 4.58 KiB
Newer Older
# -*- coding: UTF-8 -*-
# Python
"""
04-07-2023
@author: jeremy auclair

Calculate NDVI images with xarray
"""

import os  # for path exploration
import csv  # open csv files
Jeremy Auclair's avatar
Jeremy Auclair committed
from fnmatch import fnmatch  # for character string comparison
from typing import List, Union  # to declare variables
import xarray as xr  # to manage dataset
Jeremy Auclair's avatar
Jeremy Auclair committed
import pandas as pd  # to manage dataframes
import rasterio as rio  # to open geotiff files
import geopandas as gpd  # to manage shapefile crs projections
from shapely.geometry import box  # to create boundary box
from input.input_toolbox import product_str_to_datetime
def calculate_ndvi(extracted_paths: Union[List[str], str], save_dir: str, boundary_shapefile_path: str, resolution: int = 20, chunk_size: dict = {'x': 4000, 'y': 4000, 'time': 8}, acorvi_corr: int = 500) -> str:
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Check resolution for Sentinel-2
    if not resolution in [10, 20]:
        print('Resolution should be equal to 10 or 20 meters for sentinel-2')
        return None

    # If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
Jeremy Auclair's avatar
Jeremy Auclair committed
    if type(extracted_paths) == str:
        with open(extracted_paths, 'r') as file:
            extracted_paths = []
            csvreader = csv.reader(file, delimiter='\n')
            for row in csvreader:
Jeremy Auclair's avatar
Jeremy Auclair committed
                extracted_paths.append(row[0])
    
    # Sort by band
    red_paths = []
    nir_paths = []
    mask_paths = []
    if resolution == 10:
        for product in extracted_paths:
            if fnmatch(product, '*_B04_10m*'):
                red_paths.append(product)
Jeremy Auclair's avatar
Jeremy Auclair committed
                nir_paths.append(product)
            elif fnmatch(product, '*_SCL_20m*'):
                mask_paths.append(product)
    else:
        for product in extracted_paths:
Jeremy Auclair's avatar
Jeremy Auclair committed
                red_paths.append(product)
Jeremy Auclair's avatar
Jeremy Auclair committed
                nir_paths.append(product)
            elif fnmatch(product, '*_SCL_20m*'):
                mask_paths.append(product)
    
    # Create boundary shapefile from Sentinel-2 image for weather download
    ra = rio.open(red_paths[0])
    bounds  = ra.bounds
    geom = box(*bounds)
    df = gpd.GeoDataFrame({"id":1,"geometry":[geom]})
    df.crs = ra.crs
    df.geometry = df.geometry.to_crs('epsg:4326')
    df.to_file(boundary_shapefile_path)
    del df, ra, geom, bounds
Jeremy Auclair's avatar
Jeremy Auclair committed

    # Sort and get dates
    red_paths.sort()
    nir_paths.sort()
    mask_paths.sort()
    dates = [product_str_to_datetime(prod) for prod in red_paths]
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Open datasets with xarray
    red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).astype('f4')
    nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'}).astype('f4')
    mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size, parallel = True).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'}).astype('f4')
Jeremy Auclair's avatar
Jeremy Auclair committed
    if resolution == 10:
        mask = xr.where((mask == 4) | (mask == 5), 1, 0).interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest')
    else:
        mask = xr.where((mask == 4) | (mask == 5), 1, 0)

    # Set time coordinate
    red['time'] = pd.to_datetime(dates)
    nir['time'] = pd.to_datetime(dates)
    mask['time'] = pd.to_datetime(dates)

    # Create ndvi dataset and calculate ndvi
    ndvi = red
    ndvi = ndvi.drop('red')
    ndvi['ndvi'] = (((nir.nir - red.red - acorvi_corr)/(nir.nir + red.red + acorvi_corr))*mask.mask)
    del red, nir, mask
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Mask and scale ndvi
    ndvi['ndvi'] = xr.where(ndvi.ndvi < 0, 0, ndvi.ndvi)
    ndvi['ndvi'] = xr.where(ndvi.ndvi > 1, 1, ndvi.ndvi)
    ndvi['ndvi'] = (ndvi.ndvi*255) #.astype('u1')
    
    # Write attributes
    ndvi['ndvi'].attrs['long_name'] = 'Normalized Difference Vegetation Index'
    ndvi['ndvi'].attrs['units'] = 'None'
    ndvi['ndvi'].attrs['standard_name'] = 'NDVI'
    ndvi['ndvi'].attrs['comment'] = 'Normalized difference of the near infrared and red band. A value of one is a high vegetation presence.'
Jeremy Auclair's avatar
Jeremy Auclair committed
    # Create save path
    ndvi_cube_path = save_dir + os.sep + 'NDVI_precube_' + dates[0].strftime('%d-%m-%Y') + '_' + dates[-1].strftime('%d-%m-%Y') + '.nc'
    # Save NDVI cube to netcdf
    ndvi.to_netcdf(ndvi_cube_path, encoding = {"ndvi": {"dtype": "u1", "_FillValue": 0}})
Jeremy Auclair's avatar
Jeremy Auclair committed
    ndvi.close()