Skip to content
Snippets Groups Projects
slurm_utils.py 11.22 KiB
#!/usr/bin/env python3
import os
import sys
from os.path import dirname
import re
import math
import argparse
import subprocess as sp
from io import StringIO

from snakemake import io
from snakemake.io import Wildcards
from snakemake.utils import SequenceFormatter
from snakemake.utils import AlwaysQuotedFormatter
from snakemake.utils import QuotedFormatter
from snakemake.exceptions import WorkflowError
from snakemake.logging import logger

from CookieCutter import CookieCutter


def _convert_units_to_mb(memory):
    """If memory is specified with SI unit, convert to MB"""
    if isinstance(memory, int) or isinstance(memory, float):
        return int(memory)
    siunits = {"K": 1e-3, "M": 1, "G": 1e3, "T": 1e6}
    regex = re.compile(r"(\d+)({})$".format("|".join(siunits.keys())))
    m = regex.match(memory)
    if m is None:
        logger.error(
            (
                f"unsupported memory specification '{memory}';"
                "  allowed suffixes: [K|M|G|T]"
            )
        )
        sys.exit(1)
    factor = siunits[m.group(2)]
    return int(int(m.group(1)) * factor)


def parse_jobscript():
    """Minimal CLI to require/only accept single positional argument."""
    p = argparse.ArgumentParser(description="SLURM snakemake submit script")
    p.add_argument("jobscript", help="Snakemake jobscript with job properties.")
    return p.parse_args().jobscript


def parse_sbatch_defaults(parsed):
    """Unpack SBATCH_DEFAULTS."""
    d = parsed.split() if type(parsed) == str else parsed
    args = {}
    for keyval in [a.split("=") for a in d]:
        k = keyval[0].strip().strip("-")
        v = keyval[1].strip() if len(keyval) == 2 else None
        args[k] = v
    return args


def load_cluster_config(path):
    """Load config to dict

    Load configuration to dict either from absolute path or relative
    to profile dir.
    """
    if path:
        path = os.path.join(dirname(__file__), os.path.expandvars(path))
        dcc = io.load_configfile(path)
    else:
        dcc = {}
    if "__default__" not in dcc:
        dcc["__default__"] = {}
    return dcc


# adapted from format function in snakemake.utils
def format(_pattern, _quote_all=False, **kwargs):  # noqa: A001
    """Format a pattern in Snakemake style.
    This means that keywords embedded in braces are replaced by any variable
    values that are available in the current namespace.
    """
    fmt = SequenceFormatter(separator=" ")
    if _quote_all:
        fmt.element_formatter = AlwaysQuotedFormatter()
    else:
        fmt.element_formatter = QuotedFormatter()
    try:
        return fmt.format(_pattern, **kwargs)
    except KeyError as ex:
        raise NameError(
            f"The name {ex} is unknown in this context. Please "
            "make sure that you defined that variable. "
            "Also note that braces not used for variable access "
            "have to be escaped by repeating them "
        )


#  adapted from Job.format_wildcards in snakemake.jobs
def format_wildcards(string, job_properties):
    """ Format a string with variables from the job. """

    class Job(object):
        def __init__(self, job_properties):
            for key in job_properties:
                setattr(self, key, job_properties[key])

    job = Job(job_properties)
    if "params" in job_properties:
        job._format_params = Wildcards(fromdict=job_properties["params"])
    else:
        job._format_params = None
    if "wildcards" in job_properties:
        job._format_wildcards = Wildcards(fromdict=job_properties["wildcards"])
    else:
        job._format_wildcards = None
    _variables = dict()
    _variables.update(
        dict(params=job._format_params, wildcards=job._format_wildcards)
    )
    if hasattr(job, "rule"):
        _variables.update(dict(rule=job.rule))
    try:
        return format(string, **_variables)
    except NameError as ex:
        raise WorkflowError(
            "NameError with group job {}: {}".format(job.jobid, str(ex))
        )
    except IndexError as ex:
        raise WorkflowError(
            "IndexError with group job {}: {}".format(job.jobid, str(ex))
        )


