Skip to content

Commit ec0b317

Browse files
committed
feat: add mcstas h5 file loader
1 parent 10ac274 commit ec0b317

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

src/ess/estia/mcstas.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import h5py
12
import numpy as np
23
import scipp as sc
34

5+
from ess.reflectometry.load import load_h5
6+
47

58
def parse_metadata_ascii(lines):
69
data = {}
@@ -50,12 +53,26 @@ def parse_events_ascii(lines):
5053

5154

5255
def parse_events_h5(f):
53-
pass
54-
56+
if isinstance(f, str):
57+
with h5py.File(f) as ff:
58+
return parse_events_h5(ff)
5559

56-
'''
57-
def parse_events_h5(f):
58-
f['entry1/data']
59-
events = load_nx(f, 'NXentry/NXdetector/NXdata')
60-
parameters = load_nx(f, 'NXentry/simulation/Param')
61-
'''
60+
data, events, params = load_h5(
61+
f,
62+
'NXentry/NXdetector/NXdata',
63+
'NXentry/NXdetector/NXdata/events',
64+
'NXentry/simulation/Param',
65+
)
66+
da = sc.DataArray(
67+
sc.array(dims=['events'], values=events[:, 0], variances=events[:, 0] ** 2),
68+
)
69+
for i, label in enumerate(data.attrs["ylabel"].decode().strip().split(' ')):
70+
if label == 'p':
71+
continue
72+
da.coords[label] = sc.array(dims=['events'], values=events[:, i])
73+
for k, v in params.items():
74+
v = v[0]
75+
if isinstance(v, bytes):
76+
v = v.decode()
77+
da.coords[k] = sc.scalar(v)
78+
return da

src/ess/reflectometry/load.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
3-
3+
import h5py
44
import sciline
55
import scipp as sc
66
import scippnexus as snx
@@ -37,6 +37,34 @@ def _unique_child_group(
3737
return next(iter(children.values())) # type: ignore[return-value]
3838

3939

40+
def load_h5(group: h5py.Group | str, *paths: str):
41+
if isinstance(group, str):
42+
with h5py.File(group) as group:
43+
yield from load_h5(group, *paths)
44+
return
45+
for path in paths:
46+
g = group
47+
for p in path.strip('/').split('/'):
48+
g = _unique_child_group_h5(g, p) if p.startswith('NX') else g.get(p)
49+
yield g
50+
51+
52+
def _unique_child_group_h5(
53+
group: h5py.Group,
54+
nx_class: str,
55+
) -> h5py.Group | None:
56+
out = None
57+
for v in group.values():
58+
if v.attrs.get("NX_class") == nx_class.encode():
59+
if out is None:
60+
out = v
61+
else:
62+
raise ValueError(
63+
f'Expected exactly one {nx_class} group, but found more'
64+
)
65+
return out
66+
67+
4068
def save_reference(pl: sciline.Pipeline, fname: str):
4169
pl.compute(ReducedReference).save_hdf5(fname)
4270
return fname

0 commit comments

Comments
 (0)