Skip to content

Commit 74dad34

Browse files
committed
enh: fixed #182
Now BrillouinZone objects can return DataArray objects from xarray (if it can be imported). It is called asdatarray. This makes things much easier for certain users. By default the dimensions are named 'k', 'v1', ... but users may provide v1, ... names by an argument. Also, the name of the DataArray will be the function name Signed-off-by: Nick Papior <[email protected]>
1 parent bba46c8 commit 74dad34

File tree

6 files changed

+145
-5
lines changed

6 files changed

+145
-5
lines changed

CHANGELOG

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
0.9.9
2+
=====
3+
4+
- Enabled xarray.DataArray returning from BrillouinZone objects #182
5+
6+
- Several improvements to outSileSiesta.read_scf (thanks to Pol Febrer)
7+
8+
- A huge performance increase for data extraction in tbtncSileTbtrans
9+
(thanks to Gaetano Calogero for finding the bottleneck)
10+
11+
- Added preliminary usage of Mixers, primarily intented for extending
12+
sisl operations where SCF are used (may heavily change).
13+
14+
- Now sisl is Python >=3.6 only
15+
116
0.9.8
217
=====
318

sisl/io/siesta/bands.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class bandsSileSiesta(SileSiesta):
1616
@sile_fh_open()
1717
def read_data(self, as_dataarray=False):
1818
""" Returns data associated with the bands file
19-
19+
2020
Parameters
2121
--------
2222
as_dataarray: boolean, optional
@@ -84,12 +84,12 @@ def read_data(self, as_dataarray=False):
8484
l.shape = (ns, no)
8585
b[ik, :, :] = l[:, :] - Ef
8686
vals = k, b
87-
87+
8888
if as_dataarray:
8989
from xarray import DataArray
9090

9191
ticks = {"ticks": xlabels, "ticklabels": labels} if band_lines else {}
92-
92+
9393
return DataArray(b, name="Energy", attrs=ticks,
9494
coords=[("k", k),
9595
("spin", _a.arangei(0, b.shape[1])),

sisl/io/siesta/out.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def return_forces(Fs):
346346

347347
if last:
348348
return return_forces(Fs[-1])
349-
349+
350350
return return_forces(Fs)
351351

352352
return return_forces(next_force())

sisl/io/siesta/tests/test_bands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def test_fe(sisl_files):
1616
assert eig.shape == (131, 2, 15)
1717
assert len(labels[0]) == 5
1818

19+
1920
def test_fe_ArgumentParser(sisl_files, sisl_tmp):
2021
try:
2122
import matplotlib

sisl/physics/brillouinzone.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@
101101
from sisl.supercell import SuperCell
102102
from sisl.grid import Grid
103103

104+
try:
105+
import xarray
106+
_has_xarray = True
107+
except ImportError:
108+
_has_xarray = False
109+
104110

105111
__all__ = ['BrillouinZone', 'MonkhorstPack', 'BandStructure']
106112

@@ -477,7 +483,7 @@ def asarray(self):
477483
... spin_moment = (es.spin_moment(E, distribution=dist) * occ.reshape(-1, 1)).sum(0)
478484
... return oplist([DOS, PDOS, spin_moment])
479485
>>> bz = BrillouinZone(hamiltonian)
480-
>>> DOS, PDOS, spin_moment = bz.asaverage().eigenstate(wrap=wrap)
486+
>>> DOS, PDOS, spin_moment = bz.asarray().eigenstate(wrap=wrap)
481487
482488
See Also
483489
--------
@@ -567,6 +573,100 @@ def _call(self, *args, **kwargs):
567573
setattr(self, '_bz_call', types.MethodType(_call, self))
568574
return self
569575

576+
if _has_xarray:
577+
def asdataarray(self):
578+
r""" Return `self` with `xarray.DataArray` returned quantities
579+
580+
This forces the `__call__` routine to return a single `xarray.DataArray`.
581+
582+
Notes
583+
-----
584+
If you wrap the sub-method to return multiple data-sets, you should use `asdataset`
585+
instead which returns a combination of data-arrays (so-called `xarray.Dataset`).
586+
587+
All invocations of sub-methods are added these keyword-only arguments:
588+
589+
eta : bool, optional
590+
if true a progress-bar is created, default false.
591+
wrap : callable, optional
592+
a function that accepts the output of the given routine and post-process
593+
it. Defaults to ``lambda x: x``.
594+
coords : list of str or list of (str, array), optional
595+
a list of coordinates used in ``xarray.DataArray(..., coords=coords)``.
596+
By default the coordinates are named ``['k', 'v1', 'v2', ...]``
597+
depending on the shape of the returned quantity.
598+
These may optionally be a list of tuples (not a dictionary)!
599+
600+
Examples
601+
--------
602+
>>> obj = BrillouinZone(...)
603+
>>> obj.asdataarray().eigh(eta=True)
604+
605+
See Also
606+
--------
607+
asyield : all output returned through an iterator
608+
asaverage : take the average (with k-weights) of the Brillouin zone
609+
assum : return the sum of values in the Brillouin zone
610+
aslist : all output returned as a Python list
611+
"""
612+
613+
def _call(self, *args, **kwargs):
614+
func = self._bz_get_func()
615+
616+
# xarray specific data (default to function name)
617+
name = kwargs.pop('name', func.__name__)
618+
coords = kwargs.pop('coords', None)
619+
620+
has_wrap = 'wrap' in kwargs
621+
if has_wrap:
622+
wrap = allow_kwargs('parent', 'k', 'weight')(kwargs.pop('wrap'))
623+
eta = tqdm_eta(len(self), self.__class__.__name__ + '.asarray',
624+
'k', kwargs.pop('eta', False))
625+
parent = self.parent
626+
k = self.k
627+
w = self.weight
628+
if has_wrap:
629+
v = wrap(func(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0])
630+
else:
631+
v = func(*args, k=k[0], **kwargs)
632+
if v.ndim == 0:
633+
a = np.empty([len(self)], dtype=v.dtype)
634+
else:
635+
a = np.empty((len(self), ) + v.shape, dtype=v.dtype)
636+
a[0] = v
637+
del v
638+
eta.update()
639+
if has_wrap:
640+
for i in range(1, len(k)):
641+
a[i] = wrap(func(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i])
642+
eta.update()
643+
else:
644+
for i in range(1, len(k)):
645+
a[i] = func(*args, k=k[i], **kwargs)
646+
eta.update()
647+
eta.close()
648+
649+
# Create coords
650+
if coords is None:
651+
coords = [('k', _a.arangei(len(self)))]
652+
for i, v in enumerate(a.shape[1:]):
653+
coords.append((f"v{i+1}", _a.arangei(v)))
654+
else:
655+
coords = list(coords)
656+
coords.insert(0, ('k', _a.arangei(len(self))))
657+
for i in range(1, len(coords)):
658+
if isinstance(coords[i], str):
659+
coords[i] = (coords[i], _a.arangei(a.shape[i]))
660+
attrs = {'bz': self,
661+
'parent': self.parent,
662+
}
663+
664+
return xarray.DataArray(a, coords=coords, name=name, attrs=attrs)
665+
666+
# Set instance __bz_call
667+
setattr(self, '_bz_call', types.MethodType(_call, self))
668+
return self
669+
570670
def aslist(self):
571671
""" Return `self` with `list` returned quantities
572672