# adapted from ClusterExecutor.cluster_params function in snakemake.executor
def format_values(dictionary, job_properties):
    formatted = dictionary.copy()
    for key, value in list(formatted.items()):
        if key == "mem":
            value = str(_convert_units_to_mb(value))
        if isinstance(value, str):
            try:
                formatted[key] = format_wildcards(value, job_properties)
            except NameError as e:
                msg = "Failed to format cluster config " "entry for job {}.".format(
                    job_properties["rule"]
                )
                raise WorkflowError(msg, e)
    return formatted


def convert_job_properties(job_properties, resource_mapping=None):
    options = {}
    if resource_mapping is None:
        resource_mapping = {}
    resources = job_properties.get("resources", {})
    for k, v in resource_mapping.items():
        options.update({k: resources[i] for i in v if i in resources})

    if "threads" in job_properties:
        options["cpus-per-task"] = job_properties["threads"]
    return options


def ensure_dirs_exist(path):
    """Ensure output folder for Slurm log files exist."""
    di = dirname(path)
    if di == "":
        return
    if not os.path.exists(di):
        os.makedirs(di, exist_ok=True)
    return


def format_sbatch_options(**sbatch_options):
    """Format sbatch options"""
    options = []
    for k, v in sbatch_options.items():
        val = ""
        if v is not None:
            val = f"={v}"
        options.append(f"--{k}{val}")
    return options


def submit_job(jobscript, **sbatch_options):
    """Submit jobscript and return jobid."""
    options = format_sbatch_options(**sbatch_options)
    try:
        cmd = ["sbatch"] + ["--parsable"] + options + [jobscript]
        res = sp.check_output(cmd)
    except sp.CalledProcessError as e:
        raise e
    # Get jobid
    res = res.decode()
    try:
        jobid = re.search(r"(\d+)", res).group(1)
    except Exception as e:
        raise e
    return jobid


def advanced_argument_conversion(arg_dict):
    """Experimental adjustment of sbatch arguments to the given or default partition."""
    # Currently not adjusting for multiple node jobs
    nodes = int(arg_dict.get("nodes", 1))
    if nodes > 1:
        return arg_dict
    partition = arg_dict.get("partition", None) or _get_default_partition()
    constraint = arg_dict.get("constraint", None)
    ncpus = int(arg_dict.get("cpus-per-task", 1))
    runtime = arg_dict.get("time", None)
    memory = _convert_units_to_mb(arg_dict.get("mem", 0))
    config = _get_cluster_configuration(partition, constraint, memory)
    mem = arg_dict.get("mem", ncpus * min(config["MEMORY_PER_CPU"]))
    mem = _convert_units_to_mb(mem)
    if mem > max(config["MEMORY"]):
        logger.info(
            f"requested memory ({mem}) > max memory ({max(config['MEMORY'])}); "
            "adjusting memory settings"
        )
        mem = max(config["MEMORY"])

    # Calculate available memory as defined by the number of requested
    # cpus times memory per cpu
    AVAILABLE_MEM = ncpus * min(config["MEMORY_PER_CPU"])
    # Add additional cpus if memory is larger than AVAILABLE_MEM
    if mem > AVAILABLE_MEM:
        logger.info(
            f"requested memory ({mem}) > "
            f"ncpus x MEMORY_PER_CPU ({AVAILABLE_MEM}); "
            "trying to adjust number of cpus up"
        )
        ncpus = int(math.ceil(mem / min(config["MEMORY_PER_CPU"])))
    if ncpus > max(config["CPUS"]):
        logger.info(
            f"ncpus ({ncpus}) > available cpus ({max(config['CPUS'])}); "
            "adjusting number of cpus down"
        )
        ncpus = min(int(max(config["CPUS"])), ncpus)
    adjusted_args = {"mem": int(mem), "cpus-per-task": ncpus}

    # Update time. If requested time is larger than maximum allowed time, reset
    if runtime:
        runtime = time_to_minutes(runtime)
        time_limit = max(config["TIMELIMIT_MINUTES"])
        if runtime > time_limit:
            logger.info(
                f"time (runtime) > time limit {time_limit}; " "adjusting time down"
            )
            adjusted_args["time"] = time_limit

    # update and return
    arg_dict.update(adjusted_args)
    return arg_dict


