diff --git a/snakemake_executor_plugin_slurm/utils.py b/snakemake_executor_plugin_slurm/utils.py index 97df1faf..2cd5a0be 100644 --- a/snakemake_executor_plugin_slurm/utils.py +++ b/snakemake_executor_plugin_slurm/utils.py @@ -251,8 +251,9 @@ def set_gres_string(job: JobExecutorInterface) -> str: based on the resources requested in the job. """ # generic resources (GRES) arguments can be of type - # "string:int" or "string:string:int" - gres_re = re.compile(r"^[a-zA-Z0-9_]+(:[a-zA-Z0-9_\.]+)?:\d+$") + # "string:int" or "string:string:int" with optional postfix 'T' or 'G' or 'M' + gres_re = re.compile(r"^[a-zA-Z0-9_]+(:[a-zA-Z0-9_\.]+)?:\d+[TGM]?$") + # gpu model arguments can be of type "string" # The model string may contain a dot for variants, see # https://github.com/snakemake/snakemake-executor-plugin-slurm/issues/387 @@ -288,14 +289,16 @@ def set_gres_string(job: JobExecutorInterface) -> str: "GRES format should not be a nested string (start " "and end with ticks or quotation marks). " "Expected format: " - "':' or '::' " + "':' or '::' with an optional " + "'T' 'M' or 'G' postfix " "(e.g., 'gpu:1' or 'gpu:tesla:2')" ) else: raise WorkflowError( f"Invalid GRES format: {gres}. Expected format: " - "':' or '::' " - "(e.g., 'gpu:1' or 'gpu:tesla:2')" + "':' or '::' with an optional " + "'T' 'M' or 'G' postfix " + "(e.g., 'gpu:1' or 'gpu:tesla:2') " ) return f" --gres={job.resources.gres}" diff --git a/tests/tests.py b/tests/tests.py index 301b2742..06c7bee6 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -368,6 +368,20 @@ def test_gpu_model_without_gpu(self, mock_job): ): set_gres_string(job) + def test_tmpspace_gres_10G(self, mock_job): + """Test with valid GRES format (simple).""" + job = mock_job(gres="tmpspace:10G") + + # Patch subprocess.Popen to capture the sbatch command + with patch("subprocess.Popen") as mock_popen: + # Configure the mock to return successful submission + process_mock = MagicMock() + process_mock.communicate.return_value = ("123", "") + process_mock.returncode = 0 + mock_popen.return_value = process_mock + + assert set_gres_string(job) == " --gres=tmpspace:10G" + def test_both_gres_and_gpu_set(self, mock_job): """Test error case when both GRES and GPU are specified.""" job = mock_job(gres="gpu:1", gpu="2")