Skip to content

Commit c39200e

Browse files
committed
maint: fixed minor things in bandsXarray
Changed names for attributes (matches matplotlib) Changed names for coordinates to match names used throughout sisl. Also added use of sisl._array Fixed a bug in PDOS for empty data. Signed-off-by: Nick Papior <[email protected]>
1 parent cea63a8 commit c39200e

File tree

3 files changed

+27
-25
lines changed

3 files changed

+27
-25
lines changed

sisl/io/siesta/bands.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
import sisl._array as _a
34
from sisl.utils import strmap
45
from sisl.utils.cmd import default_ArgumentParser, default_namespace
56
from ..sile import add_sile, sile_fh_open
@@ -21,7 +22,7 @@ def read_data(self, as_dataarray=False):
2122
as_dataarray: boolean, optional
2223
if `True`, the information is returned as an `xarray.DataArray`
2324
Ticks (if read) are stored as an attribute of the DataArray
24-
(under `array.tick_vals` and `array.tick_labels`)
25+
(under `array.ticks` and `array.ticklabels`)
2526
"""
2627
band_lines = False
2728

@@ -43,19 +44,19 @@ def read_data(self, as_dataarray=False):
4344
no, ns, nk = map(int, l.split())
4445

4546
# Create the data to contain all band points
46-
b = np.empty([nk, ns, no], np.float64)
47+
b = _a.emptyd([nk, ns, no])
4748

4849
# for band-lines
4950
if band_lines:
50-
k = np.empty([nk], np.float64)
51+
k = _a.emptyd([nk])
5152
for ik in range(nk):
5253
l = [float(x) for x in self.readline().split()]
5354
k[ik] = l[0]
5455
del l[0]
5556
# Now populate the eigenvalues
5657
while len(l) < ns * no:
5758
l.extend([float(x) for x in self.readline().split()])
58-
l = np.array(l, np.float64)
59+
l = _a.arrayd(l)
5960
l.shape = (ns, no)
6061
b[ik, :, :] = l[:, :] - Ef
6162
# Now we need to read the labels for the points
@@ -69,7 +70,7 @@ def read_data(self, as_dataarray=False):
6970
vals = (xlabels, labels), k, b
7071

7172
else:
72-
k = np.empty([nk, 3], np.float64)
73+
k = _a.emptyd([nk, 3])
7374
for ik in range(nk):
7475
l = [float(x) for x in self.readline().split()]
7576
k[ik, :] = l[0:3]
@@ -79,26 +80,20 @@ def read_data(self, as_dataarray=False):
7980
# Now populate the eigenvalues
8081
while len(l) < ns * no:
8182
l.extend([float(x) for x in self.readline().split()])
82-
l = np.array(l, np.float64)
83+
l = _a.arrayd(l)
8384
l.shape = (ns, no)
8485
b[ik, :, :] = l[:, :] - Ef
8586
vals = k, b
8687

8788
if as_dataarray:
8889
from xarray import DataArray
8990

90-
ticks = {"tick_vals": xlabels, "tick_labels": labels} if band_lines else {}
91+
ticks = {"ticks": xlabels, "ticklabels": labels} if band_lines else {}
9192

92-
return DataArray(
93-
name="Energy",
94-
data=b,
95-
coords=[
96-
("K", k),
97-
("spin", np.arange(0,b.shape[1])),
98-
("iBand", np.arange(0,b.shape[2]) + 1)
99-
],
100-
attrs= {**ticks}
101-
)
93+
return DataArray(b, name="Energy", attrs=ticks,
94+
coords=[("k", k),
95+
("spin", _a.arangei(0, b.shape[1])),
96+
("band", _a.arangei(0, b.shape[2]))])
10297

10398
return vals
10499

sisl/io/siesta/pdos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def to(o, DOS):
106106
return xr.DataArray(data=process(DOS).reshape(shape),
107107
dims=dims, coords=coords, name='PDOS')
108108

109-
D = xr.DataArray()
109+
D = xr.DataArray([])
110110
else:
111111
def to(o, DOS):
112112
return process(DOS)

sisl/io/siesta/tests/test_bands.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ def test_fe(sisl_files):
1515
assert k.shape == (131, )
1616
assert eig.shape == (131, 2, 15)
1717
assert len(labels[0]) == 5
18-
#Test the dataarray implementation
19-
bands = si.read_data(as_dataarray=True)
20-
assert bands['K'].shape == (131,)
21-
assert bands['spin'].shape == (2,)
22-
assert bands['iBand'].shape == (15,)
23-
assert len(bands.tick_vals) == len(bands.tick_labels) == 5
24-
2518

2619
def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2720
try:
@@ -34,3 +27,17 @@ def test_fe_ArgumentParser(sisl_files, sisl_tmp):
3427
p.parse_args([], namespace=ns)
3528
p.parse_args(['--energy', ' -2:2'], namespace=ns)
3629
p.parse_args(['--energy', ' -2:2', '--plot', png], namespace=ns)
30+
31+
32+
def test_fe_xarray(sisl_files, sisl_tmp):
33+
try:
34+
import xarray
35+
except ImportError:
36+
pytest.skip('xarray not available')
37+
si = sisl.get_sile(sisl_files(_dir, 'fe.bands'))
38+
39+
bands = si.read_data(as_dataarray=True)
40+
assert len(bands['k']) == 131
41+
assert len(bands['spin']) == 2
42+
assert len(bands['band']) == 15
43+
assert len(bands.ticks) == len(bands.ticklabels) == 5

0 commit comments

Comments
 (0)