Skip to content
Snippets Groups Projects
input_toolbox.py 18.1 KiB
Newer Older
# -*- coding: UTF-8 -*-
# Python
"""
Contains functions to facilitate code workflow, directory creation, Notebook readability, logfile creation, etc.
04-07-2023
@author: jeremy auclair
"""

from typing import List, Tuple  # to declare argument types
import numpy as np  # vectorized math
from fnmatch import fnmatch  # to match character strings
import re  # for character string search
import datetime  # for date management
from scipy.signal import medfilt  # scipy median filter
import ruamel.yaml as yml # for yaml file modification
from numba import jit, njit  # to compile functions for faster execution
from modspa_pixel.config.config import config  # to import config file


def product_str_to_datetime(product_name: str) -> datetime.date:
    """
    product_str_to_datetime returns a ``datetime.date`` object for the date of the given product.
        datetime.date object, date of the product
    """

    # Search for a date pattern (yyyymmdd) in the product name or path
    try:
        match = re.search('\d{4}\d{2}\d{2}', product_name)
        format = '%Y%m%d'
        datetime_object = datetime.datetime.strptime(match[0], format)
        return datetime_object.date()
    except TypeError:
        pass
    
    # Search for a date pattern (yymmdd) in the product name or path
    try:
        match = re.search('\d{2}\d{2}\d{2}', product_name)
        format = '%y%m%d'
        datetime_object = datetime.datetime.strptime(match[0], format)
        return datetime_object.date()
    except TypeError:
        pass


