Skip to content

Commit f7dbd10

Browse files
authored
Bugfix for MD restarts with D3+mace (#657)
1 parent b3c7bc4 commit f7dbd10

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

janus_core/calculations/md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def _prepare_restart(self) -> None:
817817
struct=last_restart,
818818
read_kwargs=self.read_kwargs,
819819
sequence_allowed=False,
820-
arch=self.arch,
820+
arch=self.arch.removesuffix("_d3"),
821821
device=self.device,
822822
model=self.model,
823823
calc_kwargs=self.calc_kwargs,

tests/test_md.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,59 @@ def test_restart(tmp_path):
369369
assert len(traj) == 9
370370

371371

372+
def test_restart_with_d3(tmp_path):
373+
"""Test restarting molecular dynamics simulation (with D3)."""
374+
file_prefix = tmp_path / "Cl4Na4-nvt-T300.0"
375+
traj_path = tmp_path / "Cl4Na4-nvt-T300.0-traj.extxyz"
376+
stats_path = tmp_path / "Cl4Na4-nvt-T300.0-stats.dat"
377+
378+
single_point = SinglePoint(
379+
struct=DATA_PATH / "NaCl.cif",
380+
arch="mace",
381+
model=MODEL_PATH,
382+
calc_kwargs={"dispersion": True},
383+
)
384+
nvt = NVT(
385+
struct=single_point.struct,
386+
temp=300.0,
387+
steps=4,
388+
traj_every=1,
389+
restart_every=4,
390+
stats_every=1,
391+
file_prefix=file_prefix,
392+
calc_kwargs={"dispersion": True},
393+
)
394+
nvt.run()
395+
396+
assert nvt.dyn.nsteps == 4
397+
398+
nvt_restart = NVT(
399+
struct=single_point.struct,
400+
temp=300.0,
401+
steps=8,
402+
traj_every=1,
403+
restart_every=4,
404+
stats_every=1,
405+
restart=True,
406+
restart_auto=False,
407+
file_prefix=file_prefix,
408+
calc_kwargs={"dispersion": True},
409+
)
410+
nvt_restart.run()
411+
assert nvt_restart.offset == 4
412+
413+
with open(stats_path, encoding="utf8") as stats_file:
414+
lines = stats_file.readlines()
415+
assert " | Target_T [K]" in lines[0]
416+
# Includes step 0, and step 4 from restart
417+
assert len(lines) == 10
418+
assert int(lines[-1].split()[0]) == 8
419+
420+
traj = read(traj_path, index=":")
421+
assert all(isinstance(image, Atoms) for image in traj)
422+
assert len(traj) == 9
423+
424+
372425
def test_minimize(tmp_path):
373426
"""Test geometry optimzation before dynamics."""
374427
file_prefix = tmp_path / "Cl4Na4-nvt-T300.0"

0 commit comments

Comments
 (0)