Skip to content
Snippets Groups Projects
Commit 0bd7068e authored by Jeremy Auclair's avatar Jeremy Auclair
Browse files

Few new additions

parent 7cfcd030
No related branches found
No related tags found
No related merge requests found
......@@ -9,30 +9,87 @@ Calculate NDVI images with xarray
import os # for path exploration
import csv # open csv files
from fnmatch import fnmatch # for character string comparison
from typing import List, Union # to declare variables
import xarray as xr # to manage dataset
import dask.array as da # dask xarray
from dask.distributed import Client, LocalCluster # to parallelise calculations
import pandas as pd # to manage dataframes
# import dask.array as da # dask xarray
from code.toolbox import product_str_to_datetime
def calculate_ndvi(download_path: Union[List[str], str]) -> str:
def calculate_ndvi(extracted_paths: Union[List[str], str], save_dir: str, resolution: int = 20, chunk_size: dict = {'x': 4000, 'y': 4000, 'time': 2}, acorvi_corr: int = 500) -> str:
# Start local cluster client
client = Client()
client
# Check resolution for Sentinel-2
if not resolution in [10, 20]:
print('Resolution should be equal to 10 or 20 meters for sentinel-2')
return None
# If a file name is provided instead of a list of paths, load the csv file that contains the list of paths
if type(download_path) == str:
with open(download_path, 'r') as file:
download_path = []
if type(extracted_paths) == str:
with open(extracted_paths, 'r') as file:
extracted_paths = []
csvreader = csv.reader(file, delimiter='\n')
for row in csvreader:
download_path.append(row[0])
extracted_paths.append(row[0])
# Sort by band
red_paths = []
nir_paths = []
mask_paths = []
if resolution == 10:
for product in extracted_paths:
if fnmatch(product, '*_B04_10m*'):
red_paths.append(product)
elif fnmatch(product, '*_B8A_20m*'):
nir_paths.append(product)
elif fnmatch(product, '*_SCL_20m*'):
mask_paths.append(product)
else:
for product in extracted_paths:
if fnmatch(product, '*_B04_10m*'):
red_paths.append(product)
elif fnmatch(product, '*_B08_10m*'):
nir_paths.append(product)
elif fnmatch(product, '*_SCL_20m*'):
mask_paths.append(product)
# Sort and get dates
red_paths.sort()
nir_paths.sort()
mask_paths.sort()
dates = [product_str_to_datetime(prod) for prod in red_paths]
products = xr.open_mfdataset(download_path, parallel = True) #, chunks = {'latitude': 100, 'longitude': 100})
# Open datasets with xarray
red = xr.open_mfdataset(red_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'red'}).astype('f4')
nir = xr.open_mfdataset(nir_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'nir'}).astype('f4')
mask = xr.open_mfdataset(mask_paths, combine = 'nested', concat_dim = 'time', chunks = chunk_size).squeeze(dim = ['band'], drop = True).rename({'band_data': 'mask'}).astype('f4')
if resolution == 10:
mask = xr.where((mask == 4) | (mask == 5), 1, 0).interp(x = red.coords['x'], y = red.coords['y'], method = 'nearest')
else:
mask = xr.where((mask == 4) | (mask == 5), 1, 0)
# Set time coordinate
red['time'] = pd.to_datetime(dates)
nir['time'] = pd.to_datetime(dates)
mask['time'] = pd.to_datetime(dates)
# Create ndvi dataset and calculate ndvi
ndvi = red
ndvi = ndvi.drop('red')
ndvi['ndvi'] = (((nir.nir - red.red - acorvi_corr)/(nir.nir + red.red + acorvi_corr))*mask.mask)
del red, nir, mask
# Mask and scale ndvi
ndvi['ndvi'] = xr.where(ndvi.ndvi < 0, 0, ndvi.ndvi)
ndvi['ndvi'] = xr.where(ndvi.ndvi > 1, 1, ndvi.ndvi)
ndvi['ndvi'] = ndvi.ndvi*255
# Create save path
ndvi_cube_path = save_dir + os.sep + 'NDVI_precube_' + dates[0].strftime('%d-%m-%Y') + '_' + dates[-1].strftime('%d-%m-%Y') + '.nc'
ndvi_cube_path =''
# Save NDVI cude to netcdf
ndvi.to_netcdf(ndvi_cube_path, encoding = {"ndvi": {"dtype": "u8", "_FillValue": 0}})
ndvi.close()
return ndvi_cube_path
\ No newline at end of file
......@@ -130,7 +130,7 @@ def extract_zip_archives(download_path: str, list_paths: List[str], bands_to_ext
for file_path in list_paths:
# Change progress bar to print current file
progress_bar.set_description_str(desc = 'Extracting ' + os.path.basename(file_path) + '\ntotal progress')
progress_bar.set_description_str(desc = '\rExtracting ' + os.path.basename(file_path) + '\ntotal progress')
# Get path in which to extract the archive
extract_path = download_path + os.sep + os.path.basename(file_path)[:-4]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment