Skip to content

Commit 7eae3aa

Browse files
authored
Fix NPH and NVT_NH (#429)
* Fix NPH and NVT_NH logging * Add stats file test
1 parent 88a3508 commit 7eae3aa

File tree

3 files changed

+137
-32
lines changed

3 files changed

+137
-32
lines changed

janus_core/calculations/md.py

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,8 +1250,7 @@ def _set_param_prefix(self, file_prefix: PathLike | None = None) -> str:
12501250
if file_prefix is not None:
12511251
return ""
12521252

1253-
pressure = f"-p{self.pressure}" if not isinstance(self, NVT_NH) else ""
1254-
return f"{super()._set_param_prefix(file_prefix)}{pressure}"
1253+
return f"{super()._set_param_prefix(file_prefix)}-p{self.pressure}"
12551254

12561255
def get_stats(self) -> dict[str, float]:
12571256
"""
@@ -1262,7 +1261,7 @@ def get_stats(self) -> dict[str, float]:
12621261
dict[str, float]
12631262
Thermodynamical statistics to be written out.
12641263
"""
1265-
stats = MolecularDynamics.get_stats(self)
1264+
stats = super().get_stats()
12661265
stats |= {"Target_P": self.pressure, "Target_T": self.temp}
12671266
return stats
12681267

@@ -1434,7 +1433,7 @@ def __init__(
14341433
)
14351434

14361435

1437-
class NVT_NH(NPT): # noqa: N801 (invalid-class-name)
1436+
class NVT_NH(MolecularDynamics): # noqa: N801 (invalid-class-name)
14381437
"""
14391438
Configure NVT Nosé-Hoover simulation.
14401439
@@ -1446,6 +1445,9 @@ class NVT_NH(NPT): # noqa: N801 (invalid-class-name)
14461445
Thermostat time, in fs. Default is 50.0.
14471446
ensemble
14481447
Name for thermodynamic ensemble. Default is "nvt-nh".
1448+
file_prefix
1449+
Prefix for output filenames. Default is inferred from structure, ensemble,
1450+
temperature, and pressure.
14491451
ensemble_kwargs
14501452
Keyword arguments to pass to ensemble initialization. Default is {}.
14511453
**kwargs
@@ -1457,6 +1459,7 @@ def __init__(
14571459
*args,
14581460
thermostat_time: float = 50.0,
14591461
ensemble: Ensembles = "nvt-nh",
1462+
file_prefix: PathLike | None = None,
14601463
ensemble_kwargs: dict[str, Any] | None = None,
14611464
**kwargs,
14621465
) -> None:
@@ -1471,19 +1474,26 @@ def __init__(
14711474
Thermostat time, in fs. Default is 50.0.
14721475
ensemble
14731476
Name for thermodynamic ensemble. Default is "nvt-nh".
1477+
file_prefix
1478+
Prefix for output filenames. Default is inferred from structure, ensemble,
1479+
temperature, and pressure.
14741480
ensemble_kwargs
14751481
Keyword arguments to pass to ensemble initialization. Default is {}.
14761482
**kwargs
14771483
Additional keyword arguments.
14781484
"""
1485+
super().__init__(*args, ensemble=ensemble, file_prefix=file_prefix, **kwargs)
1486+
14791487
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
1480-
super().__init__(
1481-
*args,
1482-
ensemble=ensemble,
1483-
thermostat_time=thermostat_time,
1484-
barostat_time=None,
1485-
ensemble_kwargs=ensemble_kwargs,
1486-
**kwargs,
1488+
self.ttime = thermostat_time * units.fs
1489+
1490+
self.dyn = ASE_NPT(
1491+
self.struct,
1492+
timestep=self.timestep,
1493+
temperature_K=self.temp,
1494+
ttime=self.ttime,
1495+
append_trajectory=self.traj_append,
1496+
**ensemble_kwargs,
14871497
)
14881498

14891499
def get_stats(self) -> dict[str, float]:
@@ -1495,7 +1505,7 @@ def get_stats(self) -> dict[str, float]:
14951505
dict[str, float]
14961506
Thermodynamical statistics to be written out.
14971507
"""
1498-
stats = MolecularDynamics.get_stats(self)
1508+
stats = super().get_stats()
14991509
stats |= {"Target_T": self.temp}
15001510
return stats
15011511

@@ -1585,16 +1595,16 @@ def __init__(
15851595
)
15861596

15871597

1588-
class NPH(NPT):
1598+
class NPH(MolecularDynamics):
15891599
"""
15901600
Configure NPH simulation.
15911601
15921602
Parameters
15931603
----------
15941604
*args
15951605
Additional arguments.
1596-
thermostat_time
1597-
Thermostat time, in fs. Default is 50.0.
1606+
barostat_time
1607+
Barostat time, in fs. Default is 75.0.
15981608
bulk_modulus
15991609
Bulk modulus, in GPa. Default is 2.0.
16001610
pressure
@@ -1613,7 +1623,7 @@ class NPH(NPT):
16131623
def __init__(
16141624
self,
16151625
*args,
1616-
thermostat_time: float = 50.0,
1626+
barostat_time: float = 75.0,
16171627
bulk_modulus: float = 2.0,
16181628
pressure: float = 0.0,
16191629
ensemble: Ensembles = "nph",
@@ -1628,8 +1638,8 @@ def __init__(
16281638
----------
16291639
*args
16301640
Additional arguments.
1631-
thermostat_time
1632-
Thermostat time, in fs. Default is 50.0.
1641+
barostat_time
1642+
Barostat time, in fs. Default is 75.0.
16331643
bulk_modulus
16341644
Bulk modulus, in GPa. Default is 2.0.
16351645
pressure
@@ -1644,19 +1654,87 @@ def __init__(
16441654
**kwargs
16451655
Additional keyword arguments.
16461656
"""
1657+
self.pressure = pressure
1658+
super().__init__(*args, ensemble=ensemble, file_prefix=file_prefix, **kwargs)
1659+
16471660
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
1648-
super().__init__(
1649-
*args,
1650-
thermostat_time=thermostat_time,
1651-
barostat_time=None,
1652-
bulk_modulus=bulk_modulus,
1653-
pressure=pressure,
1654-
ensemble=ensemble,
1655-
file_prefix=file_prefix,
1656-
ensemble_kwargs=ensemble_kwargs,
1657-
**kwargs,
1661+
1662+
pfactor = barostat_time**2 * bulk_modulus
1663+
if self.logger:
1664+
self.logger.info("NPT pfactor=%s GPa fs^2", pfactor)
1665+
1666+
# convert the pfactor to ASE internal units
1667+
pfactor *= units.fs**2 * units.GPa
1668+
1669+
self.dyn = ASE_NPT(
1670+
self.struct,
1671+
timestep=self.timestep,
1672+
temperature_K=self.temp,
1673+
ttime=None,
1674+
pfactor=pfactor,
1675+
append_trajectory=self.traj_append,
1676+
externalstress=self.pressure * units.GPa,
1677+
**ensemble_kwargs,
16581678
)
16591679

1680+
def _set_param_prefix(self, file_prefix: PathLike | None = None) -> str:
1681+
"""
1682+
Set ensemble parameters for output files.
1683+
1684+
Parameters
1685+
----------
1686+
file_prefix
1687+
Prefix for output filenames on class init. If not None, param_prefix = "".
1688+
1689+
Returns
1690+
-------
1691+
str
1692+
Formatted ensemble parameters, including pressure and temperature(s).
1693+
"""
1694+
if file_prefix is not None:
1695+
return ""
1696+
1697+
return f"{super()._set_param_prefix(file_prefix)}-p{self.pressure}"
1698+
1699+
def get_stats(self) -> dict[str, float]:
1700+
"""
1701+
Get thermodynamical statistics to be written to file.
1702+
1703+
Returns
1704+
-------
1705+
dict[str, float]
1706+
Thermodynamical statistics to be written out.
1707+
"""
1708+
stats = super().get_stats()
1709+
stats |= {"Target_P": self.pressure}
1710+
return stats
1711+
1712+
@property
1713+
def unit_info(self) -> dict[str, str]:
1714+
"""
1715+
Get units of returned statistics.
1716+
1717+
Returns
1718+
-------
1719+
dict[str, str]
1720+
Units attached to statistical properties.
1721+
"""
1722+
return super().unit_info | {
1723+
"Target_P": JANUS_UNITS["pressure"],
1724+
}
1725+
1726+
@property
1727+
def default_formats(self) -> dict[str, str]:
1728+
"""
1729+
Default format of returned statistics.
1730+
1731+
Returns
1732+
-------
1733+
dict[str, str]
1734+
Default formats attached to statistical properties.
1735+
"""
1736+
return super().default_formats | {"Target_P": ".5f"}
1737+
16601738

16611739
class NPT_MTK(MolecularDynamics): # noqa: N801 (invalid-class-name)
16621740
"""

janus_core/cli/md.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def md(
4242
Option(
4343
help=(
4444
"""
45-
Thermostat time for NPT, NPT-MTK, NVT Nosé-Hoover, or NPH simulation,
45+
Thermostat time for NPT, NPT-MTK or NVT Nosé-Hoover simulation,
4646
in fs. Default is 50 fs for NPT and NVT Nosé-Hoover, or 100 fs for
4747
NPT-MTK.
4848
"""
@@ -235,8 +235,8 @@ def md(
235235
temp
236236
Temperature, in K. Default is 300.
237237
thermostat_time
238-
Thermostat time for NPT, NPT-MTK, NVT Nosé-Hoover or NPH simulation,
239-
in fs. Default is 50 fs for NPT, NVT Nosé-Hoover and NPH, or 100 fs for NPT-MTK.
238+
Thermostat time for NPT, NPT-MTK or NVT Nosé-Hoover simulation,
239+
in fs. Default is 50 fs for NPT and NVT Nosé-Hoover, or 100 fs for NPT-MTK.
240240
barostat_time
241241
Barostat time for NPT, NPT-MTK or NPH simulation, in fs.
242242
Default is 75 fs for NPT and NPH, or 1000 fs for NPT-MTK.
@@ -487,7 +487,7 @@ def md(
487487
elif ensemble == "nph":
488488
for key in (
489489
"friction",
490-
"barostat_time",
490+
"thermostat_time",
491491
"taut",
492492
"thermostat_chain",
493493
"barostat_chain",

tests/test_md.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,33 @@ def test_atoms_struct(tmp_path):
667667
assert len(lines) == 6
668668

669669

670+
@pytest.mark.parametrize("ensemble, tag", test_data)
671+
def test_stats(tmp_path, ensemble, tag):
672+
"""Test stats file has correct structure and entries for all ensembles."""
673+
file_prefix = tmp_path / tag / "NaCl"
674+
single_point = SinglePoint(
675+
struct_path=DATA_PATH / "NaCl.cif",
676+
arch="mace",
677+
calc_kwargs={"model": MODEL_PATH},
678+
)
679+
md = ensemble(
680+
struct=single_point.struct,
681+
steps=2,
682+
stats_every=1,
683+
file_prefix=file_prefix,
684+
)
685+
md.run()
686+
687+
stat_data = Stats(md.stats_file)
688+
689+
etot_index = stat_data.labels.index("ETot/N")
690+
691+
assert stat_data.columns == len(stat_data.labels)
692+
assert stat_data.columns == len(stat_data.units)
693+
assert stat_data.columns >= 16
694+
assert stat_data.units[etot_index] == "eV"
695+
696+
670697
def test_heating(tmp_path):
671698
"""Test heating with no MD."""
672699
file_prefix = tmp_path / "NaCl-heating"

0 commit comments

Comments
 (0)