Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,023 changes: 3,023 additions & 0 deletions examples/aceff_examples/alanine-dipeptide-explicit.pdb

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions examples/aceff_examples/ase_aceff_PBC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# This script shows how to use the Atomic Simulation Environment Calculator (ASE)
# inferface of TorchMD-Net with an AceFF model
import time
import sys
import ase
from ase.io import read
from torchmdnet.calculators import TMDNETCalculator

# The AceFFs models are available from HuggingFace under an Apache2.0 license
from huggingface_hub import hf_hub_download

model_file_path = hf_hub_download(
repo_id="Acellera/AceFF-2.0", filename="aceff_v2.0.ckpt"
)

# We create the ASE calculator by supplying the path to the model and specifying the device and dtype
# we provided a cutoff for the coulomb term so we can use PBCs
calc = TMDNETCalculator(model_file_path, device="cuda", coulomb_cutoff=10.0)
atoms = read("alanine-dipeptide-explicit.pdb")

print(atoms)

atoms.calc = calc

# The total molecular charge must be set
atoms.info["charge"] = 0

energy = atoms.get_potential_energy()
print(energy)
forces = atoms.get_forces()
print(forces)


# Molecular dynamics
from ase import units
from ase.md.langevin import Langevin
from ase.md import MDLogger


# setup MD
temperature_K: float = 300
timestep: float = 1.0 * units.fs
friction: float = 0.01 / units.fs
traj_interval: int = 10
log_interval: int = 10
nsteps: int = 100

dyn = Langevin(atoms, timestep, temperature_K=temperature_K, friction=friction)
dyn.attach(lambda: ase.io.write("traj.xyz", atoms, append=True), interval=traj_interval)
dyn.attach(MDLogger(dyn, atoms, sys.stdout), interval=log_interval)


# Run the dynamics
t1 = time.perf_counter()
dyn.run(steps=nsteps)
t2 = time.perf_counter()

print(f"Completed MD in {t2 - t1:.1f} s ({(t2 - t1)*1000 / nsteps:.3f} ms/step)")
146 changes: 146 additions & 0 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,149 @@ def test_ase_calculator(device):
atoms.calc = calc
# Run more dynamics
dyn.run(steps=nsteps)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_aceff2_coulomb_cutoff(device):
"""Test AceFF-2 on caffeine (no PBC) with TMDNETCalculator.

1. Compute energy with no coulomb_cutoff (all-to-all N^2 Coulomb).
2. Compute energy with very large cutoff
3. Compute energy with decreasing large coulomb_cutoffs (200 → 20 Å).
Caffeine is ~10 Å across, so all pairs are within cutoff in every case.
The Reaction Field correction is small at large cutoffs, so energies
should be close to the no-cutoff reference.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
pytest.skip("huggingface_hub not available")

try:
model_file_path = hf_hub_download(
repo_id="Acellera/AceFF-2.0", filename="aceff_v2.0.ckpt"
)
except Exception:
pytest.skip("Could not download AceFF-2.0 model from HuggingFace")

from torchmdnet.calculators import TMDNETCalculator
from ase.io import read
import numpy as np
import os

curr_dir = os.path.dirname(__file__)
caffeine_pdb = os.path.join(curr_dir, "caffeine.pdb")

# --- reference: no coulomb_cutoff, all-to-all Coulomb, no PBC ---
atoms_ref = read(caffeine_pdb)
atoms_ref.info["charge"] = 0
calc_ref = TMDNETCalculator(model_file_path, device=device)
atoms_ref.calc = calc_ref
energy_no_cutoff = atoms_ref.get_potential_energy()

# in the limit of infinite cutoff the reaction field calculation should give the same energy
calc_inf = TMDNETCalculator(
model_file_path, device=device, coulomb_cutoff=1e6, coulomb_max_num_neighbors=20
)
atoms_inf = read(caffeine_pdb)
atoms_inf.info["charge"] = 0
atoms_inf.calc = calc_inf
energy_inf = atoms_inf.get_potential_energy()
np.testing.assert_allclose(
energy_inf,
energy_no_cutoff,
rtol=1e-5,
atol=1e-5,
)

# try changing the cutoff
for cutoff in [2000.0, 200.0, 100.0, 50.0, 20.0]:
atoms_c = read(caffeine_pdb)
atoms_c.info["charge"] = 0
calc_c = TMDNETCalculator(
model_file_path,
device=device,
coulomb_cutoff=cutoff,
coulomb_max_num_neighbors=64,
)
atoms_c.calc = calc_c
energy_cutoff = atoms_c.get_potential_energy()
# energy should increase as cutoff decreases as we are adding the reaction field energy
assert energy_cutoff > energy_no_cutoff

# the difference should be small
assert (energy_cutoff - energy_no_cutoff) / energy_no_cutoff < 0.01


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_aceff2_pbc_vs_no_pbc(device):
"""Test AceFF-2 on explicit-solvent alanine-dipeptide with TMDNETCalculator.

The PDB has a CRYST1 record (cell ~32.8 x 32.9 x 31.9 Å, 2269 atoms).

1. Evaluate with PBC disabled – atoms see only intra-box neighbours.
2. Evaluate with PBC enabled – atoms see periodic images via the box.

The two energies must differ, which proves that PBC is actually applied
rather than silently ignored.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
pytest.skip("huggingface_hub not available")

