Source code for a3fe.configuration.engine_config

"""Configuration classes for SOMD engine configuration."""

__all__ = [
    "SomdConfig",
]

import math as _math
import os as _os
from abc import ABC as _ABC
from abc import abstractmethod as _abstractmethod
from decimal import Decimal as _Decimal
from typing import (
    Dict as _Dict,
)
from typing import (
    List as _List,
)
from typing import (
    Literal as _Literal,
)
from typing import (
    Optional as _Optional,
)
from typing import (
    Union as _Union,
)

import yaml as _yaml
from pydantic import (
    BaseModel as _BaseModel,
)
from pydantic import (
    Field as _Field,
)
from pydantic import (
    model_validator as _model_validator,
)


class _EngineConfig(_BaseModel, _ABC):
    """Base class for engine runner configurations."""

    model_config = {
        "populate_by_name": True,
        "validate_assignment": True,
    }

    @staticmethod
    @_abstractmethod
    def get_file_name() -> str:
        """
        Get the name of the configuration file.
        """
        pass

    def dump(self, save_dir: str) -> None:
        """
        Dump the configuration to a YAML file using `self.model_dump()`.

        Parameters
        ----------
        save_dir : str
            Directory to dump the configuration to.
        """
        model_dict = self.model_dump()

        save_path = save_dir + "/" + self.get_file_name()
        with open(save_path, "w") as f:
            _yaml.dump(model_dict, f, default_flow_style=False)

    @classmethod
    def load(cls, load_dir: str) -> "_EngineConfig":
        """
        Load a configuration from a YAML file.

        Parameters
        ----------
        load_dir : str
            Directory to load the configuration from.

        Returns
        -------
        config : EngineConfig
            The loaded configuration.
        """
        with open(load_dir + "/" + cls.get_file_name(), "r") as f:
            model_dict = _yaml.safe_load(f)

        return cls(**model_dict)

    @_abstractmethod
    def write_config(
        self,
        run_dir: str,
        lambda_val: float,
        runtime: float,
        top_file: str,
        coord_file: str,
        morph_file: str,
    ) -> None:
        """
        Write the configuration to a file.
        """
        pass

    @_abstractmethod
    def get_run_cmd(self) -> str:
        """
        Get the command to run the simulation.
        """
        pass


