Skip to content

Commit 2189b9f

Browse files
committed
python: added folded pos to ase
1 parent 3a70cb5 commit 2189b9f

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

src/python/espressomd/plugins/ase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def register_system(self, system):
4545
def __getstate__(self):
4646
return {"type_mapping": self.type_mapping}
4747

48-
def get(self) -> ase.Atoms:
48+
def get(self, folded=False) -> ase.Atoms:
4949
"""Export the ESPResSo system particle data to an ASE atoms object."""
5050
particles = self._system.part.all()
51-
positions = np.copy(particles.pos)
51+
positions = np.copy(particles.pos_folded if folded else particles.pos)
5252
types = np.copy(particles.type)
5353
forces = np.copy(particles.f)
5454
unknown_types = set(types) - set(self.type_mapping)

src/python/espressomd/zn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def encode(self, system) -> zndraw.utils.ASEDict:
6666

6767
self.numbers = self.num_particles * [1]
6868

69-
if self.params["folded"] is True:
69+
if self.params["folded"]:
7070
self.positions = self.particles.pos_folded
7171
else:
7272
self.positions = self.particles.pos
@@ -81,7 +81,7 @@ def encode(self, system) -> zndraw.utils.ASEDict:
8181
else:
8282
self.radii = self.set_radii(self.params["radii"])
8383

84-
if self.params["bonds"] is True:
84+
if self.params["bonds"]:
8585
bonds = self.get_bonds()
8686
else:
8787
bonds = []
@@ -542,7 +542,7 @@ def update(self):
542542
Asedata = ase.ASEInterface(
543543
{x: "X" for x in set(all_types)})
544544
Asedata.register_system(self.system)
545-
data = Asedata.get()
545+
data = Asedata.get(folded=self.params["folded"])
546546
if self.params["colors"] is not None:
547547
data.arrays['colors'] = [self.params["colors"].get(
548548
z, "white") for z in all_types]

testsuite/python/ase_interface.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ASEInterfaceTest(ut.TestCase):
3333
def setUp(self):
3434
self.system.part.add(pos=[0., 0., 0.], f=[1., -1., 0.], type=0)
3535
self.system.part.add(pos=[0., 0., 1.], f=[0., 12., 0.], type=1)
36+
self.system.part.add(pos=[11., 13., 12.], f=[0., 0., -8.], type=1)
3637
self.system.ase = espressomd.plugins.ase.ASEInterface(
3738
type_mapping={0: "H", 1: "O"},
3839
)
@@ -43,15 +44,18 @@ def tearDown(self):
4344
def test_ase_get(self):
4445
"""Test the ``ASEInterface.get()`` method."""
4546
# Create a simple ASE atoms object
46-
atoms = self.system.ase.get()
47-
self.assertIsInstance(atoms, ase.Atoms)
48-
self.assertEqual(set(atoms.get_chemical_symbols()), {"H", "O"})
49-
np.testing.assert_equal(atoms.pbc, np.copy(self.system.periodicity))
50-
np.testing.assert_allclose(atoms.cell, np.diag(self.system.box_l))
51-
np.testing.assert_allclose(atoms.get_positions(),
52-
[[0., 0., 0.], [0., 0., 1.]])
53-
np.testing.assert_allclose(atoms.get_forces(),
54-
[[1., -1., 0.], [0., 12., 0.]])
47+
for folded in [True, False]:
48+
atoms = self.system.ase.get(folded=folded)
49+
self.assertIsInstance(atoms, ase.Atoms)
50+
self.assertEqual(set(atoms.get_chemical_symbols()), {"H", "O"})
51+
np.testing.assert_equal(
52+
atoms.pbc, np.copy(self.system.periodicity))
53+
np.testing.assert_allclose(atoms.cell, np.diag(self.system.box_l))
54+
positions_ref = self.system.part.all(
55+
).pos_folded if folded else self.system.part.all().pos
56+
np.testing.assert_allclose(atoms.get_positions(), positions_ref)
57+
np.testing.assert_allclose(atoms.get_forces(),
58+
[[1., -1., 0.], [0., 12., 0.], [0., 0., -8.]])
5559

5660
@utx.skipIfMissingFeatures("VIRTUAL_SITES_RELATIVE")
5761
def test_exceptions(self):

0 commit comments

Comments
 (0)