Skip to content

Commit f8fd5c9

Browse files
Fix temperature saved during MD (#469)
* Rename stats units property * Fix writing current MD temperature * Test stats and traj data is consistent
1 parent 47957fd commit f8fd5c9

File tree

3 files changed

+157
-33
lines changed

3 files changed

+157
-33
lines changed

janus_core/calculations/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"time": "fs",
2727
"real_time": "s",
2828
"temperature": "K",
29+
"target_temperature": "K",
2930
"pressure": "GPa",
3031
"momenta": "(eV*u)^0.5",
3132
"density": "g/cm^3",

janus_core/calculations/md.py

Lines changed: 94 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def __init__(
517517
self._parse_correlations()
518518

519519
def _set_info(self) -> None:
520-
"""Set time in fs, current dynamics step, and density to info."""
520+
"""Set time in fs, current dynamics step, density, and temperature to info."""
521521
time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs
522522
step = self.offset + self.dyn.nsteps
523523
self.dyn.atoms.info["time"] = time
@@ -532,6 +532,13 @@ def _set_info(self) -> None:
532532
except ValueError:
533533
self.dyn.atoms.info["density"] = 0.0
534534

535+
e_kin = self.dyn.atoms.get_kinetic_energy() / self.n_atoms
536+
current_temp = e_kin / (1.5 * units.kB)
537+
self.struct.info["temperature"] = current_temp
538+
539+
if hasattr(self.dyn, "set_temperature") or isinstance(self, NVT_CSVR | NPT_MTK):
540+
self.struct.info["target_temperature"] = self.temp
541+
535542
def _prepare_restart(self) -> None:
536543
"""Prepare restart files, structure and offset."""
537544
# Check offset can be read from steps
@@ -727,6 +734,28 @@ def _write_correlations(self) -> None:
727734
data[str(cor)] = {"value": value.tolist(), "lags": lags.tolist()}
728735
yaml.dump(data, out_file, default_flow_style=None)
729736

737+
@property
738+
def info_unit_keys(self) -> tuple[str]:
739+
"""
740+
Get Atoms.info keys to save units for.
741+
742+
Returns
743+
-------
744+
tuple[str]
745+
Keys of Atoms.info.
746+
"""
747+
return (
748+
"energy",
749+
"forces",
750+
"stress",
751+
"time",
752+
"real_time",
753+
"temperature",
754+
"pressure",
755+
"density",
756+
"momenta",
757+
)
758+
730759
def get_stats(self) -> dict[str, float]:
731760
"""
732761
Get thermodynamical statistics to be written to file.
@@ -738,7 +767,6 @@ def get_stats(self) -> dict[str, float]:
738767
"""
739768
e_pot = self.dyn.atoms.get_potential_energy() / self.n_atoms
740769
e_kin = self.dyn.atoms.get_kinetic_energy() / self.n_atoms
741-
current_temp = e_kin / (1.5 * units.kB)
742770

743771
self._set_info()
744772

@@ -770,7 +798,7 @@ def get_stats(self) -> dict[str, float]:
770798
"Time": self.dyn.atoms.info["time"],
771799
"Epot/N": e_pot,
772800
"EKin/N": e_kin,
773-
"T": current_temp,
801+
"T": self.dyn.atoms.info["temperature"],
774802
"ETot/N": e_pot + e_kin,
775803
"Density": self.dyn.atoms.info["density"],
776804
"Volume": volume,
@@ -784,7 +812,7 @@ def get_stats(self) -> dict[str, float]:
784812
}
785813

786814
@property
787-
def unit_info(self) -> dict[str, str]:
815+
def stats_units(self) -> dict[str, str]:
788816
"""
789817
Get units of returned statistics.
790818
@@ -847,8 +875,8 @@ def _write_header(self) -> None:
847875
write_table(
848876
"ascii",
849877
file=stats_file,
850-
units=self.unit_info,
851-
**{key: () for key in self.unit_info},
878+
units=self.stats_units,
879+
**{key: () for key in self.stats_units},
852880
)
853881

854882
def _write_stats_file(self) -> None:
@@ -863,7 +891,7 @@ def _write_stats_file(self) -> None:
863891
write_table(
864892
"ascii",
865893
file=stats_file,
866-
units=self.unit_info,
894+
units=self.stats_units,
867895
formats=self.default_formats,
868896
print_header=False,
869897
**stats,
@@ -895,10 +923,6 @@ def _write_traj(self) -> None:
895923

896924
def _write_final_state(self) -> None:
897925
"""Write the final system state."""
898-
self.struct.info["temperature"] = self.temp
899-
if isinstance(self, NPT) and not isinstance(self, NVT_NH):
900-
self.struct.info["pressure"] = self.pressure
901-
902926
# Append if final file has been created
903927
append = self.created_final_file
904928

@@ -1046,18 +1070,7 @@ def _set_target_temperature(self, temperature: float):
10461070

10471071
def run(self) -> None:
10481072
"""Run molecular dynamics simulation and/or temperature ramp."""
1049-
unit_keys = (
1050-
"energy",
1051-
"forces",
1052-
"stress",
1053-
"time",
1054-
"real_time",
1055-
"temperature",
1056-
"pressure",
1057-
"density",
1058-
"momenta",
1059-
)
1060-
self._set_info_units(unit_keys)
1073+
self._set_info_units(self.info_unit_keys)
10611074

10621075
if not self.restart:
10631076
if self.minimize:
@@ -1299,6 +1312,18 @@ def _set_param_prefix(self, file_prefix: PathLike | None = None) -> str:
12991312

13001313
return f"{super()._set_param_prefix(file_prefix)}-p{self.pressure}"
13011314

1315+
@property
1316+
def info_unit_keys(self) -> tuple[str]:
1317+
"""
1318+
Get Atoms.info keys to save units for.
1319+
1320+
Returns
1321+
-------
1322+
tuple[str]
1323+
Keys of Atoms.info.
1324+
"""
1325+
return super().info_unit_keys + ("target_temperature",)
1326+
13021327
def get_stats(self) -> dict[str, float]:
13031328
"""
13041329
Get thermodynamical statistics to be written to file.
@@ -1313,7 +1338,7 @@ def get_stats(self) -> dict[str, float]:
13131338
return stats
13141339

13151340
@property
1316-
def unit_info(self) -> dict[str, str]:
1341+
def stats_units(self) -> dict[str, str]:
13171342
"""
13181343
Get units of returned statistics.
13191344
@@ -1322,7 +1347,7 @@ def unit_info(self) -> dict[str, str]:
13221347
dict[str, str]
13231348
Units attached to statistical properties.
13241349
"""
1325-
return super().unit_info | {
1350+
return super().stats_units | {
13261351
"Target_P": JANUS_UNITS["pressure"],
13271352
"Target_T": JANUS_UNITS["temperature"],
13281353
}
@@ -1394,6 +1419,18 @@ def __init__(
13941419
**ensemble_kwargs,
13951420
)
13961421

1422+
@property
1423+
def info_unit_keys(self) -> tuple[str]:
1424+
"""
1425+
Get Atoms.info keys to save units for.
1426+
1427+
Returns
1428+
-------
1429+
tuple[str]
1430+
Keys of Atoms.info.
1431+
"""
1432+
return super().info_unit_keys + ("target_temperature",)
1433+
13971434
def get_stats(self) -> dict[str, float]:
13981435
"""
13991436
Get thermodynamical statistics to be written to file.
@@ -1408,7 +1445,7 @@ def get_stats(self) -> dict[str, float]:
14081445
return stats
14091446

14101447
@property
1411-
def unit_info(self) -> dict[str, str]:
1448+
def stats_units(self) -> dict[str, str]:
14121449
"""
14131450
Get units of returned statistics.
14141451
@@ -1417,7 +1454,7 @@ def unit_info(self) -> dict[str, str]:
14171454
dict[str, str]
14181455
Units attached to statistical properties.
14191456
"""
1420-
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}
1457+
return super().stats_units | {"Target_T": JANUS_UNITS["temperature"]}
14211458

14221459
@property
14231460
def default_formats(self) -> dict[str, str]:
@@ -1543,6 +1580,18 @@ def __init__(
15431580
**ensemble_kwargs,
15441581
)
15451582

1583+
@property
1584+
def info_unit_keys(self) -> tuple[str]:
1585+
"""
1586+
Get Atoms.info keys to save units for.
1587+
1588+
Returns
1589+
-------
1590+
tuple[str]
1591+
Keys of Atoms.info.
1592+
"""
1593+
return super().info_unit_keys + ("target_temperature",)
1594+
15461595
def get_stats(self) -> dict[str, float]:
15471596
"""
15481597
Get thermodynamical statistics to be written to file.
@@ -1557,7 +1606,7 @@ def get_stats(self) -> dict[str, float]:
15571606
return stats
15581607

15591608
@property
1560-
def unit_info(self) -> dict[str, str]:
1609+
def stats_units(self) -> dict[str, str]:
15611610
"""
15621611
Get units of returned statistics.
15631612
@@ -1566,7 +1615,7 @@ def unit_info(self) -> dict[str, str]:
15661615
dict[str, str]
15671616
Units attached to statistical properties.
15681617
"""
1569-
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}
1618+
return super().stats_units | {"Target_T": JANUS_UNITS["temperature"]}
15701619

15711620
@property
15721621
def default_formats(self) -> dict[str, str]:
@@ -1757,7 +1806,7 @@ def get_stats(self) -> dict[str, float]:
17571806
return stats
17581807

17591808
@property
1760-
def unit_info(self) -> dict[str, str]:
1809+
def stats_units(self) -> dict[str, str]:
17611810
"""
17621811
Get units of returned statistics.
17631812
@@ -1766,7 +1815,7 @@ def unit_info(self) -> dict[str, str]:
17661815
dict[str, str]
17671816
Units attached to statistical properties.
17681817
"""
1769-
return super().unit_info | {
1818+
return super().stats_units | {
17701819
"Target_P": JANUS_UNITS["pressure"],
17711820
}
17721821

@@ -1906,6 +1955,18 @@ def _set_param_prefix(self, file_prefix: PathLike | None = None) -> str:
19061955
pressure = f"-p{self.pressure}"
19071956
return f"{super()._set_param_prefix(file_prefix)}{pressure}"
19081957

1958+
@property
1959+
def info_unit_keys(self) -> tuple[str]:
1960+
"""
1961+
Get Atoms.info keys to save units for.
1962+
1963+
Returns
1964+
-------
1965+
tuple[str]
1966+
Keys of Atoms.info.
1967+
"""
1968+
return super().info_unit_keys + ("target_temperature",)
1969+
19091970
def get_stats(self) -> dict[str, float]:
19101971
"""
19111972
Get thermodynamical statistics to be written to file.
@@ -1920,7 +1981,7 @@ def get_stats(self) -> dict[str, float]:
19201981
return stats
19211982

19221983
@property
1923-
def unit_info(self) -> dict[str, str]:
1984+
def stats_units(self) -> dict[str, str]:
19241985
"""
19251986
Get units of returned statistics.
19261987
@@ -1929,7 +1990,7 @@ def unit_info(self) -> dict[str, str]:
19291990
dict[str, str]
19301991
Units attached to statistical properties.
19311992
"""
1932-
return super().unit_info | {
1993+
return super().stats_units | {
19331994
"Target_P": JANUS_UNITS["pressure"],
19341995
"Target_T": JANUS_UNITS["temperature"],
19351996
}

tests/test_md_cli.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import yaml
1414

1515
from janus_core.cli.janus import app
16+
from janus_core.helpers.stats import Stats
1617
from tests.utils import assert_log_contains, clear_log_handlers, strip_ansi_codes
1718

1819
if hasattr(ase.md.nose_hoover_chain, "IsotropicMTKNPT"):
@@ -815,3 +816,64 @@ def test_no_carbon(tmp_path):
815816
with open(summary_path, encoding="utf8") as file:
816817
summary = yaml.safe_load(file)
817818
assert "emissions" not in summary
819+
820+
821+
@pytest.mark.parametrize("ensemble", ("nvt", "npt", "nvt-csvr"))
822+
@pytest.mark.parametrize("output_every", (1, 2))
823+
@pytest.mark.parametrize("heating", (True, False))
824+
def test_consistent_stats_traj(tmp_path, ensemble, output_every, heating):
825+
"""Test data saved to statistics is consistent with trajectory info."""
826+
file_prefix = tmp_path / ensemble
827+
stats_path = tmp_path / f"{ensemble}-stats.dat"
828+
traj_path = tmp_path / f"{ensemble}-traj.extxyz"
829+
830+
inputs = [
831+
"md",
832+
"--ensemble",
833+
"nvt",
834+
"--struct",
835+
DATA_PATH / "NaCl.cif",
836+
"--no-tracker",
837+
"--file-prefix",
838+
file_prefix,
839+
"--steps",
840+
4,
841+
"--stats-every",
842+
output_every,
843+
"--traj-every",
844+
output_every,
845+
]
846+
847+
if heating:
848+
inputs = inputs + [
849+
"--temp-start",
850+
10,
851+
"--temp-end",
852+
20,
853+
"--temp-step",
854+
10,
855+
"--temp-time",
856+
2,
857+
]
858+
859+
result = runner.invoke(app, inputs)
860+
assert result.exit_code == 0
861+
862+
atoms = read(traj_path, index=":")
863+
temps = [a.info["temperature"] for a in atoms]
864+
865+
data = Stats(stats_path)
866+
temp_col = data.labels.index("T")
867+
stats_temps = data.data[:, temp_col]
868+
869+
assert len(temps) == len(stats_temps)
870+
assert temps == pytest.approx(stats_temps, rel=1e5)
871+
872+
if ensemble in ("nvt", "npt"):
873+
target_temps = [a.info["target_temperature"] for a in atoms]
874+
875+
target_temp_col = data.labels.index("Target_T")
876+
stats_target_temps = data.data[:, target_temp_col]
877+
878+
assert len(target_temps) == len(stats_target_temps)
879+
assert target_temps == pytest.approx(stats_target_temps, rel=1e5)

0 commit comments

Comments
 (0)