def read_product_info(path_to_product: str) -> Tuple[datetime.date, str, str, str]:
    """
    Read_product_info detects and returns the date, tile and provider of a given product.

    Arguments
    =========

    1. path_to_product: ``str``
        path to the product

    Returns
    =======

    1. date, tile, provider): ``tuple``
        ``datetime.date`` object of the product
    2. tile: ``str``
        tile name of product
    3. provider: ``str``
        provider of product
    4. satellite: ``str``
        name of the satellite (ex: Sentinel-2, LandSat-8, SPOT-4)
    """

    # Collect date of product
    date = product_str_to_datetime(path_to_product)

    # Test provider type
    if fnmatch(path_to_product, '*S2?_MSIL2A_20*'):
        provider = 'copernicus'
        satellite = 'Sentinel-2'

        # Try finding the tile pattern (_TXXXXX_) in the path to product. None is returned if the path is not in
        # accordance with copernicus or theia naming schemes
        try:
            tile = re.findall('_T(.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None

    elif fnmatch(path_to_product, '*SENTINEL2?_20*'):
        provider = 'theia'
        satellite = 'Sentinel-2'

        # Try finding the tile pattern (_TXXXXX_) in the path to product. None is returned if the path is not in
        # accordance with copernicus or theia naming schemes
        try:
            tile = re.findall('_T(.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None
    
    elif fnmatch(path_to_product, '*LC08_L2SP_*'):
        provider = 'usgs'
        satellite = 'LandSat-8'

        # Try finding the tile pattern (_XXXXXX_) in the path to product. None is returned if the path is not in
        # accordance with usgs naming schemes
        try:
            tile = re.findall('_(.?.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None
    
    elif fnmatch(path_to_product, '*LE07_L2SP_*'):
        provider = 'usgs'
        satellite = 'LandSat-7'
        
        # Try finding the tile pattern (_XXXXXX_) in the path to product. None is returned if the path is not in
        # accordance with usgs naming schemes
        try:
            tile = re.findall('_(.?.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None
    
    elif fnmatch(path_to_product, '*SP4_OPER_*'):
        provider = 'swh'
        satellite = 'SPOT-4'
        
        # Try finding the tile pattern (_XXXXXX-) in the path to product. None is returned if the path is not in
        # accordance with SPOT naming schemes
        try:
            tile = re.findall('T(.?.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None
    
    elif fnmatch(path_to_product, '*SP5_OPER_*'):
        provider = 'swh'
        satellite = 'SPOT-5'
        
        # Try finding the tile pattern (_XXXXXX-) in the path to product. None is returned if the path is not in
        # accordance with SPOT naming schemes
        try:
            tile = re.findall('T(.?.?.?.?.?.?)_', path_to_product)[0]
        except TypeError:
            print('Error in provided path, no tile found in string')
            tile = None
    
    elif fnmatch(path_to_product, '*SP_*'):
        provider = 'swh'
        satellite = 'SP-5'
        tile = ''
    
    else:
        tile, provider, satellite = '', '', ''

    return date, tile, provider, satellite
    

def prepare_directories(config_file: str) -> None:
    """
    Creates the directories for the data inputs and outputs.

    Arguments
    =========
    
    1. config_file: ``str``

    Returns
    =======
    
    ``None``
    """
    
    # Open config file
    config_params = config(config_file)
    
    # Get parameters
    download_path = config_params.download_path
    era5_path = config_params.era5_path
    run_name = config_params.run_name
    output_path = config_params.output_path

    # Create all the directories for downloaded data if they do not exist.
    if not os.path.exists(download_path):
        os.mkdir(download_path)

    # Path for scihub data
    scihub_path = download_path + os.sep + 'SCIHUB'
    if not os.path.exists(scihub_path):
        os.mkdir(scihub_path)
    ndvi_scihub_path = scihub_path + os.sep + 'NDVI'
    if not os.path.exists(ndvi_scihub_path):
        os.mkdir(ndvi_scihub_path)
    ndvi_run_name = ndvi_scihub_path + os.sep + run_name
    if not os.path.exists(ndvi_run_name):
        os.mkdir(ndvi_run_name)

    # Path for theia data
    theia_path = download_path + os.sep + 'THEIA'
    if not os.path.exists(theia_path):
        os.mkdir(theia_path)
    ndvi_theia_path = theia_path + os.sep + 'NDVI'
    if not os.path.exists(ndvi_theia_path):
        os.mkdir(ndvi_theia_path)
    ndvi_run_name = ndvi_theia_path + os.sep + run_name
    if not os.path.exists(ndvi_run_name):
        os.mkdir(ndvi_run_name)

    # Path for usgs data
    usgs_path = download_path + os.sep + 'USGS'
    if not os.path.exists(usgs_path):
        os.mkdir(usgs_path)
    ndvi_usgs_path = usgs_path + os.sep + 'NDVI'
    if not os.path.exists(ndvi_usgs_path):
        os.mkdir(ndvi_usgs_path)
    
    # Create weather data directories
    if not os.path.exists(era5_path):
        os.mkdir(era5_path)
    if not os.path.exists(era5_path + os.sep + run_name):
        os.mkdir(era5_path + os.sep + run_name)
    
    # Create soil data directories
    soil_dir = download_path + os.sep + 'SOIL'
    if not os.path.exists(soil_dir):
        os.mkdir(soil_dir)
    if not os.path.exists(soil_dir + os.sep + run_name):
        os.mkdir(soil_dir + os.sep + run_name)
    
    # Create land cover data directories
    landcover_dir = download_path + os.sep + 'LAND_COVER'
    if not os.path.exists(landcover_dir):
        os.mkdir(landcover_dir)
    if not os.path.exists(landcover_dir + os.sep + run_name):
        os.mkdir(landcover_dir + os.sep + run_name)
    
    # Create output data directories
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    if not os.path.exists(output_path + os.sep + run_name):
        os.mkdir(output_path + os.sep + run_name)
@jit(nopython = False, forceobj = True)
def intersection(list1: list, list2: list) -> list:
    """
    returns the intersection of two lists. Objects in lists need to be comparable.

    Arguments
    =========

    1. list1: ``list``
        first list
    2. list2: ``list``
        second list

    Returns
    =======

    1. list3: ``list``
        list of the intersection of list1 and list2
    """

    list3 = [value for value in list1 if value in list2]
    list3 = list(dict.fromkeys(list3))
    
    return list3
    

@jit(nopython = False, forceobj = True)
def avg_2dates(date1: datetime.datetime, date2: datetime.datetime) -> datetime.datetime:
    """
    Get the average date between to dates. For example, the average between 01/01/2020 and 07/01/2020 is 04/01/2020.
    if the number of days between the two dates is odd, the average date has 12 hours added to it.

    Arguments
    =========

    1. date1: ``datetime.datetime``
        first date
    2. date2: ``datetime.datetime``
        second date
        
    Returns
    =======

    1. averaged date: ``datetime.datetime``
        averaged date as datetime format
    """
    if date2 >= date1:
        return date1 + datetime.timedelta((date2 - date1).days)/2
    else:
        return date2 + datetime.timedelta((date1 - date2).days)/2
    

@jit(nopython = False, forceobj = True)
def calculate_time_derivative(values: np.ndarray, dates: List[datetime.datetime]) -> Tuple[List[float], List[datetime.datetime]]:
    """
    Get a time derivative of a series of dated values (one value per day maximum). If values is of length N,
    the returned lists are of length n-1

    Arguments
    =========

    1. values: ``list[float]``
        list of timed values
    2. dates: ``list[np.datetime64]``
        list of dates for the values
        
    Returns
    =======

    1. derivative: ``list[float]``
        derivative values
    2. deriv_dates: ``list[np.datetime64]``
        dates corresponding to the derivative values
    """
    
    # Get derivative values
    value_deltas, time_deltas = values[1:] - values[:-1], (dates[1:] - dates[:-1]).days
    derivative = np.array(value_deltas/time_deltas)
    
    # Get new date list for derivative values
    deriv_dates = [avg_2dates(dates[i+1], dates[i]) for i in range(len(dates)-1)]
    
    return derivative, deriv_dates


@njit
def detect_anomalies_deriv(derivative: np.ndarray, threshold: float) -> List[int]:
    """
    Detects anomalies in a list of derivative of values, based on a threshold. The algorithm
    looks for derivative high values (negative or positive) followed by another high value of
    the opposite sign. Filters out two or three consecutive values to keep the one that looks
    like a peak.

    Arguments
    =========

    1. derivative: ``list[float]``
        list of derivative values
    2. threshold: ``float``
        float value between two derivative values to flag as an anomaly
        
    Returns
    =======

    1. list_anomalies: ``list[int]``
        indexes (derivative values) of flagged values.
    """
    
    list_anomalies = []
    
    # Detect anomalies based on derivative value
    for i in range(1, len(derivative)-1):
        if abs(derivative[i+1] - derivative[i]) > threshold:
            # First condition finds high changes in derivative (peaks) values that are not updward or downward curves
            # Second condition ensures that the consecutive derivative values change sign (indicator of peak value)
            # Third condition ensures that the following derivative values isn't too small (meaning we are not selecting a "normal" curve)
            if abs(derivative[i+1] - derivative[i-1]) <= abs(derivative[i+1] - derivative[i]) and np.sign(derivative[i+1]) != np.sign(derivative[i]) and abs(derivative[i+1]) > threshold/5:
                list_anomalies.append(i)

    # Find series of 3 consecutive anomamlies, keep middle one
    to_pop = []
    for i in range(1, len(list_anomalies)-1):
        if list_anomalies[i+1] - list_anomalies[i] == 1 and list_anomalies[i] - list_anomalies[i-1] == 1:
            to_pop.append(i+1)
            to_pop.append(i-1)
    list_anomalies = [list_anomalies[i] for i in range(len(list_anomalies)) if i not in to_pop]
    
    # Find double consecutive anomalies, keep the one with the highest derivative value on the right
    to_pop = []
    for i in range(len(list_anomalies)-1):
        if list_anomalies[i+1] - list_anomalies[i] == 1:
            if abs(derivative[list_anomalies[i]+1] - derivative[list_anomalies[i]]) > max(abs(derivative[list_anomalies[i]] - derivative[list_anomalies[i]-1]), abs(derivative[list_anomalies[i]+2] - derivative[list_anomalies[i]+1])):
                to_pop.append(i+1)
            else:
                to_pop.append(i)
    list_anomalies = [list_anomalies[i] for i in range(len(list_anomalies)) if i not in to_pop]
    
    return list_anomalies


def detect_anomalies_median(values: List[float], threshold_ratio: float, window: int = 3) -> List[int]:
    """
    Detects anomalies in a NDVI time series based on a moving median filter. If one point is to 
    far from its corresponding point in the median filtered list, it is flagged as an anomaly. The
    threshold is calculated as followed: ``threshold[i] = threshold_ratio*median(values[i-1], values[i], values[i+1])``

    Arguments
    =========

    1. values: ``List[float]``
        values in which to detect the anomalies
    2. threshold_ratio: ``float``
        ratio to apply to the current median value to get a threshold over which to flag values
    3. window: ``int`` ``default = 3``
        size of window for moving median filter

    Returns
    =======

    1. list_anomalies: ``List[int]``
        indexes of flagged values
    """
    
    list_anomalies = []

    # Set moving window to size of list if list is smaller
    if len(values) < 3: window = 1
    
    # Get median filter
    median = medfilt(values, window)

    # Calculate delta

    # Return indexes where condition is met
    list_anomalies = np.argwhere(delta > threshold_ratio*median)
    
    return list_anomalies


def find_anomalies(values: List[float], dates: List[datetime.datetime], deriv_threshold_ratio: float = 25, median_threshold_ratio: float = 0.1, median_window: int = 3) -> List[int]:
    """
    Detects anomalies with the intersection of two filters. The first (derivative filter)
    finds anomalies by looking for consecuvite high derivative values of opposite sign (with
    a few additional conditions). The second (median filter) finds anomalies by comparing each
    point to its corresponding point in the median filter list, when the difference is over the
    given threshold, the point is flagged. The intersection of these two filters is returned as
    the list of detected anomalies.

    Arguments
    =========

    1. values: ``List[float]``
        list in which to detect anomalies
    2. dates: ``List[datetime.datetime]``
        list of dates of values
    3. deriv_threshold_ratio: ``int`` ``default = 25``
        ratio to divide the max of values to get a threshold for the derivative values.
        25 is arbitrary but seems to work well for NDVI values
        changes in NDVI values can seem "more brutal")
    4. median_threshold_ratio: ``float`` ``default = 0.1``
        threshold ratio for median filter
    5. median_window: ``int`` ``default = 3``
        size of window for median filter

    Returns
    =======

    1. list_anomalies: ``List[int]``
        indexes of values flagged as anomalies
    """
    
    # Get values' derivative
    values_deriv, _ = calculate_time_derivative(values, dates)
    
    # Get threshold from data
    deriv_threshold = max(values)/deriv_threshold_ratio
    
    # Get anomalies
    deriv_anomalies = list(np.array(detect_anomalies_deriv(values_deriv, deriv_threshold)) +1)
    median_anomalies = detect_anomalies_median(values, median_threshold_ratio, median_window)
    
    # Get intersection
    list_anomalies = intersection(deriv_anomalies, median_anomalies)
    
    return list_anomalies


def set_eodag_config_file(path_to_config_file: str, download_dir: str, provider: str) -> None:
    """
    Modifies the download path in the eodag `yaml` configuration file according to the code configuration file if needed.
    If you add additionnal providers this function needs to be modified: add elif condition for new provider.

    ## Arguments
    1. path_to_config_file: `str`
        path to the eodag config file
    2. download_dir: `str`
        path to the download directory declared in the code config file
    3. provider: `str`
        provider for which tu update/change download directory
    
    ## Returns
    `None`
    """

    # Check if eodag config files exists (said file is created on the first run of eodag)
    if not os.path.exists(path_to_config_file):
        print("First run of eodag, eodag config file doesn't exist yet, run script again.\n")
        return None

    # Open eodag config file
    parameter_list, ind, bsi = yml.util.load_yaml_guess_indent(open(path_to_config_file))

    # Adapt structure based on provider
    if provider == 'coperniucs':
        provider = 'scihub'
        if parameter_list[provider]['api']['outputs_prefix'] != download_dir + os.sep + 'SCIHUB' + os.sep:
            parameter_list[provider]['api']['outputs_prefix'] = download_dir + os.sep + 'SCIHUB' + os.sep
    elif provider == 'theia':
        if parameter_list[provider]['download']['outputs_prefix'] != download_dir + os.sep + 'THEIA' + os.sep:
            parameter_list[provider]['download']['outputs_prefix'] = download_dir + os.sep + 'THEIA' + os.sep
    elif provider == 'usgs':
        if parameter_list[provider]['api']['outputs_prefix'] != download_dir + os.sep + 'USGS' + os.sep:
            parameter_list[provider]['api']['outputs_prefix'] = download_dir + os.sep + 'USGS' + os.sep

    # Save modified data
    yaml = yml.YAML()
    yaml.indent(mapping=ind, sequence=ind, offset=bsi) 
    with open(path_to_config_file, 'w') as fp:
        yaml.dump(parameter_list, fp)

    return None