[docs]class SomdConfig(_EngineConfig): """ Pydantic model for holding SOMD engine configuration. """ ### Integrator - ncycles modified as required by a3fe ### timestep: float = _Field(4.0, description="Timestep in femtoseconds(fs)") max_nmoves: int = _Field( 250000, description="Maximum number of moves per cycle. nmoves and ncycles are computed from runtime, timestep, and max_nmoves.", ) runtime: _Union[int, float] = _Field( 5.0, description="Runtime in nanoseconds(ns), must be a multiple of timestep and ncycles will be calculated from runtime and nmoves", ) ### Constraints ### constraint: _Literal["hbonds", "allbonds"] = _Field( "hbonds", description="Constraint type, must be hbonds or allbonds" ) hydrogen_mass_factor: float = _Field( 3.0, ge=1.0, le=4.0, alias="hydrogen mass repartitioning factor", description="Hydrogen mass repartitioning factor", ) integrator: _Literal["langevinmiddle", "leapfrogverlet"] = _Field( "langevinmiddle", description="Integration algorithm" ) ### Thermostatting already handled by langevin integrator thermostat: bool = _Field( False, description="Enable the thermodstat. Not required if using langevinmiddle integrator", ) inverse_friction: float = _Field( 1.0, ge=0.1, lt=10.0, description="Inverse friction in picoseconds", alias="inverse friction", ) temperature: float = _Field( 25.0, ge=-200.0, le=1000.0, description="Temperature in Celsius" ) ### Barostat ### barostat: bool = _Field(True, description="Enable barostat") pressure: float = _Field(1.0, gt=0.0, lt=1000.0, description="Pressure in atm") ### Non-Bonded Interactions ### cutoff_type: _Literal["cutoffperiodic", "PME"] = _Field( "cutoffperiodic", description="Type of cutoff to use. Options: PME, cutoffperiodic", ) cutoff_distance: float = _Field( 12.0, alias="cutoff distance", ge=6.0, le=18.0, description="Cutoff distance in angstroms (6-18). Default 12.0 for cutoffperiodic.", ) reaction_field_dielectric: float = _Field( 78.3, alias="reaction field dielectric", description="Reaction field dielectric constant (only for cutoffperiodic). " "If cutoff type is PME, this value is ignored", ) ### Trajectory ### buffered_coords_freq: int = _Field( 5000, alias="buffered coordinates frequency", description="Frequency of buffered coordinates output", ) center_solute: bool = _Field( True, alias="center solute", description="Center solute in box" ) ### Minimisation ### minimise: bool = _Field(True, description="Perform energy minimisation") ### Restraints ### use_boresch_restraints: bool = _Field( False, description="Use Boresch restraints mode" ) turn_on_receptor_ligand_restraints: bool = _Field( False, description="Turn on receptor-ligand restraints mode" ) ### Alchemistry - restraints added by a3fe ### lambda_values: _Optional[_List[float]] = _Field( None, description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stage", ) perturbed_residue_number: int = _Field( 1, alias="perturbed residue number", ge=1, description="Residue number to perturb. Must be >= 1", ) energy_frequency: int = _Field( 200, alias="energy frequency", description="Frequency of energy output" ) ligand_charge: int = _Field( 0, description="Net charge of the ligand. If non-zero, must use PME for electrostatics.", ) boresch_restraints_dictionary: _Optional[str] = _Field( None, description="Optional string to hold boresch restraints dictionary content", ) ### Extra options ### extra_options: _Dict[str, str] = _Field( default_factory=dict, description="Extra options to pass to the SOMD engine" ) def _get_total_nmoves(self) -> int: """Calculate total number of moves from runtime and timestep.""" runtime_fs = _Decimal(str(self.runtime)) * _Decimal("1_000_000") timestep = _Decimal(str(self.timestep)) return int(runtime_fs / timestep) @property def nmoves(self) -> int: """ Number of moves per cycle. If total_nmoves <= max_nmoves, returns total_nmoves (ncycles=1). Otherwise returns the largest factor of total_nmoves that is both <= max_nmoves and divisible by energy_frequency(default 200), ensuring that energy output points align with every cycle boundary. """ total_nmoves = self._get_total_nmoves() if total_nmoves <= self.max_nmoves: return total_nmoves best = 1 for d in range(1, int(_math.isqrt(total_nmoves)) + 1): if total_nmoves % d == 0: for candidate in (d, total_nmoves // d): if ( candidate <= self.max_nmoves and candidate % self.energy_frequency == 0 ): best = max(best, candidate) return best @property def ncycles(self) -> int: """Number of cycles, computed as total_nmoves / nmoves.""" return max(1, self._get_total_nmoves() // self.nmoves) @_model_validator(mode="after") def _validate_runtime_timestep_nmoves(self): """Validate that runtime is a multiple of both timestep and energy_frequency * timestep.""" if self.max_nmoves < self.energy_frequency: raise ValueError( f"max_nmoves ({self.max_nmoves}) must be >= energy_frequency " f"({self.energy_frequency}) so that each cycle contains at least one energy output." ) runtime_fs = _Decimal(str(self.runtime)) * _Decimal("1_000_000") timestep = _Decimal(str(self.timestep)) if round(float(runtime_fs % timestep), 4) != 0: raise ValueError( f"Runtime must be a multiple of timestep. " f"Runtime is {self.runtime} ns ({runtime_fs} fs), " f"timestep is {self.timestep} fs." ) energy_block_fs = timestep * _Decimal(str(self.energy_frequency)) if round(float(runtime_fs % energy_block_fs), 4) != 0: raise ValueError( f"Runtime must be a multiple of energy_frequency * timestep " f"({self.energy_frequency} * {self.timestep} fs = {float(energy_block_fs)} fs). " f"Runtime is {self.runtime} ns ({runtime_fs} fs)." ) return self @_model_validator(mode="after") def _check_rf_dielectric(self): """Warn the user if they've changed the rf dielectric constant but are using PME""" if ( self.cutoff_type == "cutoffperiodic" and self.reaction_field_dielectric != 78.3 ): self._logger.warning( "You have cutoff_type=cutoffperiodic but set a reaction_field_dielectric. This will result in the use of PME." f"This value ({self.reaction_field_dielectric}) will be ignored by the engine." ) return self @_model_validator(mode="after") def _check_ligand_charge(self): if self.ligand_charge != 0 and self.cutoff_type != "PME": raise ValueError( "Charge difference is non-zero but cutoff type is not PME." ) return self @_model_validator(mode="after") def _validate_integrator_and_thermo(self): integrator = self.integrator thermostat = self.thermostat # Use default value if None # 1) integrator='langevinmiddle' => thermostat must be False # 2) integrator='leapfrogverlet' => thermostat must be True if integrator == "langevinmiddle" and thermostat is True: raise ValueError( "If integrator is 'langevinmiddle', thermostat must be False." ) elif integrator == "leapfrogverlet" and thermostat is False: raise ValueError( "If integrator is 'leapfrogverlet', thermostat must be True." ) return self
[docs] @staticmethod def get_file_name() -> str: """ Get the name of the SOMD configuration file. """ return "somd.cfg"
[docs] def write_config( self, run_dir: str, lambda_val: float, runtime: float, top_file: str, coord_file: str, morph_file: str, ) -> None: """ Generates the SOMD configuration file and returns its path. Parameters ---------- run_dir : str Directory to write the configuration file to. lambda_val : float Current lambda value runtime : float Total runtime in nanoseconds. top_file : str Path to the topology file. coord_file : str Path to the coordinate file. morph_file : str Path to the morph file. """ self.runtime = runtime if self.lambda_values is None: raise ValueError( "lambda_array must be set before writing the configuration." ) config_lines = [ "### Integrator ###", f"timestep = {self.timestep} * femtosecond", f"ncycles = {self.ncycles}", f"nmoves = {self.nmoves}", f"constraint = {self.constraint}", f"hydrogen mass repartitioning factor = {self.hydrogen_mass_factor}", f"integrator = {self.integrator}", f"inverse friction = {self.inverse_friction} * picosecond", f"temperature = {self.temperature} * celsius", f"thermostat = {self.thermostat}", "\n\n### Barostat ###", f"barostat = {self.barostat}", f"pressure = {self.pressure} * atm", "\n\n### Non-Bonded Interactions ###", f"cutoff type = {self.cutoff_type}", f"cutoff distance = {self.cutoff_distance} * angstrom", ] if ( self.cutoff_type == "cutoffperiodic" and self.reaction_field_dielectric is not None ): config_lines.append( f"reaction field dielectric = {self.reaction_field_dielectric}" ) config_lines.extend( [ "\n\n### Trajectory ###", f"buffered coordinates frequency = {self.buffered_coords_freq}", f"center solute = {self.center_solute}", "\n\n### Minimisation ###", f"minimise = {self.minimise}", "\n\n### Alchemistry ###", f"perturbed residue number = {self.perturbed_residue_number}", f"energy frequency = {self.energy_frequency}", f"ligand charge = {self.ligand_charge}", f"lambda array = {', '.join(str(x) for x in self.lambda_values)}", f"lambda_val = {lambda_val}", "\n\n### Restraints ###", f"use boresch restraints = {self.use_boresch_restraints}", f"turn on receptor-ligand restraints mode = {self.turn_on_receptor_ligand_restraints}", "\n\n###Paths###", f"morphfile = {_os.path.join(run_dir, morph_file)}", f"topfile = {_os.path.join(run_dir, top_file)}", f"crdfile = {_os.path.join(run_dir, coord_file)}", ] ) # 5) Boresch restraints if self.boresch_restraints_dictionary is not None: config_lines.extend(["", "### Boresch Restraints Dictionary ###"]) config_lines.append( f"boresch restraints dictionary = {self.boresch_restraints_dictionary}" ) # Add any extra options if self.extra_options: config_lines.extend(["", "### Extra Options ###"]) for key, value in self.extra_options.items(): config_lines.append(f"{key} = {value}") # Write the configuration to a file config_filename = self.get_file_name() config_path = _os.path.join(run_dir, config_filename) with open(config_path, "w") as f: f.write("\n".join(config_lines) + "\n")
@classmethod def _from_config_file(cls, config_path: str) -> "SomdConfig": """Create a SomdConfig instance from an existing configuration file.""" with open(config_path, "r") as f: config_content = f.read() config_dict = {} for line in config_content.split("\n"): line = line.strip() if line and not line.startswith("#") and "=" in line: key, value = [x.strip() for x in line.split("=", 1)] if key == "lambda array": value = [float(x.strip()) for x in value.split(",")] config_dict["lambda_values"] = value continue elif "*" in value: value = value.split("*")[0].strip() try: value = float(value) except ValueError: pass elif value.lower() == "true": value = True elif value.lower() == "false": value = False else: try: if "." in value: value = float(value) elif value.isdigit(): value = int(value) except ValueError: pass key = key.replace(" ", "_") config_dict[key] = value return cls(**config_dict)
[docs] def get_run_cmd(self, lam: float) -> str: """ Get the command to run the simulation. """ return f"somd-freenrg -C {self.get_file_name()} -l {lam} -p CUDA"