try:
model_file_path = hf_hub_download(
repo_id="Acellera/AceFF-2.0", filename="aceff_v2.0.ckpt"
)
except Exception:
pytest.skip("Could not download AceFF-2.0 model from HuggingFace")

from torchmdnet.calculators import TMDNETCalculator
from ase.io import read
import numpy as np
import os

pdb_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"examples",
"aceff_examples",
"alanine-dipeptide-explicit.pdb",
)
coulomb_cutoff = 10.0

# --- no PBC: disable the periodic box read from CRYST1 ---
atoms_nopbc = read(pdb_path)
atoms_nopbc.info["charge"] = 0
atoms_nopbc.pbc = False
calc_nopbc = TMDNETCalculator(
model_file_path, device=device, coulomb_cutoff=coulomb_cutoff
)
atoms_nopbc.calc = calc_nopbc
energy_nopbc = atoms_nopbc.get_potential_energy()
print(f"No PBC: {energy_nopbc}")

# --- with PBC: keep the cell and PBC flags from the CRYST1 record ---
atoms_pbc = read(pdb_path)
atoms_pbc.info["charge"] = 0
assert atoms_pbc.pbc.all(), "Expected full PBC from CRYST1 record"
calc_pbc = TMDNETCalculator(
model_file_path, device=device, coulomb_cutoff=coulomb_cutoff
)
atoms_pbc.calc = calc_pbc
energy_pbc = atoms_pbc.get_potential_energy()
print(f"PBC: {energy_pbc}")

# PBC adds periodic-image contributions to both the NN and Coulomb terms,
# so the energies must be meaningfully different.
assert not np.isclose(energy_pbc, energy_nopbc, rtol=1e-3), (
f"PBC energy ({energy_pbc:.4f}) is too close to no-PBC energy ({energy_nopbc:.4f}); "
"PBC may not be applied"
)


if __name__ == "__main__":
test_aceff2_coulomb_cutoff("cpu")
test_aceff2_coulomb_cutoff("cuda")
test_aceff2_pbc_vs_no_pbc("cpu")
test_aceff2_pbc_vs_no_pbc("cuda")
26 changes: 19 additions & 7 deletions torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,17 @@ def __init__(

ase_calc.Calculator.__init__(self)
self.device = device
self.remove_ref_energy = kwargs.pop("remove_ref_energy", True)
self.max_num_neighbors = kwargs.pop("max_num_neighbors", 64)
self.static_shapes = True if compile else False
self.model_file = model_file
self.model_kwargs = kwargs
self.model = load_model(
model_file,
self.model_file,
derivative=False,
remove_ref_energy=kwargs.get("remove_ref_energy", True),
max_num_neighbors=kwargs.get("max_num_neighbors", 64),
static_shapes=True if compile else False,
remove_ref_energy=self.remove_ref_energy,
max_num_neighbors=self.max_num_neighbors,
static_shapes=self.static_shapes,
**kwargs,
)
for parameter in self.model.parameters():
Expand Down Expand Up @@ -262,6 +267,11 @@ def calculate(
positions = atoms.positions
total_charge = atoms.info["charge"]
batch = [0 for _ in range(len(numbers))]
if atoms.pbc.any():
cell = atoms.cell.array
box = torch.tensor(cell, device=self.device, dtype=torch.float32)
else:
box = None

batch = torch.tensor(batch, device=self.device, dtype=torch.long)
numbers = torch.tensor(numbers, device=self.device, dtype=torch.long)
Expand All @@ -282,7 +292,7 @@ def calculate(
# This is needed because torch.compile doesn't support .item() calls
self.model.to(self.device)
with torch.no_grad():
_ = self.model(numbers, positions, batch=batch, q=total_charge)
_ = self.model(numbers, positions, batch=batch, q=total_charge, box=box)

self.compiled_model = torch.compile(
self.model,
Expand All @@ -295,10 +305,12 @@ def calculate(

if self.compiled:
energy, _ = self.compiled_model(
numbers, positions, batch=batch, q=total_charge
numbers, positions, batch=batch, q=total_charge, box=box
)
else:
energy, _ = self.model(numbers, positions, batch=batch, q=total_charge)
energy, _ = self.model(
numbers, positions, batch=batch, q=total_charge, box=box
)

energy.backward()
forces = -positions.grad
Expand Down
13 changes: 12 additions & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def create_model(args, prior_model=None, mean=None, std=None):
), # https://github.com/torchmd/torchmd-net/issues/343
q_dim=args.get("q_dim", 0),
q_weights=args.get("q_weights", []),
coulomb_cutoff=args.get("coulomb_cutoff", None),
coulomb_max_num_neighbors=args.get("coulomb_max_num_neighbors", None),
coulomb_neighbor_strategy=args.get("coulomb_neighbor_strategy", "brute"),
)

# combine representation and output network
Expand Down Expand Up @@ -281,6 +284,14 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
(3, 3), device="cpu"
)

# same for this one:
if (
"coulomb_cutoff" in args
and args["coulomb_cutoff"] is not None
and "output_model.distance.box" not in state_dict
):
state_dict["output_model.distance.box"] = torch.zeros((3, 3), device="cpu")

model.load_state_dict(state_dict)
return model.to(device)

Expand Down Expand Up @@ -493,7 +504,7 @@ def forward(
z, pos, batch, box=box, q=q, s=s
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
x = self.output_model.pre_reduce(x, v, z, pos, batch, box=box)

# scale by data standard deviation
if self.std is not None:
Expand Down
Loading