Skip to content

Commit 5b84510

Browse files
committed
Merge branch 'master' of github.com:zerothi/sisl
2 parents 35aae59 + e8b53f3 commit 5b84510

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

sisl/io/siesta/bands.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
import sisl._array as _a
34
from sisl.utils import strmap
45
from sisl.utils.cmd import default_ArgumentParser, default_namespace
56
from ..sile import add_sile, sile_fh_open
@@ -13,8 +14,16 @@ class bandsSileSiesta(SileSiesta):
1314
""" Bandstructure information """
1415

1516
@sile_fh_open()
16-
def read_data(self):
17-
""" Returns data associated with the bands file """
17+
def read_data(self, as_dataarray=False):
18+
""" Returns data associated with the bands file
19+
20+
Parameters
21+
--------
22+
as_dataarray: boolean, optional
23+
if `True`, the information is returned as an `xarray.DataArray`
24+
Ticks (if read) are stored as an attribute of the DataArray
25+
(under `array.ticks` and `array.ticklabels`)
26+
"""
1827
band_lines = False
1928

2029
# Luckily the data is in eV
@@ -35,19 +44,19 @@ def read_data(self):
3544
no, ns, nk = map(int, l.split())
3645

3746
# Create the data to contain all band points
38-
b = np.empty([nk, ns, no], np.float64)
47+
b = _a.emptyd([nk, ns, no])
3948

4049
# for band-lines
4150
if band_lines:
42-
k = np.empty([nk], np.float64)
51+
k = _a.emptyd([nk])
4352
for ik in range(nk):
4453
l = [float(x) for x in self.readline().split()]
4554
k[ik] = l[0]
4655
del l[0]
4756
# Now populate the eigenvalues
4857
while len(l) < ns * no:
4958
l.extend([float(x) for x in self.readline().split()])
50-
l = np.array(l, np.float64)
59+
l = _a.arrayd(l)
5160
l.shape = (ns, no)
5261
b[ik, :, :] = l[:, :] - Ef
5362
# Now we need to read the labels for the points
@@ -61,7 +70,7 @@ def read_data(self):
6170
vals = (xlabels, labels), k, b
6271

6372
else:
64-
k = np.empty([nk, 3], np.float64)
73+
k = _a.emptyd([nk, 3])
6574
for ik in range(nk):
6675
l = [float(x) for x in self.readline().split()]
6776
k[ik, :] = l[0:3]
@@ -71,10 +80,21 @@ def read_data(self):
7180
# Now populate the eigenvalues
7281
while len(l) < ns * no:
7382
l.extend([float(x) for x in self.readline().split()])
74-
l = np.array(l, np.float64)
83+
l = _a.arrayd(l)
7584
l.shape = (ns, no)
7685
b[ik, :, :] = l[:, :] - Ef
7786
vals = k, b
87+
88+
if as_dataarray:
89+
from xarray import DataArray
90+
91+
ticks = {"ticks": xlabels, "ticklabels": labels} if band_lines else {}
92+
93+
return DataArray(b, name="Energy", attrs=ticks,
94+
coords=[("k", k),
95+
("spin", _a.arangei(0, b.shape[1])),
96+
("band", _a.arangei(0, b.shape[2]))])
97+
7898
return vals
7999

80100
@default_ArgumentParser(description="Manipulate bands file in sisl.")

sisl/io/siesta/out.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,15 @@ def next_force():
291291
# Now read data
292292
F = []
293293
line = self.readline()
294+
if 'siesta:' in line:
295+
# This is the final summary, we don't need to read it as it does not contain new information
296+
# and also it make break things since max forces are not written there
297+
return None
294298

295299
# First, we encounter the atomic forces
296300
while '---' not in line:
297301
line = line.split()
298-
if not total or max:
302+
if not (total or max):
299303
F.append([float(x) for x in line[-3:]])
300304
line = self.readline()
301305
if line == '':
@@ -328,8 +332,8 @@ def return_forces(Fs):
328332
if max and total:
329333
return (Fs[..., :-1], Fs[..., -1])
330334
elif max and not all:
331-
# This will return a float (or actually a numpy.dtype)
332-
return Fs[0]
335+
# This will return a float
336+
return np.atleast_1d(Fs)[0]
333337
return Fs
334338

335339
if all or last:
@@ -342,11 +346,8 @@ def return_forces(Fs):
342346
Fs.append(F)
343347

344348
if last:
345-
return return_forces(Fs[-2])
346-
# F[-2] is really the same as F[-1], the last forces are stated twice
347-
# However, the maxForce is not stated in the final summary, that's why we use F[-2]
348-
if self.job_completed:
349-
return return_forces(Fs[:-1])
349+
return return_forces(Fs[-1])
350+
350351
return return_forces(Fs)
351352

352353
return return_forces(next_force())

sisl/io/siesta/pdos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def to(o, DOS):
106106
return xr.DataArray(data=process(DOS).reshape(shape),
107107
dims=dims, coords=coords, name='PDOS')
108108

109-
D = xr.DataArray()
109+
D = xr.DataArray([])
110110
else:
111111
def to(o, DOS):
112112
return process(DOS)

sisl/io/siesta/tests/test_bands.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def test_fe(sisl_files):
1616
assert eig.shape == (131, 2, 15)
1717
assert len(labels[0]) == 5
1818

19-
2019
def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2120
try:
2221
import matplotlib
@@ -28,3 +27,17 @@ def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2827
p.parse_args([], namespace=ns)
2928
p.parse_args(['--energy', ' -2:2'], namespace=ns)
3029
p.parse_args(['--energy', ' -2:2', '--plot', png], namespace=ns)
30+
31+
32+
def test_fe_xarray(sisl_files, sisl_tmp):
33+
try:
34+
import xarray
35+
except ImportError:
36+
pytest.skip('xarray not available')
37+
si = sisl.get_sile(sisl_files(_dir, 'fe.bands'))
38+
39+
bands = si.read_data(as_dataarray=True)
40+
assert len(bands['k']) == 131
41+
assert len(bands['spin']) == 2
42+
assert len(bands['band']) == 15
43+
assert len(bands.ticks) == len(bands.ticklabels) == 5

0 commit comments

Comments
 (0)