Skip to content

Commit d66bb53

Browse files
committed
Merge branch 'bandsXarray'
2 parents 6a97732 + c39200e commit d66bb53

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

sisl/io/siesta/bands.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
import sisl._array as _a
34
from sisl.utils import strmap
45
from sisl.utils.cmd import default_ArgumentParser, default_namespace
56
from ..sile import add_sile, sile_fh_open
@@ -13,8 +14,16 @@ class bandsSileSiesta(SileSiesta):
1314
""" Bandstructure information """
1415

1516
@sile_fh_open()
16-
def read_data(self):
17-
""" Returns data associated with the bands file """
17+
def read_data(self, as_dataarray=False):
18+
""" Returns data associated with the bands file
19+
20+
Parameters
21+
--------
22+
as_dataarray: boolean, optional
23+
if `True`, the information is returned as an `xarray.DataArray`
24+
Ticks (if read) are stored as an attribute of the DataArray
25+
(under `array.ticks` and `array.ticklabels`)
26+
"""
1827
band_lines = False
1928

2029
# Luckily the data is in eV
@@ -35,19 +44,19 @@ def read_data(self):
3544
no, ns, nk = map(int, l.split())
3645

3746
# Create the data to contain all band points
38-
b = np.empty([nk, ns, no], np.float64)
47+
b = _a.emptyd([nk, ns, no])
3948

4049
# for band-lines
4150
if band_lines:
42-
k = np.empty([nk], np.float64)
51+
k = _a.emptyd([nk])
4352
for ik in range(nk):
4453
l = [float(x) for x in self.readline().split()]
4554
k[ik] = l[0]
4655
del l[0]
4756
# Now populate the eigenvalues
4857
while len(l) < ns * no:
4958
l.extend([float(x) for x in self.readline().split()])
50-
l = np.array(l, np.float64)
59+
l = _a.arrayd(l)
5160
l.shape = (ns, no)
5261
b[ik, :, :] = l[:, :] - Ef
5362
# Now we need to read the labels for the points
@@ -61,7 +70,7 @@ def read_data(self):
6170
vals = (xlabels, labels), k, b
6271

6372
else:
64-
k = np.empty([nk, 3], np.float64)
73+
k = _a.emptyd([nk, 3])
6574
for ik in range(nk):
6675
l = [float(x) for x in self.readline().split()]
6776
k[ik, :] = l[0:3]
@@ -71,10 +80,21 @@ def read_data(self):
7180
# Now populate the eigenvalues
7281
while len(l) < ns * no:
7382
l.extend([float(x) for x in self.readline().split()])
74-
l = np.array(l, np.float64)
83+
l = _a.arrayd(l)
7584
l.shape = (ns, no)
7685
b[ik, :, :] = l[:, :] - Ef
7786
vals = k, b
87+
88+
if as_dataarray:
89+
from xarray import DataArray
90+
91+
ticks = {"ticks": xlabels, "ticklabels": labels} if band_lines else {}
92+
93+
return DataArray(b, name="Energy", attrs=ticks,
94+
coords=[("k", k),
95+
("spin", _a.arangei(0, b.shape[1])),
96+
("band", _a.arangei(0, b.shape[2]))])
97+
7898
return vals
7999

80100
@default_ArgumentParser(description="Manipulate bands file in sisl.")

sisl/io/siesta/pdos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def to(o, DOS):
106106
return xr.DataArray(data=process(DOS).reshape(shape),
107107
dims=dims, coords=coords, name='PDOS')
108108

109-
D = xr.DataArray()
109+
D = xr.DataArray([])
110110
else:
111111
def to(o, DOS):
112112
return process(DOS)

sisl/io/siesta/tests/test_bands.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def test_fe(sisl_files):
1616
assert eig.shape == (131, 2, 15)
1717
assert len(labels[0]) == 5
1818

19-
2019
def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2120
try:
2221
import matplotlib
@@ -28,3 +27,17 @@ def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2827
p.parse_args([], namespace=ns)
2928
p.parse_args(['--energy', ' -2:2'], namespace=ns)
3029
p.parse_args(['--energy', ' -2:2', '--plot', png], namespace=ns)
30+
31+
32+
def test_fe_xarray(sisl_files, sisl_tmp):
33+
try:
34+
import xarray
35+
except ImportError:
36+
pytest.skip('xarray not available')
37+
si = sisl.get_sile(sisl_files(_dir, 'fe.bands'))
38+
39+
bands = si.read_data(as_dataarray=True)
40+
assert len(bands['k']) == 131
41+
assert len(bands['spin']) == 2
42+
assert len(bands['band']) == 15
43+
assert len(bands.ticks) == len(bands.ticklabels) == 5

0 commit comments

Comments
 (0)