sisl/physics/tests/test_brillouinzone.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,30 @@ def test_as_simple(self):
248248
assert np.allclose((asarray / len(bz)).sum(0), asaverage)
249249
bz.asnone().eigh()
250250

251+
def test_as_dataarray(self):
252+
try:
253+
import xarray
254+
except ImportError:
255+
pytest.skip('xarray not available')
256+
257+
from sisl import geom, Hamiltonian
258+
g = geom.graphene()
259+
H = Hamiltonian(g)
260+
H.construct([[0.1, 1.44], [0, -2.7]])
261+
262+
bz = MonkhorstPack(H, [2, 2, 2], trs=False)
263+
264+
# Assert that as* all does the same
265+
asarray = bz.asarray().eigh()
266+
asdarray = bz.asdataarray().eigh()
267+
assert np.allclose(asarray, asdarray.values)
268+
assert isinstance(asdarray.bz, MonkhorstPack)
269+
assert isinstance(asdarray.parent, Hamiltonian)
270+
assert asdarray.dims == ('k', 'v1')
271+
272+
asdarray = bz.asdataarray().eigh(coords=['orb'])
273+
assert asdarray.dims == ('k', 'orb')
274+
251275
def test_as_single(self):
252276
from sisl import geom, Hamiltonian
253277
g = geom.graphene()

0 commit comments

Comments
 (0)