diff --git a/CHANGELOG.md b/CHANGELOG.md index 837d1d15..b789f440 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,7 +52,7 @@ ### Performance Improvements -* performanc tweaks ([#417](https://github.com/snakemake/snakemake-executor-plugin-slurm/issues/417)) ([a3f6abf](https://github.com/snakemake/snakemake-executor-plugin-slurm/commit/a3f6abf47a1d3baa51b987fc0fcdb1972fc2bdd6)) +* performance tweaks ([#417](https://github.com/snakemake/snakemake-executor-plugin-slurm/issues/417)) ([a3f6abf](https://github.com/snakemake/snakemake-executor-plugin-slurm/commit/a3f6abf47a1d3baa51b987fc0fcdb1972fc2bdd6)) ## [2.3.1](https://github.com/snakemake/snakemake-executor-plugin-slurm/compare/v2.3.0...v2.3.1) (2026-02-20) diff --git a/docs/further.md b/docs/further.md index 263a3ab9..320fc6d0 100644 --- a/docs/further.md +++ b/docs/further.md @@ -332,6 +332,19 @@ This node tracking works regardless of whether the `--slurm-requeue` flag is ena - **With `--slurm-requeue`**: SLURM will automatically requeue failed jobs, and they will be retried on different nodes - **Without `--slurm-requeue`**: Failed jobs will be reported as errors, but future retries (via `--retries` or other retry mechanisms) will avoid the problematic nodes +#### SLURM Job Arrays + +Using `--slurm-array-jobs` SLURM job arrays can be submitted. `--slurm-array-jobs=rule1,rule2,...` lets you select specific rules by name to be selected as an array job. Alternatively, `--slurm-array-jobs=all` will submit all eligible rules as array jobs. + +Note: group jobs cannot be array jobs. + +.. note:: Using array jobs does impose a synchronization overhead (all jobs of a particular rule need to be ready for execution). + +When submitting array jobs, the `--slurm-array-limit` flag defines the +maximum number of array tasks to be submitted in one job submission. +If the number of tasks exceeds this limit, multiple array job submissions will be performed. This is useful to avoid hitting cluster limits on the maximum number of array tasks per job. Please obey your cluster limits and set this flag accordingly. + + #### MPI-specific Resources Snakemake's SLURM executor plugin supports the execution of MPI ([Message Passing Interface](https://en.wikipedia.org/wiki/Message_Passing_Interface)) jobs. diff --git a/pyproject.toml b/pyproject.toml index a255370c..63441bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ keywords = ["snakemake", "plugin", "executor", "cluster", "slurm"] python = "^3.11" snakemake-interface-common = "^1.21.0" snakemake-interface-executor-plugins = "^9.3.9" -snakemake-executor-plugin-slurm-jobstep = "^0.4.0" +snakemake-executor-plugin-slurm-jobstep = "^0.6.0" pandas = "^2.2.3" numpy = ">=1.26.4, <3" throttler = "^1.2.2" @@ -28,7 +28,7 @@ black = "^23.7.0" flake8 = "^6.1.0" coverage = "^7.3.1" pytest = "^8.3.5" -snakemake = "^9.6.0" +snakemake = "^9.17.2" pandas = "^2.2.3" [tool.poetry.scripts] diff --git a/snakemake_executor_plugin_slurm/__init__.py b/snakemake_executor_plugin_slurm/__init__.py index 3f02fd26..6c71e1e9 100644 --- a/snakemake_executor_plugin_slurm/__init__.py +++ b/snakemake_executor_plugin_slurm/__init__.py @@ -4,6 +4,11 @@ __license__ = "MIT" import atexit +import asyncio +import base64 +import errno +from concurrent.futures import ThreadPoolExecutor +import json import os from pathlib import Path import re @@ -14,6 +19,7 @@ from dataclasses import dataclass, field from typing import List, Generator, Optional import uuid +import zlib from snakemake_interface_executor_plugins.executors.base import ( SubmittedJobInfo, @@ -36,6 +42,9 @@ ) from .utils import ( + get_max_array_size, + get_job_wildcards, + pending_jobs_for_rule, delete_slurm_environment, delete_empty_dirs, set_gres_string, @@ -50,7 +59,7 @@ ) from .job_cancellation import cancel_slurm_jobs from .efficiency_report import create_efficiency_report -from .submit_string import get_submit_command +from .submit_string import apply_mem_fudge, get_submit_command from .partitions import ( get_default_partition, read_partition_file, @@ -135,10 +144,56 @@ def _get_status_command_help(): ) +def _status_lookup_ids(external_jobid: str) -> List[str]: + """Return candidate IDs for status lookup. + + For array jobs, Snakemake tracks task IDs as ``_``. + Depending on SLURM command/options (e.g. ``sacct -X``), status output may + only contain the parent array ID ````. This helper returns IDs in + lookup order so callers can transparently fall back from task ID to parent + ID. + """ + candidates = [external_jobid] + if "_" in external_jobid: + parent_id, task_id = external_jobid.split("_", 1) + if parent_id.isdigit() and task_id.isdigit(): + candidates.append(parent_id) + return candidates + + @dataclass class ExecutorSettings(ExecutorSettingsBase): """Settings for the SLURM executor plugin.""" + array_jobs: Optional[str] = field( + default=None, + metadata={ + "help": "Will submit jobs as SLURM job arrays, if possible. " + "Use as: --slurm-array-jobs='rule1, rule2' to submit jobs of " + "rule1 and rule2 as array jobs. If a DAG contains only one job for " + "a rule, it cannot be submitted as an array job. Selecting " + "--slurm-array-jobs=all will submit all eligiblejobs as array jobs. " + "Note: When choosing array job submission, the required jobs are " + "subject to a synchronization overhead.", + "env_var": False, + "required": False, + }, + ) + + array_limit: Optional[int] = field( + default=1000, + metadata={ + "help": "When submitting array jobs, this flag defines the maximum " + "number of array tasks to be submitted in one sbatch call. If the " + "number of tasks exceeds this limit, multiple array job submissions " + "will be performed. This is useful to avoid hitting cluster limits on " + "the maximum number of array tasks per job. " + "Please obey your cluster limits and set this flag accordingly.", + "env_var": False, + "required": False, + }, + ) + logdir: Optional[Path] = field( default=None, metadata={ @@ -406,12 +461,45 @@ def __post_init__(self, test_mode: bool = False): self._fallback_partition = None self._preemption_warning = False # no preemption warning has been issued self._submitted_job_clusters = set() # track clusters of submitted jobs + self._job_submission_executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="slurm_job_submit" + ) + self._main_event_loop = None self._status_query_calls = 0 self._status_query_failures = 0 self._status_query_total_seconds = 0.0 self._status_query_min_seconds = None self._status_query_max_seconds = 0.0 self._status_query_cycle_rows = [] + array_job_setting = self.workflow.executor_settings.array_jobs + if array_job_setting: + normalized_setting = array_job_setting.replace(";", ",") + self.array_jobs = { + rule.strip() for rule in normalized_setting.split(",") if rule.strip() + } + else: + self.array_jobs = set() + self.max_array_size = min( + get_max_array_size(), int(self.workflow.executor_settings.array_limit) + ) + if self.max_array_size <= 10: + self.logger.warning( + "Array limit is set to " + f"{self.max_array_size}, " + "which is very low and may lead to excessive numbers of array " + "job submissions. Please consider increasing this limit if your " + "cluster allows it." + ) + if self.max_array_size < 2: + self.logger.error( + "Array job submission is effectively disabled due to " + f"max_array_size={self.max_array_size}. Consider increasing " + "the array_limit setting to enable array job submission." + ) + raise WorkflowError( + "Array job submission is effectively disabled due to " + "low array_limit." + ) self.slurm_logdir = _select_logdir(self.workflow) # Check the environment variable "SNAKEMAKE_SLURM_PARTITIONS", # if set, read the partitions from the given file. Let the CLI @@ -451,6 +539,9 @@ def shutdown(self) -> None: This method is overloaded, to include the cleaning of old log files and to optionally create an efficiency report. """ + # Ensure background submission tasks are finished before shutting down. + self._job_submission_executor.shutdown(wait=True) + # First, we invoke the original shutdown method super().shutdown() @@ -551,23 +642,409 @@ def additional_general_args(self): general_args += " --slurm-jobstep-pass-command-as-script" return general_args - def run_job(self, job: JobExecutorInterface): - # Implement here how to run a job. - # You can access the job's resources, etc. - # via the job object. - # After submitting the job, you have to call - # self.report_job_submission(job_info). - # with job_info being of type - # snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo. + def run_jobs(self, jobs: List[JobExecutorInterface]): + """ + This is a meta rule to delegate the job execution to either + - `run_job` for individual job submission, or + - `run_array_jobs` for array job submission, or + - `run_pool_jobs` for pool job submission (to be implemented in the future). + """ + if self._main_event_loop is None: + try: + self._main_event_loop = asyncio.get_running_loop() + except RuntimeError: + self._main_event_loop = None + + ready_jobs_by_rule = {} + + # check whether any other job is a group job, as these cannot be + # submitted as array jobs and require special handling + for job in jobs: + if job.is_group(): + if job.name in self.array_jobs or "all" in self.array_jobs: + self.logger.warning( + f"Job '{job.name}' is a group job and cannot be " + "submitted as an array job. " + "Submitting it as a regular job instead." + ) + self._job_submission_executor.submit(self.run_job, job) + else: + ready_jobs_by_rule.setdefault(job.rule.name, []).append(job) - group_or_rule = f"group_{job.name}" if job.is_group() else f"rule_{job.name}" + for rule_name, same_rule_jobs in ready_jobs_by_rule.items(): + array_selected_for_rule = ( + "all" in self.array_jobs or rule_name in self.array_jobs + ) + # TODO: use more sensible logging information, once finished + self.logger.debug( + f"Running jobs for rule: {rule_name}, " f"{same_rule_jobs}" + ) + self.logger.debug("Current array job settings: " f"{self.array_jobs}") + + if array_selected_for_rule: + dag = getattr(self.workflow, "dag", None) + if dag is not None: + eligible_jobs = pending_jobs_for_rule(dag, rule_name) + else: + eligible_jobs = len(same_rule_jobs) + self.logger.debug( + "workflow.dag unavailable in run_jobs(); " + "falling back to ready-job count for eligibility " + f"({rule_name}: {eligible_jobs})." + ) + + # Keep synchronization against DAG eligibility, but do not block + # once at least one full array chunk is ready. + chunk_size = self.max_array_size + + if len(same_rule_jobs) == 1: + if eligible_jobs <= 1: + self.logger.debug( + f"Array submission requested for rule {rule_name}, " + "but only one pending job is available; submitting " + "as a regular job." + ) + self._job_submission_executor.submit( + self.run_job, same_rule_jobs[0] + ) + else: + self.logger.debug( + "Array job collection incomplete for rule " + f"{rule_name}: 1/{eligible_jobs} arrived. Waiting " + "for at least one full chunk." + ) + else: + if ( + len(same_rule_jobs) < eligible_jobs + and len(same_rule_jobs) < chunk_size + ): + self.logger.debug( + "Array job collection incomplete for rule " + f"{rule_name}: {len(same_rule_jobs)}/{eligible_jobs} " + "arrived (< chunk size), waiting for more jobs." + ) + continue + self.logger.debug( + "Submitting array-selected jobs for rule " + f"{rule_name}: {len(same_rule_jobs)} ready, " + f"{eligible_jobs} eligible, chunk_size={chunk_size}." + ) + self._job_submission_executor.submit( + self.run_array_jobs, same_rule_jobs + ) + continue + # Non-array mode: submit all ready jobs individually. + elif len(same_rule_jobs) == 1: + self.logger.debug( + f"Submitting single job for rule {rule_name} as " + "array mode is disabled." + ) + self._job_submission_executor.submit(self.run_job, same_rule_jobs[0]) + continue + else: + self.logger.debug( + f"Submitting {len(same_rule_jobs)} ready jobs for rule " + f"{rule_name} individually (array mode disabled)." + ) + for job in same_rule_jobs: + self._job_submission_executor.submit(self.run_job, job) + + def _report_job_submission_threadsafe(self, job_info: SubmittedJobInfo): + if self._main_event_loop is not None: + self._main_event_loop.call_soon_threadsafe( + self.report_job_submission, + job_info, + ) + else: + self.report_job_submission(job_info) + + def _report_job_error_threadsafe(self, job_info: SubmittedJobInfo, msg: str): + if self._main_event_loop is not None: + self._main_event_loop.call_soon_threadsafe( + self.report_job_error, + job_info, + msg, + ) + else: + self.report_job_error(job_info, msg=msg) + + def run_array_jobs(self, jobs: List[JobExecutorInterface]): try: - wildcard_str = ( - "_".join(job.wildcards).replace("/", "_") if job.wildcards else "" + self.logger.debug( + f"Preparing to submit array job for rule {jobs[0].rule.name} " + f"with {len(jobs)} tasks." + ) + group_or_rule = ( + f"group_{jobs[0].name}" + if jobs[0].is_group() + else f"rule_{jobs[0].name}" ) - except AttributeError: - wildcard_str = "" + + # in an array job `sbatch --output` gets a single value + # hence, we can only consider the first wildcard string + # to create the SLURM log file path. + wildcard_strs = [get_job_wildcards(job) for job in jobs] + wildcard_str = wildcard_strs[0] + # the wildcard string shall be ignored for the SLURM log + # file path. + slurm_logfile = self.slurm_logdir / group_or_rule / r"%A_%a.log" + slurm_logfile.parent.mkdir(parents=True, exist_ok=True) + + # this behavior has been fixed in slurm 23.02, but there might be + # plenty of older versions around, hence we should rather be + # conservative here. + assert "%A" not in str(self.slurm_logdir), ( + "bug: jobid placeholder in parent dir of logfile. This does not " + "work as we have to create that dir before submission in order to " + "make sbatch happy. Otherwise we get silent fails without " + "logfiles being created." + ) + assert r"%a" not in str(self.slurm_logdir), ( + "bug: jobid placeholder in parent dir of logfile. This does not " + "work as we have to create that dir before submission in order to " + "make sbatch happy. Otherwise we get silent fails without " + "logfiles being created." + ) + + # generic part of a submission string: + # we use a run_uuid as the job-name, to allow `--name`-based + # filtering in the job status checks (`sacct --name` and + # `squeue --name`) + if wildcard_str == "": + comment_str = f"rule_{jobs[0].name}" + else: + self.logger.warning( + "Array job submission does not allow for multiple different " + "wildcard combinations in the comment string. Only the first " + "one will be used." + ) + comment_str = f"rule_{jobs[0].name}_wildcards_{wildcard_strs[0]}" + + for job in jobs: + # check whether the 'slurm_extra' parameter is used correctly + # prior to putatively setting in the sbatch call + validate_slurm_extra(job) + + self.logger.debug("Building job params for array job") + # Note: all jobs have the same resource requirement. + # Thus, we can simply take the first job to extract + # the relevant parameters for the sbatch call. + job_params = { + "run_uuid": self.run_uuid, + "slurm_logfile": slurm_logfile, + "comment_str": comment_str, + "account": next(self.get_account_arg(jobs[0])), + "partition": self.get_partition_arg(jobs[0]), + "workdir": self.workflow.workdir_init, + } + + if not jobs[0].resources.get("runtime"): + self.logger.warning( + "No wall time information given. This might or might not " + "work on your cluster. " + "If not, specify the resource runtime in your rule or as " + "a reasonable default via --default-resources." + ) + + if not jobs[0].resources.get("mem_mb_per_cpu") and not jobs[ + 0 + ].resources.get("mem_mb"): + self.logger.warning( + "No job memory information ('mem_mb' or 'mem_mb_per_cpu') is " + "given - submitting without. This might or might not work on " + "your cluster." + ) + # Build a compressed map of array task id -> full execution string + # for all jobs. + array_execs = { + index: zlib.compress( + self.format_job_exec(job).encode("utf-8"), level=9 + ).hex() + for index, job in enumerate(jobs, start=1) + } + + call = get_submit_command( + jobs[0], + job_params, + settings=self.workflow.executor_settings, + failed_nodes=self._failed_nodes, + ) + if self._failed_nodes: + self.logger.debug( + "Excluding failed nodes from array job submission: " + f"{','.join(self._failed_nodes)}" + ) + call += set_gres_string(jobs[0]) + + # the actual array job call: + # we need to cycle over all jobs and submit up to `array_limit` jobs per + # submission, to avoid hitting cluster limits or oversaturating the + # command line limits + array_limit = min(self.max_array_size, len(jobs)) + for start_index in range(1, len(jobs) + 1, array_limit): + end_index = min(start_index + array_limit - 1, len(jobs)) + # The first task of each chunk runs via the plain base command. + # Remaining tasks are dispatched from --slurm-jobstep-array-execs. + exec_job = self.format_job_exec(jobs[start_index - 1]) + sub_array_execs = { + str(i): array_execs[i] + for i in range(start_index + 1, end_index + 1) + } + array_execs_payload = base64.b64encode( + json.dumps(sub_array_execs).encode("utf-8") + ).decode() + + # add memory fudge factor to the base call, + # to account for the extra memory needed by the + # jobstep process to hold and parse the array execs payload. + call = apply_mem_fudge(call, array_execs_payload) + + use_script_submission = ( + self.workflow.executor_settings.pass_command_as_script + ) + submission_failed = False + while True: + call_with_array = call + f" --array={start_index}-{end_index}" + + if not use_script_submission: + # Use --wrap for the base execution command. + call_with_array += ( + f' --wrap="{exec_job}' + f" --slurm-jobstep-array-execs=" + f"{shlex.quote(array_execs_payload)}" + '"' + ) + subprocess_stdin = None + self.logger.debug(f"call with array: {call_with_array}") + else: + # Use /dev/stdin to pass the base execution command as a script. + sbatch_script = "\n".join( + [ + "#!/bin/sh", + f"{exec_job}", + "--slurm-jobstep-array-execs", + shlex.quote(array_execs_payload), + ] + ) + call_with_array += " /dev/stdin" + subprocess_stdin = sbatch_script + + self.logger.debug( + f"Submitting array job with sbatch call: {call_with_array}" + ) + try: + process = subprocess.Popen( + call_with_array, + shell=True, + text=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, err = process.communicate(input=subprocess_stdin) + if process.returncode != 0: + raise subprocess.CalledProcessError( + process.returncode, call_with_array, output=err + ) + break + except OSError as e: + if e.errno == errno.E2BIG and not use_script_submission: + self.logger.warning( + "Array sbatch command exceeds argument-length " + "limits; retrying via /dev/stdin script mode " + f"for tasks {start_index}-{end_index}." + ) + use_script_submission = True + continue + raise + except subprocess.CalledProcessError as e: + error_msg = ( + "SLURM sbatch failed for array job submission " + f"(tasks {start_index}-{end_index}). " + f"The error message was '{e.output.strip()}'.\n" + f" sbatch call:\n {call_with_array}\n" + ) + self.logger.error(error_msg) + for job in jobs[start_index - 1 : end_index]: + self._report_job_error_threadsafe( + SubmittedJobInfo(job), + ( + f"Part of failed array sbatch submission " + f"(tasks {start_index}-{end_index}); " + "see log for details." + ), + ) + submission_failed = True + break + + if submission_failed: + continue + + # To extract the job id we split by semicolon and take the first + # element (this also works if no cluster name was provided) + slurm_jobid = out.strip().split(";")[0] + # this slurm_jobid might be wrong: some cluster admin give convoluted + # sbatch outputs. So we need to validate it properly (and replace it + # if necessary). + slurm_jobid = validate_or_get_slurm_job_id(slurm_jobid, out) + # here, however we are dealing with array jobs and the job id is of + # the form _, so we need to add the task ids + job_ids = list() # Snakemake interal ids + for index in range(start_index, end_index + 1): + # Calculate the actual logfile path for this array task + job = jobs[index - 1] + job_ids.append(job.jobid) + job_wildcard_str = get_job_wildcards(job) + job_logfile = ( + self.slurm_logdir + / group_or_rule + / job_wildcard_str + / f"{slurm_jobid}_{index}.log" + ) + + job_info = SubmittedJobInfo( + job, + external_jobid=f"{slurm_jobid}_{index}", + aux={"slurm_logfile": job_logfile}, + ) + self._report_job_submission_threadsafe(job_info) + self.logger.debug( + f"Registered array job task: " + f"external_jobid={slurm_jobid}_{index}, " + f"snakemake_jobid={job.jobid}" + ) + + job_ids_str = ",".join(map(str, job_ids)) + self.logger.info( + f"Submitted array job with Snakemake IDs {job_ids_str} and " + f"SLURM job ID {slurm_jobid}. The individual task IDs are " + f"{start_index}-{end_index}." + ) + + # Track cluster specification for later use in cancel_jobs + cluster_val = ( + jobs[0].resources.get("cluster") + or jobs[0].resources.get("clusters") + or jobs[0].resources.get("slurm_cluster") + ) + if cluster_val: + self._submitted_job_clusters.add(cluster_val) + except Exception as e: + self.logger.error( + f"Exception in run_array_jobs for rule {jobs[0].rule.name}: {e}", + exc_info=True, + ) + for job in jobs: + self._report_job_error_threadsafe( + SubmittedJobInfo(job), + f"Array job submission failed with exception: {e}", + ) + + def run_job(self, job: JobExecutorInterface): + group_or_rule = f"group_{job.name}" if job.is_group() else f"rule_{job.name}" + + wildcard_str = get_job_wildcards(job) self.slurm_logdir.mkdir(parents=True, exist_ok=True) slurm_logfile = self.slurm_logdir / group_or_rule / wildcard_str / "%j.log" @@ -605,23 +1082,16 @@ def run_job(self, job: JobExecutorInterface): "workdir": self.workflow.workdir_init, } - call = get_submit_command(job, job_params) - - if self.workflow.executor_settings.requeue: - call += " --requeue" - - if self.workflow.executor_settings.qos: - call += f" --qos={self.workflow.executor_settings.qos}" - - if self.workflow.executor_settings.reservation: - call += f" --reservation={self.workflow.executor_settings.reservation}" + call = get_submit_command( + job, + job_params, + settings=self.workflow.executor_settings, + failed_nodes=self._failed_nodes, + ) - # we exclude failed nodes from further job submissions, to avoid - # repeated failures. if self._failed_nodes: - call += f" --exclude={','.join(self._failed_nodes)}" self.logger.debug( - f"Excluding the following nodes from job submission: " + "Excluding failed nodes from job submission: " f"{','.join(self._failed_nodes)}" ) @@ -664,6 +1134,7 @@ def run_job(self, job: JobExecutorInterface): subprocess_stdin = sbatch_script self.logger.debug(f"sbatch call: {call}") + time.sleep(5) try: process = subprocess.Popen( call, @@ -681,9 +1152,9 @@ def run_job(self, job: JobExecutorInterface): process.returncode, call, output=err ) except subprocess.CalledProcessError as e: - self.report_job_error( + self._report_job_error_threadsafe( SubmittedJobInfo(job), - msg=( + ( "SLURM sbatch failed. " f"The error message was '{e.output.strip()}'.\n" f" sbatch call:\n {call}\n" @@ -726,7 +1197,7 @@ def run_job(self, job: JobExecutorInterface): ) if cluster_val: self._submitted_job_clusters.add(cluster_val) - self.report_job_submission( + self._report_job_submission_threadsafe( SubmittedJobInfo( job, external_jobid=slurm_jobid, @@ -848,10 +1319,17 @@ async def check_active_jobs( self._status_query_max_seconds, sacct_query_duration, ) - # only take jobs that are still active - active_jobs_ids_with_current_sacct_status = ( - set(status_of_jobs.keys()) & active_jobs_ids - ) + # only take jobs that are still active; for array tasks fall + # back from _ to parent if needed. + status_keys = set(status_of_jobs.keys()) + active_jobs_ids_with_current_sacct_status = { + external_jobid + for external_jobid in active_jobs_ids + if any( + candidate in status_keys + for candidate in _status_lookup_ids(external_jobid) + ) + } missing_sacct_status = ( active_jobs_seen_by_sacct - active_jobs_ids_with_current_sacct_status @@ -868,7 +1346,7 @@ async def check_active_jobs( f"{active_jobs_ids_with_current_sacct_status}" ) self.logger.debug( - "active_jobs_seen_by_sacct are: " f"{active_jobs_seen_by_sacct}" + f"active_jobs_seen_by_sacct are: {active_jobs_seen_by_sacct}" ) self.logger.debug( f"missing_sacct_status are: {missing_sacct_status}" @@ -903,7 +1381,7 @@ async def check_active_jobs( cumulative_avg_duration = ( self._status_query_total_seconds / self._status_query_calls ) - self.logger.info( + self.logger.debug( "Status query timing (cumulative): " f"calls={self._status_query_calls}, " f"failures={self._status_query_failures}, " @@ -925,19 +1403,39 @@ async def check_active_jobs( if status_of_jobs is not None: any_finished = False + self.logger.debug( + f"Status query returned {len(status_of_jobs)} job IDs: " + f"{list(status_of_jobs.keys())[:10]}..." # Show first 10 + ) + self.logger.debug( + f"Checking {len(active_jobs)} active jobs with external IDs: " + f"{[j.external_jobid for j in active_jobs][:10]}..." # Show first 10 + ) for j in active_jobs: + slurm_logfile = j.aux.get("slurm_logfile") + slurm_logfile_str = ( + str(slurm_logfile) if slurm_logfile is not None else None + ) + status_lookup_id = next( + ( + candidate + for candidate in _status_lookup_ids(j.external_jobid) + if candidate in status_of_jobs + ), + None, + ) # the job probably didn't make it into slurmdbd yet, so # `sacct` doesn't return it - if j.external_jobid not in status_of_jobs: + if status_lookup_id is None: # but the job should still be queueing or running and # appear in slurmdbd (and thus `sacct` output) later yield j continue - status = status_of_jobs[j.external_jobid] + status = status_of_jobs[status_lookup_id] if status == "COMPLETED": self.report_job_success(j) any_finished = True - active_jobs_seen_by_sacct.remove(j.external_jobid) + active_jobs_seen_by_sacct.discard(j.external_jobid) if not self.workflow.executor_settings.keep_successful_logs: self.logger.debug( "removing SLURM log for successful job " @@ -967,7 +1465,7 @@ async def check_active_jobs( # so we assume it is finished self.report_job_success(j) any_finished = True - active_jobs_seen_by_sacct.remove(j.external_jobid) + active_jobs_seen_by_sacct.discard(j.external_jobid) elif status == "NODE_FAIL": # this is a special case: the job failed, but due to a node failure. # Always track the failed node so future submissions exclude it, @@ -1011,10 +1509,8 @@ async def check_active_jobs( "such cases by setting the 'requeue' flag in the " "executor settings." ) - self.report_job_error( - j, msg=msg, aux_logs=[j.aux["slurm_logfile"]._str] - ) - active_jobs_seen_by_sacct.remove(j.external_jobid) + self.report_job_error(j, msg=msg, aux_logs=[slurm_logfile_str]) + active_jobs_seen_by_sacct.discard(j.external_jobid) elif status in fail_stati: # we can only check for the fail status, if `sacct` is available if status_command_name != "sacct": @@ -1024,9 +1520,7 @@ async def check_active_jobs( "Detailed failure reason unavailable " "(status command is not 'sacct')." ) - self.report_job_error( - j, msg=msg, aux_logs=[j.aux["slurm_logfile"]._str] - ) + self.report_job_error(j, msg=msg, aux_logs=[slurm_logfile_str]) active_jobs_seen_by_sacct.discard(j.external_jobid) continue reasons = [] @@ -1066,10 +1560,8 @@ async def check_active_jobs( f"SLURM status is: '{status}'. " f"Reasons: {', '.join(reasons)}." ) - self.report_job_error( - j, msg=msg, aux_logs=[j.aux["slurm_logfile"]._str] - ) - active_jobs_seen_by_sacct.remove(j.external_jobid) + self.report_job_error(j, msg=msg, aux_logs=[slurm_logfile_str]) + active_jobs_seen_by_sacct.discard(j.external_jobid) else: # still running? yield j diff --git a/snakemake_executor_plugin_slurm/job_cancellation.py b/snakemake_executor_plugin_slurm/job_cancellation.py index 2f55576f..a44a3703 100644 --- a/snakemake_executor_plugin_slurm/job_cancellation.py +++ b/snakemake_executor_plugin_slurm/job_cancellation.py @@ -26,7 +26,19 @@ def cancel_slurm_jobs( """ if active_jobs: # TODO chunk jobids in order to avoid too long command lines - jobids = " ".join([job_info.external_jobid for job_info in active_jobs]) + # Filter out None values in case some jobs haven't been assigned + # external IDs yet + jobids = " ".join( + [ + job_info.external_jobid + for job_info in active_jobs + if job_info.external_jobid is not None + ] + ) + + if not jobids: + # No valid job IDs to cancel + return try: # timeout set to 60, because a scheduler cycle usually is @@ -68,5 +80,5 @@ def cancel_slurm_jobs( "HPC administrator." ) raise WorkflowError( - "Unable to cancel jobs with scancel " f"(exit code {e.returncode}){msg}" + f"Unable to cancel jobs with scancel (exit code {e.returncode}){msg}" ) from e diff --git a/snakemake_executor_plugin_slurm/submit_string.py b/snakemake_executor_plugin_slurm/submit_string.py index 112eb46d..71810ec6 100644 --- a/snakemake_executor_plugin_slurm/submit_string.py +++ b/snakemake_executor_plugin_slurm/submit_string.py @@ -1,6 +1,7 @@ from snakemake_interface_common.exceptions import WorkflowError from snakemake_executor_plugin_slurm_jobstep import get_cpu_setting from types import SimpleNamespace +import re import shlex @@ -16,13 +17,70 @@ def safe_quote(value): return shlex.quote(str_value) -def get_submit_command(job, params): +def _compute_array_exec_fudge_mb(payload: str) -> int: + """ + Return the estimated extra memory (in MB) that the jobstep process will + need in order to hold and parse the array_exec payload + (the base64-encoded array-execs argument). + + Within the jobstep the array_exec payload goes through several in-memory + representations: + 1. The raw CLI string (~1x payload bytes) + 2. base64-decoded JSON bytes (~0.75x payload bytes) + 3. Python dict of hex strings (~same as the JSON bytes) + Together that is roughly 3x the payload size. We round up to the nearest + MiB and always allocate at least 1 MB so we never return 0. + """ + payload_bytes = len(payload) # payload is pure ASCII, so len == byte count + # Integer ceiling of (3 × payload_bytes) / 1 MiB + return max(1, -(-payload_bytes * 3 // (1024 * 1024))) + + +def apply_mem_fudge(call: str, payload: str) -> str: + """Increase the ``--mem`` or ``--mem-per-cpu`` value already present in + *call* by *fudge_mb* MB. + + If neither flag is present (user submitted without memory constraints) a + bare ``--mem `` is appended. When *fudge_mb* is 0 or negative + the call is returned unchanged. + """ + fudge_mb = _compute_array_exec_fudge_mb(payload) + + # --mem-per-cpu + new_call, n = re.subn( + r"(--mem-per-cpu\s+)(\d+)", + lambda m: f"{m.group(1)}{int(m.group(2)) + fudge_mb}", + call, + count=1, + ) + if n: + return new_call + + # --mem + new_call, n = re.subn( + r"(--mem\s+)(\d+)", + lambda m: f"{m.group(1)}{int(m.group(2)) + fudge_mb}", + call, + count=1, + ) + if n: + return new_call + + # No memory flag present — add one so the fudge is not silently dropped. + return call + f" --mem {fudge_mb}" + + +def get_submit_command( + job, params, settings=None, failed_nodes=None, array_job=False +) -> str: """ Return the submit command for the job. """ # Convert params dict to a SimpleNamespace for attribute-style access params = SimpleNamespace(**params) + failed_nodes = failed_nodes or set() + call = ( "sbatch " "--parsable " @@ -76,6 +134,20 @@ def get_submit_command(job, params): if job.resources.get("nodes", False): call += f" --nodes={job.resources.get('nodes', 1)}" + if settings and settings.requeue: + call += " --requeue" + + if settings and settings.qos: + call += f" --qos={safe_quote(settings.qos)}" + + if settings and settings.reservation: + call += f" --reservation={safe_quote(settings.reservation)}" + + # we exclude failed nodes from further job submissions, to avoid + # repeated failures. + if failed_nodes: + call += f" --exclude={','.join(failed_nodes)}" + gpu_job = job.resources.get("gpu") or "gpu" in job.resources.get("gres", "") if gpu_job: # fixes #316 - allow unsetting of tasks per gpu diff --git a/snakemake_executor_plugin_slurm/utils.py b/snakemake_executor_plugin_slurm/utils.py index 2cd5a0be..78c282b4 100644 --- a/snakemake_executor_plugin_slurm/utils.py +++ b/snakemake_executor_plugin_slurm/utils.py @@ -1,17 +1,108 @@ # utility functions for the SLURM executor plugin +from collections import Counter import math import os +import shlex +import subprocess import re from pathlib import Path from typing import Union +from snakemake_interface_executor_plugins.dag import DAGExecutorInterface from snakemake_interface_executor_plugins.jobs import ( JobExecutorInterface, ) from snakemake_interface_common.exceptions import WorkflowError +def get_max_array_size() -> int: + """ + Function to get the maximum array size for SLURM job arrays. This is used + to determine how many jobs can be submitted in a single array job. + + Returns: + The maximum array size for SLURM job arrays, as an integer. + Defaults to 1000 if the SLURM_ARRAY_MAX environment variable is not set + or cannot be parsed as an integer. + """ + max_array_size_str = None + scontrol_cmd = "scontrol show config" + try: + res = subprocess.run( + shlex.split(scontrol_cmd), + capture_output=True, + text=True, + timeout=5, + ) + out = (res.stdout or "") + (res.stderr or "") + m = re.search(r"MaxArraySize\s*=?\s*(\d+)", out, re.IGNORECASE) + if m: + max_array_size_str = m.group(1) + except (subprocess.SubprocessError, OSError): + max_array_size_str = None + + try: + max_array_size = int(max_array_size_str) + except (ValueError, TypeError): + max_array_size = 1000 + # The SLURM_ARRAY_MAX limits to its value -1 + return max_array_size - 1 + + +def get_job_wildcards(job: JobExecutorInterface) -> str: + """ + Function to get the wildcards of a job as a string. This is used to + create the job name for the SLURM job submission. + + Args: + job: The JobExecutorInterface instance representing the job + Returns: + A string representation of the job's wildcards, with slashes replaced + by underscores. + """ + try: + wildcard_str = ( + "_".join(job.wildcards).replace("/", "_") if job.wildcards else "" + ) + except AttributeError: + wildcard_str = "" + + return wildcard_str + + +def pending_jobs_for_rule(dag: DAGExecutorInterface, rule_name: str) -> int: + """Count jobs of a rule that are currently eligible for scheduling. + + Prefer DAG ``ready_jobs`` if available because it reflects jobs that can + run now. Fall back to ``needrun_jobs`` for compatibility with interfaces + that do not expose ready jobs. + """ + # Previous implementation (kept as requested for reference): + # counts = Counter(job.rule.name for job in dag.needrun_jobs()) + # return counts.get(rule_name, 0) + + jobs = None + + ready_jobs_attr = getattr(dag, "ready_jobs", None) + if callable(ready_jobs_attr): + jobs = ready_jobs_attr() + elif ready_jobs_attr is not None: + jobs = ready_jobs_attr + + if jobs is None: + needrun_jobs_attr = getattr(dag, "needrun_jobs", None) + if callable(needrun_jobs_attr): + jobs = needrun_jobs_attr() + elif needrun_jobs_attr is not None: + jobs = needrun_jobs_attr + else: + jobs = [] + + counts = Counter(job.rule.name for job in jobs) + return counts.get(rule_name, 0) + + def round_half_up(n): return int(math.floor(n + 0.5)) diff --git a/tests/test_array_jobs.py b/tests/test_array_jobs.py new file mode 100644 index 00000000..2868388c --- /dev/null +++ b/tests/test_array_jobs.py @@ -0,0 +1,643 @@ +"""Unit tests for SLURM array-job functionality. + +These tests do NOT require a live SLURM cluster and do NOT subclass +TestWorkflows. They cover: + + - ExecutorSettings array-job field defaults and parsing (TestArrayJobsSettings) + - run_jobs() dispatch routing (TestRunJobsRouting) + - run_array_jobs() sbatch construction and chunking (TestRunArrayJobs) + - _status_lookup_ids() helper edge cases (TestStatusLookupIds) + - check_active_jobs() status resolution for array tasks (TestCheckActiveArrayJobs) +""" + +import asyncio +import base64 +import errno +import json +import re +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from snakemake_executor_plugin_slurm import ( + Executor, + ExecutorSettings, + _status_lookup_ids, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +class _Resources(dict): + """Dict-like resources with attribute access for known keys only.""" + + def __getattr__(self, name): + try: + return self[name] + except KeyError as e: + raise AttributeError(name) from e + + +def _make_mock_job( + rule_name="myrule", + name=None, + wildcards=None, + jobid=1, + is_group=False, + **resources, +): + """Return a minimal mock job compatible with run_jobs / run_array_jobs.""" + mock_resources = _Resources(resources) + + mock_rule = MagicMock() + mock_rule.name = rule_name + + job = MagicMock() + job.resources = mock_resources + job.rule = mock_rule + job.name = name if name is not None else rule_name + job.wildcards = wildcards if wildcards is not None else {} + job.is_group.return_value = is_group + job.threads = resources.get("threads", 1) + job.jobid = jobid + return job + + +def _make_executor_stub(array_jobs=None, array_limit=100): + """Return a minimal Executor stub (bypasses __post_init__ entirely).""" + executor = Executor.__new__(Executor) + executor.logger = MagicMock() + executor.run_uuid = "test-run-uuid" + executor._fallback_account_arg = None + executor._fallback_partition = None + executor._partitions = None + executor._failed_nodes = set() + executor._main_event_loop = None + executor._status_query_calls = 0 + executor._status_query_failures = 0 + executor._status_query_total_seconds = 0.0 + executor._status_query_min_seconds = None + executor._status_query_max_seconds = 0.0 + executor._status_query_cycle_rows = [] + executor._preemption_warning = False + executor._submitted_job_clusters = set() + + # Replicate the array_jobs parsing from Executor.__post_init__ + if array_jobs: + normalized = array_jobs.replace(";", ",") + executor.array_jobs = {r.strip() for r in normalized.split(",") if r.strip()} + else: + executor.array_jobs = set() + executor.max_array_size = int(array_limit) + + executor.slurm_logdir = Path("/tmp/test_slurm_logs") + executor.workflow = SimpleNamespace( + executor_settings=SimpleNamespace( + array_limit=array_limit, + status_attempts=1, + init_seconds_before_status_checks=40, + keep_successful_logs=False, + requeue=False, + qos=None, + reservation=None, + pass_command_as_script=False, + ), + workdir_init=Path("/tmp"), + ) + + executor._job_submission_executor = MagicMock() + executor.report_job_success = MagicMock() + executor.report_job_error = MagicMock() + executor._report_job_submission_threadsafe = MagicMock() + executor._report_job_error_threadsafe = MagicMock() + return executor + + +class TestArrayJobsSettings: + """Tests for ExecutorSettings array-job fields and their defaults.""" + + def test_array_jobs_default_is_none(self): + """array_jobs field defaults to None.""" + settings = ExecutorSettings() + assert settings.array_jobs is None + + def test_array_limit_default_is_1000(self): + """array_limit field defaults to 1000.""" + settings = ExecutorSettings() + assert settings.array_limit == 1000 + + def test_array_jobs_none_yields_empty_set_on_executor(self): + """Executor with array_jobs=None initialises self.array_jobs as empty set.""" + executor = _make_executor_stub(array_jobs=None) + assert executor.array_jobs == set() + + def test_array_jobs_comma_separated_parsed(self): + """Comma-separated rule names are split into a set.""" + executor = _make_executor_stub(array_jobs="rule1, rule2") + assert executor.array_jobs == {"rule1", "rule2"} + + def test_array_jobs_semicolons_normalised(self): + """Semicolons are normalised to commas before splitting.""" + executor = _make_executor_stub(array_jobs="rule1; rule2") + assert executor.array_jobs == {"rule1", "rule2"} + + def test_array_jobs_all_keyword_preserved(self): + """The magic keyword 'all' is preserved as a set member.""" + executor = _make_executor_stub(array_jobs="all") + assert executor.array_jobs == {"all"} + + def test_array_jobs_extra_whitespace_stripped(self): + """Leading/trailing whitespace is stripped from each rule name.""" + executor = _make_executor_stub(array_jobs=" rule1 , rule2 ") + assert executor.array_jobs == {"rule1", "rule2"} + + +class TestRunJobsRouting: + """Tests that run_jobs dispatches to run_job or run_array_jobs correctly.""" + + def test_single_non_array_job_uses_run_job(self): + """One job with no array setting → run_job is enqueued.""" + executor = _make_executor_stub() + job = _make_mock_job(rule_name="myrule") + executor.run_jobs([job]) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_job + + def test_multiple_non_array_jobs_each_get_run_job(self): + """Three jobs, no array setting → three individual run_job submissions.""" + executor = _make_executor_stub() + jobs = [_make_mock_job(rule_name="myrule", jobid=i) for i in range(3)] + executor.run_jobs(jobs) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 3 + for c in calls: + assert c[0][0] == executor.run_job + + def test_array_rule_single_ready_job_falls_back_to_run_job(self): + """Array selected for rule but only 1 ready job → run_job, debug log emitted.""" + executor = _make_executor_stub(array_jobs="myrule") + job = _make_mock_job(rule_name="myrule") + executor.run_jobs([job]) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_job + # A debug-level message explains the single-job fallback + executor.logger.debug.assert_called() + + def test_array_rule_multiple_jobs_use_run_array_jobs(self): + """Array selected + 3 ready jobs for the same rule → one run_array_jobs call.""" + executor = _make_executor_stub(array_jobs="myrule") + jobs = [_make_mock_job(rule_name="myrule", jobid=i) for i in range(3)] + executor.run_jobs(jobs) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_array_jobs + # All 3 jobs are forwarded together + assert calls[0][0][1] == jobs + + def test_group_job_for_array_rule_uses_run_job_with_warning(self): + """Group job whose rule is in array_jobs → run_job; logger.warning called.""" + executor = _make_executor_stub(array_jobs="myrule") + job = _make_mock_job(rule_name="myrule", is_group=True) + executor.run_jobs([job]) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_job + executor.logger.warning.assert_called() + + def test_all_keyword_routes_to_run_array_jobs(self): + """array_jobs='all' + 2 regular jobs for any rule → run_array_jobs.""" + executor = _make_executor_stub(array_jobs="all") + jobs = [_make_mock_job(rule_name="anyrule", jobid=i) for i in range(2)] + executor.run_jobs(jobs) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_array_jobs + + def test_mixed_group_and_regular_jobs_routed_independently(self): + """Group job → run_job; regular job pair for array rule → run_array_jobs.""" + executor = _make_executor_stub(array_jobs="myrule") + group_job = _make_mock_job(rule_name="myrule", jobid=0, is_group=True) + regular_jobs = [ + _make_mock_job(rule_name="myrule", jobid=i) for i in range(1, 3) + ] + executor.run_jobs([group_job] + regular_jobs) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 2 + methods = [c[0][0] for c in calls] + assert executor.run_job in methods + assert executor.run_array_jobs in methods + + def test_array_rule_waits_below_chunk_if_more_eligible_in_dag(self): + """ + With DAG showing more pending jobs, do not submit until chunk + size is reached. + """ + executor = _make_executor_stub(array_jobs="myrule", array_limit=10) + ready_jobs = [_make_mock_job(rule_name="myrule", jobid=i) for i in range(1, 6)] + pending_jobs = [ + _make_mock_job(rule_name="myrule", jobid=i) for i in range(1, 101) + ] + executor.workflow.dag = SimpleNamespace(needrun_jobs=lambda: pending_jobs) + + executor.run_jobs(ready_jobs) + + assert executor._job_submission_executor.submit.call_count == 0 + + def test_array_rule_submits_at_chunk_size_even_if_more_eligible_in_dag(self): + """ + With DAG showing more pending jobs, submit once at least one full + chunk is ready. + """ + executor = _make_executor_stub(array_jobs="myrule", array_limit=10) + ready_jobs = [_make_mock_job(rule_name="myrule", jobid=i) for i in range(1, 11)] + pending_jobs = [ + _make_mock_job(rule_name="myrule", jobid=i) for i in range(1, 101) + ] + executor.workflow.dag = SimpleNamespace(needrun_jobs=lambda: pending_jobs) + + executor.run_jobs(ready_jobs) + + calls = executor._job_submission_executor.submit.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == executor.run_array_jobs + assert calls[0][0][1] == ready_jobs + + +class TestRunArrayJobs: + """Tests for run_array_jobs: sbatch command structure, chunking, error handling.""" + + # --- fixtures & helpers ------------------------------------------------ + + @pytest.fixture + def mock_popen_success(self): + """Popen mock that returns a successful sbatch response with job ID 987654.""" + with patch("snakemake_executor_plugin_slurm.subprocess.Popen") as mock_popen: + proc = MagicMock() + proc.communicate.return_value = ("987654", "") + proc.returncode = 0 + mock_popen.return_value = proc + yield mock_popen + + def _build_executor(self, tmp_path, array_limit=1000): + executor = _make_executor_stub(array_limit=array_limit) + executor.slurm_logdir = tmp_path / "slurm_logs" + executor.get_account_arg = MagicMock( + side_effect=lambda job: iter(["-A testaccount"]) + ) + executor.get_partition_arg = MagicMock(return_value="-p main") + executor.format_job_exec = MagicMock( + side_effect=lambda job: f"snakemake_exec_{job.jobid}" + ) + return executor + + def _make_jobs(self, n=3, rule_name="myrule"): + return [_make_mock_job(rule_name=rule_name, jobid=i) for i in range(1, n + 1)] + + # --- tests ------------------------------------------------------------- + + def test_logfile_per_task_uses_resolved_jobid_and_index( + self, tmp_path, mock_popen_success + ): + """Reported logfile for each task is '_.log' (1-based).""" + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=2) + executor.run_array_jobs(jobs) + + calls = executor._report_job_submission_threadsafe.call_args_list + assert len(calls) == 2 + for idx, c in enumerate(calls, start=1): + job_info = c[0][0] + assert job_info.aux["slurm_logfile"].name == f"987654_{idx}.log" + + def test_external_jobid_per_task_is_jobid_underscore_index( + self, tmp_path, mock_popen_success + ): + """external_jobid for each task is '_' (1-based).""" + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=3) + executor.run_array_jobs(jobs) + + calls = executor._report_job_submission_threadsafe.call_args_list + assert len(calls) == 3 + external_ids = [c[0][0].external_jobid for c in calls] + assert external_ids == ["987654_1", "987654_2", "987654_3"] + + def test_array_execs_task_1_absent_tasks_2_plus_present( + self, tmp_path, mock_popen_success + ): + """The --slurm-jobstep-array-execs map has keys 2,3,… but never 1. + + Task 1 executes via the plain base exec_job; tasks 2+ are encoded + in the compressed map so the job-step can dispatch them. + """ + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=3) + executor.run_array_jobs(jobs) + popen_call_str = mock_popen_success.call_args_list[0][0][0] + match = re.search( + r"--slurm-jobstep-array-execs=(?:'([A-Za-z0-9+/=]+)'|([A-Za-z0-9+/=]+))", + popen_call_str, + ) + assert match, ( + "Could not find --slurm-jobstep-array-execs in sbatch call.\n" + f"Call was: {popen_call_str!r}" + ) + encoded_payload = match.group(1) or match.group(2) + array_execs = json.loads(base64.b64decode(encoded_payload).decode("utf-8")) + assert "1" not in array_execs + assert "2" in array_execs + assert "3" in array_execs + + def test_array_execs_omits_first_task_of_each_chunk(self, tmp_path): + """For each chunk, first task uses base exec command and is absent from map.""" + executor = self._build_executor(tmp_path, array_limit=3) + jobs = self._make_jobs(n=5) + + with patch("snakemake_executor_plugin_slurm.subprocess.Popen") as mock_popen: + proc = MagicMock() + proc.communicate.return_value = ("333333", "") + proc.returncode = 0 + mock_popen.return_value = proc + executor.run_array_jobs(jobs) + + first_call_str = mock_popen.call_args_list[0][0][0] + second_call_str = mock_popen.call_args_list[1][0][0] + + assert '--wrap="snakemake_exec_1 ' in first_call_str + assert '--wrap="snakemake_exec_4 ' in second_call_str + + first_match = re.search( + r"--slurm-jobstep-array-execs=(?:'([A-Za-z0-9+/=]+)'|([A-Za-z0-9+/=]+))", + first_call_str, + ) + second_match = re.search( + r"--slurm-jobstep-array-execs=(?:'([A-Za-z0-9+/=]+)'|([A-Za-z0-9+/=]+))", + second_call_str, + ) + assert first_match and second_match + + first_payload = first_match.group(1) or first_match.group(2) + second_payload = second_match.group(1) or second_match.group(2) + first_map = json.loads(base64.b64decode(first_payload).decode("utf-8")) + second_map = json.loads(base64.b64decode(second_payload).decode("utf-8")) + + assert "1" not in first_map + assert "2" in first_map + assert "3" in first_map + assert "4" not in second_map + assert "5" in second_map + + def test_array_limit_produces_chunked_sbatch_calls(self, tmp_path): + """5 jobs with array_limit=3 → 2 Popen calls: --array=1-3 and --array=4-5.""" + executor = self._build_executor(tmp_path, array_limit=3) + jobs = self._make_jobs(n=5) + + with patch("snakemake_executor_plugin_slurm.subprocess.Popen") as mock_popen: + proc = MagicMock() + proc.communicate.return_value = ("111111", "") + proc.returncode = 0 + mock_popen.return_value = proc + executor.run_array_jobs(jobs) + + assert mock_popen.call_count == 2 + first_call_str = mock_popen.call_args_list[0][0][0] + second_call_str = mock_popen.call_args_list[1][0][0] + assert "--array=1-3" in first_call_str + assert "--array=4-5" in second_call_str + + def test_e2big_retries_with_stdin_script_mode(self, tmp_path): + """If --wrap exceeds argv size, retry once via /dev/stdin script mode.""" + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=3) + + with patch("snakemake_executor_plugin_slurm.subprocess.Popen") as mock_popen: + proc = MagicMock() + proc.communicate.return_value = ("222222", "") + proc.returncode = 0 + mock_popen.side_effect = [ + OSError(errno.E2BIG, "Argument list too long"), + proc, + ] + + executor.run_array_jobs(jobs) + + assert mock_popen.call_count == 2 + first_call_str = mock_popen.call_args_list[0][0][0] + second_call_str = mock_popen.call_args_list[1][0][0] + assert "--wrap=" in first_call_str + assert "/dev/stdin" in second_call_str + assert proc.communicate.call_args.kwargs["input"].startswith("#!/bin/sh") + + def test_non_empty_wildcards_in_comment_triggers_warning( + self, tmp_path, mock_popen_success + ): + """ + When wildcards are non-empty, a warning + about comment limitations is logged. + """ + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=2) + + with patch( + "snakemake_executor_plugin_slurm.get_job_wildcards", + side_effect=["sample_A", "sample_B"], + ): + executor.run_array_jobs(jobs) + + executor.logger.warning.assert_called() + warning_msgs = " ".join(str(c) for c in executor.logger.warning.call_args_list) + assert "wildcard" in warning_msgs.lower() + + def test_no_wildcards_comment_is_plain_rule_name( + self, tmp_path, mock_popen_success + ): + """Empty wildcards → comment is 'rule_'; no wildcard warning.""" + executor = self._build_executor(tmp_path) + jobs = self._make_jobs(n=2) + + # Real get_job_wildcards returns "" for jobs with empty wildcards dict + executor.run_array_jobs(jobs) + + popen_call_str = mock_popen_success.call_args_list[0][0][0] + assert "rule_myrule" in popen_call_str + + # No wildcard-specific warning should have been issued + for c in executor.logger.warning.call_args_list: + assert "wildcard" not in str(c).lower() + + def test_failed_nodes_exclusion_propagated_to_sbatch_call( + self, tmp_path, mock_popen_success + ): + """_failed_nodes set is propagated as --exclude= in the sbatch call.""" + executor = self._build_executor(tmp_path) + executor._failed_nodes = {"bad_node01"} + jobs = self._make_jobs(n=2) + executor.run_array_jobs(jobs) + + popen_call_str = mock_popen_success.call_args_list[0][0][0] + assert "--exclude=bad_node01" in popen_call_str + + +class TestStatusLookupIds: + """Edge-case unit tests for _status_lookup_ids.""" + + def test_plain_numeric_id_returns_single_entry(self): + """A plain numeric job ID returns only itself — no parent appended.""" + assert _status_lookup_ids("12345") == ["12345"] + + def test_array_task_appends_parent_id(self): + """'_' with all-numeric parts appends the parent ID.""" + assert _status_lookup_ids("12345_3") == ["12345_3", "12345"] + + def test_non_numeric_parent_not_treated_as_array(self): + """Non-numeric parent prevents parent-ID fallback.""" + assert _status_lookup_ids("abc_123") == ["abc_123"] + + def test_non_numeric_task_not_treated_as_array(self): + """Non-numeric task index prevents parent-ID fallback.""" + assert _status_lookup_ids("abc_1") == ["abc_1"] + + def test_multiple_underscores_first_split_only(self): + """Only the first underscore is used; extra parts make task non-numeric.""" + # parent="12345" (digits), task="3_extra" (not digits) → no fallback + result = _status_lookup_ids("12345_3_extra") + assert result == ["12345_3_extra"] + + +class _NoopAsyncContext: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +def _make_check_executor(): + """Return an Executor stub wired for check_active_jobs tests.""" + executor = Executor.__new__(Executor) + executor.logger = MagicMock() + executor.run_uuid = "run-uuid" + executor.status_rate_limiter = _NoopAsyncContext() + executor.get_status_command = lambda: "sacct" + executor.next_seconds_between_status_checks = 40 + executor._status_query_calls = 0 + executor._status_query_failures = 0 + executor._status_query_total_seconds = 0.0 + executor._status_query_min_seconds = None + executor._status_query_max_seconds = 0.0 + executor._status_query_cycle_rows = [] + executor._preemption_warning = False + executor._failed_nodes = set() + executor.report_job_success = MagicMock() + executor.report_job_error = MagicMock() + executor.workflow = SimpleNamespace( + executor_settings=SimpleNamespace( + status_attempts=1, + init_seconds_before_status_checks=40, + keep_successful_logs=False, + requeue=False, + ) + ) + return executor + + +def _run_check(executor, active_jobs): + """Drain check_active_jobs into a list synchronously.""" + + async def _collect(): + remaining = [] + async for job in executor.check_active_jobs(active_jobs): + remaining.append(job) + return remaining + + return asyncio.run(_collect()) + + +class TestCheckActiveArrayJobs: + """Tests for check_active_jobs status resolution with array tasks.""" + + def _patch_all(self, monkeypatch, status_dict): + """Patch all external dependencies used by check_active_jobs.""" + + async def _mock_query(command, logger): + return (status_dict, 0.01) + + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status", _mock_query + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status_sacct", + lambda run_uuid: "mock_sacct_cmd", + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.get_min_job_age", lambda: 300 + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.is_query_tool_available", + lambda tool: True, + ) + + def test_task_level_status_takes_precedence_over_parent( + self, monkeypatch, tmp_path + ): + """Task-specific 'COMPLETED' wins over parent-array 'FAILED'.""" + executor = _make_check_executor() + self._patch_all(monkeypatch, {"123_2": "COMPLETED", "123": "FAILED"}) + + log = tmp_path / "123_2.log" + log.write_text("content") + active_job = SimpleNamespace(external_jobid="123_2", aux={"slurm_logfile": log}) + + remaining = _run_check(executor, [active_job]) + + assert remaining == [] + executor.report_job_success.assert_called_once() + executor.report_job_error.assert_not_called() + + def test_all_tasks_of_array_resolved_via_parent_status(self, monkeypatch, tmp_path): + """Multiple array tasks all resolved via parent 'COMPLETED' in one cycle.""" + executor = _make_check_executor() + self._patch_all(monkeypatch, {"123": "COMPLETED"}) + + active_jobs = [ + SimpleNamespace( + external_jobid=f"123_{i}", + aux={"slurm_logfile": tmp_path / f"123_{i}.log"}, + ) + for i in range(1, 4) + ] + + remaining = _run_check(executor, active_jobs) + + assert remaining == [] + assert executor.report_job_success.call_count == 3 + executor.report_job_error.assert_not_called() + + def test_non_terminal_status_keeps_job_active(self, monkeypatch, tmp_path): + """A status not in the terminal set (e.g. 'RUNNING') keeps the job active.""" + executor = _make_check_executor() + self._patch_all(monkeypatch, {"123_1": "RUNNING"}) + + active_job = SimpleNamespace( + external_jobid="123_1", + aux={"slurm_logfile": tmp_path / "123_1.log"}, + ) + + remaining = _run_check(executor, [active_job]) + + assert remaining == [active_job] + executor.report_job_success.assert_not_called() + executor.report_job_error.assert_not_called() diff --git a/tests/test_array_status_lookup.py b/tests/test_array_status_lookup.py new file mode 100644 index 00000000..7fa3dc82 --- /dev/null +++ b/tests/test_array_status_lookup.py @@ -0,0 +1,155 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock + +from snakemake_executor_plugin_slurm import Executor, _status_lookup_ids + + +class _NoopAsyncContext: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +def test_status_lookup_ids_for_array_task(): + assert _status_lookup_ids("1057651_6") == ["1057651_6", "1057651"] + assert _status_lookup_ids("1057651") == ["1057651"] + assert _status_lookup_ids("abc_1") == ["abc_1"] + + +def test_check_active_jobs_uses_parent_array_status(monkeypatch): + executor = Executor.__new__(Executor) + executor.logger = MagicMock() + executor.run_uuid = "run-uuid" + executor.status_rate_limiter = _NoopAsyncContext() + executor.get_status_command = lambda: "sacct" + executor.next_seconds_between_status_checks = 40 + executor._status_query_calls = 0 + executor._status_query_failures = 0 + executor._status_query_total_seconds = 0.0 + executor._status_query_min_seconds = None + executor._status_query_max_seconds = 0.0 + executor._status_query_cycle_rows = [] + executor._preemption_warning = False + executor._failed_nodes = set() + executor.report_job_success = MagicMock() + executor.report_job_error = MagicMock() + + executor.workflow = SimpleNamespace( + executor_settings=SimpleNamespace( + status_attempts=1, + init_seconds_before_status_checks=40, + keep_successful_logs=True, + requeue=False, + ) + ) + + async def _mock_query_job_status(command, logger): + return ({"1057651": "FAILED"}, 0.01) + + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status", + _mock_query_job_status, + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status_sacct", + lambda run_uuid: "mock_sacct_command", + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.get_min_job_age", + lambda: 300, + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.is_query_tool_available", + lambda tool: True, + ) + + active_job = SimpleNamespace( + external_jobid="1057651_1", + aux={"slurm_logfile": Path("fake.log")}, + ) + + async def _collect_remaining_jobs(): + remaining = [] + async for job in executor.check_active_jobs([active_job]): + remaining.append(job) + return remaining + + remaining = asyncio.run(_collect_remaining_jobs()) + + assert remaining == [] + executor.report_job_success.assert_not_called() + executor.report_job_error.assert_called_once() + + +def test_parent_fallback_completed_keeps_array_task_log(monkeypatch, tmp_path): + executor = Executor.__new__(Executor) + executor.logger = MagicMock() + executor.run_uuid = "run-uuid" + executor.status_rate_limiter = _NoopAsyncContext() + executor.get_status_command = lambda: "sacct" + executor.next_seconds_between_status_checks = 40 + executor._status_query_calls = 0 + executor._status_query_failures = 0 + executor._status_query_total_seconds = 0.0 + executor._status_query_min_seconds = None + executor._status_query_max_seconds = 0.0 + executor._status_query_cycle_rows = [] + executor._preemption_warning = False + # ensures working, even if cleanup fails + executor._keep_successful_logs = True + executor._failed_nodes = set() + executor.report_job_success = MagicMock() + executor.report_job_error = MagicMock() + + executor.workflow = SimpleNamespace( + executor_settings=SimpleNamespace( + status_attempts=1, + init_seconds_before_status_checks=40, + keep_successful_logs=False, + requeue=False, + ) + ) + + async def _mock_query_job_status(command, logger): + return ({"1057651": "COMPLETED"}, 0.01) + + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status", + _mock_query_job_status, + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.query_job_status_sacct", + lambda run_uuid: "mock_sacct_command", + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.get_min_job_age", + lambda: 300, + ) + monkeypatch.setattr( + "snakemake_executor_plugin_slurm.is_query_tool_available", + lambda tool: True, + ) + + log_path = tmp_path / "1057651_1.log" + log_path.write_text("content") + + active_job = SimpleNamespace( + external_jobid="1057651_1", + aux={"slurm_logfile": log_path}, + ) + + async def _collect_remaining_jobs(): + remaining = [] + async for job in executor.check_active_jobs([active_job]): + remaining.append(job) + return remaining + + remaining = asyncio.run(_collect_remaining_jobs()) + + assert remaining == [] + executor.report_job_success.assert_called_once() + executor.report_job_error.assert_not_called() diff --git a/tests/test_cli.py b/tests/test_cli.py index ec93fd85..7e2f65b2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,9 @@ """ - Tests for CLI-related executor settings. +Tests for CLI-related executor settings. - This suite will only test settings, which - are to be tested separately from the full - executor functionality. +This suite will only test settings, which +are to be tested separately from the full +executor functionality. """ from unittest.mock import MagicMock, patch diff --git a/tests/testcases/array_jobs/Snakefile b/tests/testcases/array_jobs/Snakefile new file mode 100644 index 00000000..14b1bb23 --- /dev/null +++ b/tests/testcases/array_jobs/Snakefile @@ -0,0 +1,57 @@ +""" +Testcase workflow for SLURM array job integration tests. + +generate_numbers is a localrule that produces 4 input files. +copy_number runs as a SLURM job for each of the 4 inputs, +so the test always submits exactly 4 cluster jobs regardless of +the array-limit setting. +""" + +N = 4 + +localrules: + generate_numbers, + collect + + +rule all: + input: + "results/collected.txt" + + +rule generate_numbers: + output: + expand("numbers/{i}.txt", i=range(1, N + 1)), + shell: + """ + mkdir -p numbers + for i in $(seq 1 4); do + printf "%s\\n" "$i" > "numbers/$i.txt" + done + """ + + +rule copy_number: + input: + "numbers/{i}.txt", + output: + "copied/{i}.txt", + resources: + runtime=5, + shell: + """ + mkdir -p copied + cp {input} {output} + """ + + +rule collect: + input: + expand("copied/{i}.txt", i=range(1, N + 1)), + output: + "results/collected.txt", + shell: + """ + mkdir -p results + cat {input} | sort -n > {output} + """ diff --git a/tests/tests.py b/tests/tests.py index 06c7bee6..3153e717 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1044,3 +1044,122 @@ def test_numbers_in_filenames_excluded(self): Job ID: 999000""" result = validate_or_get_slurm_job_id("999000", output) assert result == "999000" + + +class _LocalTestcasesBase(snakemake.common.tests.TestWorkflowsLocalStorageBase): + """Mixin that resolves testcase paths relative to this plugin's tests/ directory. + + Overrides run_workflow so that testcases not shipped with snakemake itself + (e.g. tests/testcases/array_jobs/) are found correctly. + """ + + def get_config_settings(self) -> Optional[ConfigSettings]: + """Provide default config settings for local testcase workflows.""" + return None + + def run_workflow(self, test_name, tmp_path, deployment_method=frozenset()): + test_path = Path(__file__).parent / "testcases" / test_name + if not test_path.exists(): + return super().run_workflow(test_name, tmp_path, deployment_method) + + if self.omit_tmp: + tmp_path = test_path + else: + tmp_path = Path(tmp_path) / test_name + self._copy_test_files(test_path, tmp_path) + + resource_settings = self.get_resource_settings() + + if self._common_settings().local_exec: + resource_settings.cores = 3 + resource_settings.nodes = None + else: + resource_settings.cores = 1 + resource_settings.nodes = 3 + + with api.SnakemakeApi( + settings.OutputSettings( + verbose=True, + show_failed_logs=True, + ), + ) as snakemake_api: + workflow_api = snakemake_api.workflow( + config_settings=self.get_config_settings(), + resource_settings=resource_settings, + storage_settings=settings.StorageSettings( + default_storage_provider=self.get_default_storage_provider(), + default_storage_prefix=self.get_default_storage_prefix(), + shared_fs_usage=( + settings.SharedFSUsage.all() + if self.get_assume_shared_fs() + else frozenset() + ), + ), + deployment_settings=self.get_deployment_settings(deployment_method), + storage_provider_settings=self.get_default_storage_provider_settings(), + workdir=Path(tmp_path), + snakefile=tmp_path / "Snakefile", + ) + + dag_api = workflow_api.dag() + + if self.create_report: + dag_api.create_report( + reporter=self.get_reporter(), + report_settings=self.get_report_settings(), + ) + else: + dag_api.execute_workflow( + executor=self.get_executor(), + executor_settings=self.get_executor_settings(), + execution_settings=settings.ExecutionSettings( + latency_wait=self.latency_wait, + ), + remote_execution_settings=self.get_remote_execution_settings(), + ) + + +class TestArrayJobsAll(_LocalTestcasesBase): + """Integration test: submit 4 copy_number jobs as one array (array_jobs='all'). + + Uses the testcases/array_jobs workflow which has exactly 4 SLURM jobs + (one copy_number instance per number). With no limit override the default + limit (1000) applies, so all 4 tasks land in a single array submission. + """ + + __test__ = True + + def get_executor(self) -> str: + return "slurm" + + def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: + return ExecutorSettings( + array_jobs="all", + init_seconds_before_status_checks=2, + ) + + def test_array_jobs_all(self, tmp_path): + self.run_workflow("array_jobs", tmp_path) + + +class TestArrayJobsAllWithLimit(_LocalTestcasesBase): + """Integration test: submit 4 copy_number jobs as two arrays of size 2. + + array_limit=2 means run_array_jobs will chunk the 4 ready jobs into + two sbatch submissions of --array=1-2 and --array=3-4. + """ + + __test__ = True + + def get_executor(self) -> str: + return "slurm" + + def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: + return ExecutorSettings( + array_jobs="all", + array_limit=2, + init_seconds_before_status_checks=2, + ) + + def test_array_jobs_all_with_limit(self, tmp_path): + self.run_workflow("array_jobs", tmp_path)