Skip to content

Commit bd388e6

Browse files
Fix CSVR temperature ramp (#414)
* Fix CSVR temperature ramp * Replace warnings with errors * Add temperature units Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> * Fix MTK _set_target_temperature * Add heating tests for all ensembles * Rework NPT_MTK test skip * Fix tests * Fix tests --------- Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
1 parent 0debc49 commit bd388e6

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

janus_core/calculations/md.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ def __init__(
423423
"heating to run",
424424
stacklevel=2,
425425
)
426+
if self.ramp_temp and self.ensemble in ("nve", "nph"):
427+
raise ValueError(
428+
"Temperature ramp requested for ensemble with no thermostat."
429+
)
426430

427431
# Check validate start and end temperatures
428432
if self.ramp_temp and (self.temp_start < 0 or self.temp_end < 0):
@@ -1021,6 +1025,28 @@ def _write_restart(self) -> None:
10211025
self.restart_files.append(self._restart_file)
10221026
self._rotate_restart_files()
10231027

1028+
def _set_target_temperature(self, temperature: float):
1029+
"""
1030+
Set the target temperature of the thermostat.
1031+
1032+
Parameters
1033+
----------
1034+
temperature
1035+
New target temperature, in K.
1036+
"""
1037+
if hasattr(self.dyn, "set_temperature"):
1038+
self.dyn.set_temperature(temperature_K=temperature)
1039+
elif isinstance(self, NVT_CSVR):
1040+
self.dyn.temp = temperature * units.kB
1041+
self.dyn.target_kinetic_energy = 0.5 * self.dyn.temp * self.dyn.ndof
1042+
elif isinstance(self, NPT_MTK):
1043+
kt = temperature * units.kB
1044+
self.dyn._kT = kt
1045+
self.dyn._thermostat._kT = kt
1046+
self.dyn._barostat._kT = kt
1047+
else:
1048+
raise ValueError("Temperature set for ensemble with no thermostat.")
1049+
10241050
def run(self) -> None:
10251051
"""Run molecular dynamics simulation and/or temperature ramp."""
10261052
unit_keys = (
@@ -1105,8 +1131,7 @@ def _run_dynamics(self) -> None:
11051131
self._write_final_state()
11061132
self.created_final_file = True
11071133
continue
1108-
if not isinstance(self, NVE):
1109-
self.dyn.set_temperature(temperature_K=self.temp)
1134+
self._set_target_temperature(temp)
11101135
self.dyn.run(heating_steps)
11111136
self._write_final_state()
11121137
self.created_final_file = True
@@ -1126,8 +1151,7 @@ def _run_dynamics(self) -> None:
11261151
self.temp = md_temp
11271152
if self.ramp_temp:
11281153
self._set_velocity_distribution()
1129-
if not isinstance(self, NVE):
1130-
self.dyn.set_temperature(temperature_K=self.temp)
1154+
self._set_target_temperature(self.temp)
11311155
self.dyn.run(self.steps - self.offset)
11321156
self._write_final_state()
11331157
self.created_final_file = True

tests/test_md.py

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ase import Atoms
88
from ase.io import read
99
import ase.md.nose_hoover_chain
10+
from ase.md.npt import NPT as ASE_NPT
1011
import numpy as np
1112
import pytest
1213

@@ -21,6 +22,7 @@
2122

2223
MTK_IMPORT_FAILED = False
2324
else:
25+
NPT_MTK = None
2426
MTK_IMPORT_FAILED = True
2527

2628
DATA_PATH = Path(__file__).parent / "data"
@@ -32,9 +34,29 @@
3234
(NPT, "npt"),
3335
(NVT_NH, "nvt-nh"),
3436
(NPH, "nph"),
37+
(NVT_CSVR, "nvt-csvr"),
38+
pytest.param(
39+
NPT_MTK,
40+
"npt-mtk",
41+
marks=pytest.mark.skipif(
42+
MTK_IMPORT_FAILED, reason="Requires updated version of ASE"
43+
),
44+
),
3545
]
36-
if not MTK_IMPORT_FAILED:
37-
test_data.append((NPT_MTK, "npt-mtk"))
46+
47+
ensembles_without_thermostat = (NVE, NPH)
48+
ensembles_with_thermostat = (
49+
NVT,
50+
NPT,
51+
NVT_NH,
52+
NVT_CSVR,
53+
pytest.param(
54+
NPT_MTK,
55+
marks=pytest.mark.skipif(
56+
MTK_IMPORT_FAILED, reason="Requires updated version of ASE"
57+
),
58+
),
59+
)
3860

3961

4062
@pytest.mark.parametrize("ensemble, expected", test_data)
@@ -694,19 +716,20 @@ def test_stats(tmp_path, ensemble, tag):
694716
assert stat_data.units[etot_index] == "eV"
695717

696718

697-
def test_heating(tmp_path):
719+
@pytest.mark.parametrize("ensemble", ensembles_with_thermostat)
720+
def test_heating(tmp_path, ensemble):
698721
"""Test heating with no MD."""
699722
file_prefix = tmp_path / "NaCl-heating"
700723
final_file = tmp_path / "NaCl-heating-final.extxyz"
701-
log_file = tmp_path / "nvt.log"
724+
log_file = tmp_path / "NaCl.log"
702725

703726
single_point = SinglePoint(
704727
struct_path=DATA_PATH / "NaCl.cif",
705728
arch="mace",
706729
calc_kwargs={"model": MODEL_PATH},
707730
)
708731

709-
nvt = NVT(
732+
md = ensemble(
710733
struct=single_point.struct,
711734
temp=300.0,
712735
steps=0,
@@ -719,7 +742,7 @@ def test_heating(tmp_path):
719742
temp_time=0.5,
720743
log_kwargs={"filename": log_file},
721744
)
722-
nvt.run()
745+
md.run()
723746
assert_log_contains(
724747
log_file,
725748
includes=[
@@ -728,10 +751,42 @@ def test_heating(tmp_path):
728751
],
729752
excludes=["Starting molecular dynamics simulation"],
730753
)
754+
731755
assert final_file.exists()
732756

733757

734-
def test_noramp_heating(tmp_path):
758+
@pytest.mark.parametrize("ensemble", ensembles_without_thermostat)
759+
def test_no_thermostat_heating(tmp_path, ensemble):
760+
"""Test that temperature ramp with no thermostat throws an error."""
761+
file_prefix = tmp_path / "NaCl-heating"
762+
final_file = tmp_path / "NaCl-heating-final.extxyz"
763+
log_file = tmp_path / "NaCl.log"
764+
765+
single_point = SinglePoint(
766+
struct_path=DATA_PATH / "NaCl.cif",
767+
arch="mace",
768+
calc_kwargs={"model": MODEL_PATH},
769+
)
770+
with pytest.raises(ValueError, match="no thermostat"):
771+
md = ensemble(
772+
struct=single_point.struct,
773+
temp=300.0,
774+
steps=0,
775+
traj_every=10,
776+
stats_every=10,
777+
file_prefix=file_prefix,
778+
temp_start=0.0,
779+
temp_end=20.0,
780+
temp_step=20,
781+
temp_time=0.5,
782+
log_kwargs={"filename": log_file},
783+
)
784+
md.run()
785+
assert not final_file.exists()
786+
787+
788+
@pytest.mark.parametrize("ensemble", ensembles_with_thermostat)
789+
def test_noramp_heating(tmp_path, ensemble):
735790
"""Test ValueError is thrown for invalid temperature ramp."""
736791
file_prefix = tmp_path / "NaCl-heating"
737792

@@ -742,7 +797,7 @@ def test_noramp_heating(tmp_path):
742797
)
743798

744799
with pytest.raises(ValueError):
745-
NVT(
800+
ensemble(
746801
struct=single_point.struct,
747802
file_prefix=file_prefix,
748803
temp_start=10,
@@ -751,18 +806,19 @@ def test_noramp_heating(tmp_path):
751806
)
752807

753808

754-
def test_heating_md(tmp_path):
809+
@pytest.mark.parametrize("ensemble", ensembles_with_thermostat)
810+
def test_heating_md(tmp_path, ensemble):
755811
"""Test heating followed by MD."""
756812
file_prefix = tmp_path / "NaCl-heating"
757813
stats_path = tmp_path / "NaCl-heating-stats.dat"
758-
log_file = tmp_path / "nvt.log"
814+
log_file = tmp_path / "NaCl.log"
759815

760816
single_point = SinglePoint(
761817
struct_path=DATA_PATH / "NaCl.cif",
762818
arch="mace",
763819
calc_kwargs={"model": MODEL_PATH},
764820
)
765-
nvt = NVT(
821+
md = ensemble(
766822
struct=single_point.struct,
767823
temp=25.0,
768824
steps=5,
@@ -775,7 +831,7 @@ def test_heating_md(tmp_path):
775831
temp_time=2,
776832
log_kwargs={"filename": log_file},
777833
)
778-
nvt.run()
834+
md.run()
779835
assert_log_contains(
780836
log_file,
781837
includes=[
@@ -786,15 +842,23 @@ def test_heating_md(tmp_path):
786842
],
787843
)
788844
stat_data = Stats(stats_path)
789-
assert stat_data.rows == 5
790-
assert stat_data.columns == 17
791-
assert stat_data.data[0, 16] == pytest.approx(10.0)
792-
assert stat_data.data[2, 16] == pytest.approx(20.0)
793-
assert stat_data.data[4, 16] == pytest.approx(25.0)
845+
target_t_col = stat_data.labels.index("Target_T")
846+
847+
is_ase_npt = isinstance(md.dyn, ASE_NPT)
848+
849+
# ASE_NPT skips first row of output - ASE merge request submitted to fix:
850+
# https://gitlab.com/ase/ase/-/merge_requests/3598
851+
assert stat_data.rows == 4 if is_ase_npt else 5
852+
assert stat_data.data[0, target_t_col] == pytest.approx(10.0)
853+
if is_ase_npt:
854+
assert stat_data.data[1, target_t_col] == pytest.approx(20.0)
855+
assert stat_data.data[3, target_t_col] == pytest.approx(25.0)
856+
else:
857+
assert stat_data.data[2, target_t_col] == pytest.approx(20.0)
858+
assert stat_data.data[4, target_t_col] == pytest.approx(25.0)
794859
assert stat_data.labels[0] == "# Step"
795860
assert stat_data.units[0] == ""
796-
assert stat_data.units[16] == "K"
797-
assert stat_data.labels[16] == "Target_T"
861+
assert stat_data.units[target_t_col] == "K"
798862

799863

800864
def test_heating_files():

tests/test_md_cli.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,20 @@ def test_md_help():
3232
assert "Usage: janus md [OPTIONS]" in strip_ansi_codes(result.stdout)
3333

3434

35-
test_data = [("nvt"), ("nve"), ("npt"), ("nvt-nh"), ("nph"), ("nvt-csvr"), ("npt-mtk")]
35+
test_data = [
36+
("nvt"),
37+
("nve"),
38+
("npt"),
39+
("nvt-nh"),
40+
("nph"),
41+
("nvt-csvr"),
42+
pytest.param(
43+
"npt-mtk",
44+
marks=pytest.mark.skipif(
45+
MTK_IMPORT_FAILED, reason="Requires updated version of ASE"
46+
),
47+
),
48+
]
3649

3750

3851
@pytest.mark.parametrize("ensemble", test_data)
@@ -49,9 +62,6 @@ def test_md(ensemble):
4962
"npt-mtk": "NaCl-npt-mtk-T300.0-p0.0-",
5063
}
5164

52-
if ensemble == "npt-mtk" and MTK_IMPORT_FAILED:
53-
pytest.skip(reason="Requires updated version of ASE")
54-
5565
final_path = Path(f"{file_prefix[ensemble]}final.extxyz").absolute()
5666
restart_path = Path(f"{file_prefix[ensemble]}res-2.extxyz").absolute()
5767
stats_path = Path(f"{file_prefix[ensemble]}stats.dat").absolute()
@@ -342,7 +352,8 @@ def test_config(tmp_path):
342352
assert_log_contains(log_path, includes=["hydrostatic_strain: True"])
343353

344354

345-
def test_heating(tmp_path):
355+
@pytest.mark.parametrize("ensemble", test_data)
356+
def test_heating(tmp_path, ensemble):
346357
"""Test heating before MD."""
347358
file_prefix = tmp_path / "nvt-T300"
348359

@@ -351,7 +362,7 @@ def test_heating(tmp_path):
351362
[
352363
"md",
353364
"--ensemble",
354-
"nvt",
365+
ensemble,
355366
"--struct",
356367
DATA_PATH / "NaCl.cif",
357368
"--file-prefix",
@@ -370,7 +381,10 @@ def test_heating(tmp_path):
370381
0.05,
371382
],
372383
)
373-
assert result.exit_code == 0
384+
if ensemble in ("nve", "nph"):
385+
assert result.exit_code != 0
386+
else:
387+
assert result.exit_code == 0
374388

375389

376390
def test_invalid_config():

0 commit comments

Comments
 (0)