timeformats = [
    re.compile(r"^(?P<days>\d+)-(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d+)$"),
    re.compile(r"^(?P<days>\d+)-(?P<hours>\d+):(?P<minutes>\d+)$"),
    re.compile(r"^(?P<days>\d+)-(?P<hours>\d+)$"),
    re.compile(r"^(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d+)$"),
    re.compile(r"^(?P<minutes>\d+):(?P<seconds>\d+)$"),
    re.compile(r"^(?P<minutes>\d+)$"),
]


def time_to_minutes(time):
    """Convert time string to minutes.

    According to slurm:

      Acceptable time formats include "minutes", "minutes:seconds",
      "hours:minutes:seconds", "days-hours", "days-hours:minutes"
      and "days-hours:minutes:seconds".

    """
    if not isinstance(time, str):
        time = str(time)
    d = {"days": 0, "hours": 0, "minutes": 0, "seconds": 0}
    regex = list(filter(lambda regex: regex.match(time) is not None, timeformats))
    if len(regex) == 0:
        return
    assert len(regex) == 1, "multiple time formats match"
    m = regex[0].match(time)
    d.update(m.groupdict())
    minutes = (
        int(d["days"]) * 24 * 60
        + int(d["hours"]) * 60
        + int(d["minutes"])
        + math.ceil(int(d["seconds"]) / 60)
    )
    assert minutes > 0, "minutes has to be greater than 0"
    return minutes


def _get_default_partition():
    """Retrieve default partition for cluster"""
    cluster = CookieCutter.get_cluster_option()
    cmd = f"sinfo -O partition {cluster}"
    res = sp.check_output(cmd.split())
    m = re.search(r"(?P<partition>\S+)\*", res.decode(), re.M)
    partition = m.group("partition")
    return partition


def _get_cluster_configuration(partition, constraints=None, memory=0):
    """Retrieve cluster configuration.

    Retrieve cluster configuration for a partition filtered by
    constraints, memory and cpus

    """
    try:
        import pandas as pd
    except ImportError:
        print(
            "Error: currently advanced argument conversion "
            "depends on 'pandas'.", file=sys.stderr
        )
        sys.exit(1)

    if constraints:
        constraint_set = set(constraints.split(","))
    cluster = CookieCutter.get_cluster_option()
    cmd = f"sinfo -e -o %all -p {partition} {cluster}".split()
    try:
        output = sp.Popen(" ".join(cmd), shell=True, stdout=sp.PIPE).communicate()
    except Exception as e:
        print(e)
        raise
    data = re.sub("^CLUSTER:.+\n", "", re.sub(" \\|", "|", output[0].decode()))
    df = pd.read_csv(StringIO(data), sep="|")
    try:
        df["TIMELIMIT_MINUTES"] = df["TIMELIMIT"].apply(time_to_minutes)
        df["MEMORY_PER_CPU"] = df["MEMORY"] / df["CPUS"]
        df["FEATURE_SET"] = df["AVAIL_FEATURES"].str.split(",").apply(set)
    except Exception as e:
        print(e)
        raise
    if constraints:
        constraint_set = set(constraints.split(","))
        i = df["FEATURE_SET"].apply(lambda x: len(x.intersection(constraint_set)) > 0)
        df = df.loc[i]
    memory = min(_convert_units_to_mb(memory), max(df["MEMORY"]))
    df = df.loc[df["MEMORY"] >= memory]
    return df