Source code for a3fe.configuration.slurm_config

"""Configuration classes for SLURM configuration."""

__all__ = [
    "SlurmConfig",
]

import os as _os
import re as _re
import subprocess as _subprocess
from typing import Dict as _Dict
from typing import List as _List

import yaml as _yaml
from pydantic import BaseModel as _BaseModel
from pydantic import ConfigDict as _ConfigDict
from pydantic import Field as _Field


[docs]class SlurmConfig(_BaseModel): """ Pydantic model for holding a SLURM configuration. """ partition: str = _Field("default", description="SLURM partition to submit to.") time: str = _Field("24:00:00", description="Time limit for the SLURM job.") gres: str = _Field("gpu:1", description="Resources to request - normally one GPU.") nodes: int = _Field(1, ge=1) ntasks_per_node: int = _Field(1, ge=1) output: str = _Field( "slurm-%A.%a.out", description="Output file for the SLURM job." ) extra_options: _Dict[str, str] = _Field( {}, description="Extra options to pass to SLURM. For example, {'account': 'qt'}" ) queue_check_interval: int = _Field( 30, ge=1, description="Interval in seconds between SLURM queue status checks." ) job_submission_wait: int = _Field( 300, ge=1, description="Wait time in seconds for job submission to SLURM queue." ) model_config = _ConfigDict(validate_assignment=True)
[docs] def get_submission_cmds( self, cmd: str, run_dir: str, script_name: str = "a3fe" ) -> _List[str]: """ Generates the SLURM submission commands list based on the configuration. Parameters ---------- cmd : str Command to run during the SLURM job. run_dir : str Directory to submit the SLURM job from. script_name : str, optional, default="a3fe" Name of the script file to write. Note that when running many jobs from the same directory, this should be unique to avoid overwriting the script file. Returns ------- List[str] The list of SLURM arguments. """ # First, write the script to a file script_path = _os.path.join(run_dir, f"{script_name}.sh") script = ( "#!/bin/bash\n" f"#SBATCH --partition={self.partition}\n" f"#SBATCH --time={self.time}\n" f"#SBATCH --gres={self.gres}\n" f"#SBATCH --nodes={self.nodes}\n" f"#SBATCH --ntasks-per-node={self.ntasks_per_node}\n" f"#SBATCH --output={self.output}\n" ) for key, value in self.extra_options.items(): script += f"#SBATCH --{key}={value}\n" script += f"\n{cmd}\n" with open(script_path, "w") as f: f.write(script) return ["sbatch", f"--chdir={run_dir}", script_path]
[docs] def get_slurm_output_file_base(self, run_dir: str) -> str: """ Get the base name of the SLURM output file. Parameters ---------- run_dir : str Directory the job was submitted from. Returns ------- str The base name of the SLURM output file. """ return run_dir + "/" + self.output.split("%")[0]
[docs] @classmethod def get_default_partition(cls) -> "str": """Get the default SLURM partition.""" sinfo = _subprocess.run( ["sinfo", "-o", "%P", "-h"], stdout=_subprocess.PIPE, text=True ) # Search for the default queue (marked with "*", then throw away the "*") return _re.search(r"([^\s]+)(?=\*)", sinfo.stdout).group(1)
[docs] def dump(self, save_dir: str) -> None: """ Dumps the configuration to a YAML file. Parameters ---------- save_dir : str Directory to save the YAML file to. """ model_dict = self.model_dump() save_path = save_dir + "/" + self.get_file_name() with open(save_path, "w") as f: _yaml.dump(model_dict, f, default_flow_style=False)
[docs] @classmethod def load(cls, load_dir: str) -> "SlurmConfig": """ Loads the configuration from a YAML file. Parameters ---------- load_dir : str Directory to load the YAML file from. Returns ------- SlurmConfig The loaded configuration. """ with open(load_dir + "/" + cls.get_file_name(), "r") as f: model_dict = _yaml.safe_load(f) return cls(**model_dict)
[docs] @staticmethod def get_file_name() -> str: """ Get the name of the SLURM configuration file. """ return "slurm_config.yaml"