Skip to content

Commit dda8b97

Browse files
committed
test: add output shape test
1 parent 9a49e9a commit dda8b97

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

tests/reflectometry/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
3+
import numpy as np
4+
import pytest
5+
import scipp as sc
6+
7+
from ess.reflectometry.normalization import reduce_sample_over_zw
8+
9+
10+
@pytest.fixture
11+
def sample(request):
12+
n = 50
13+
da = sc.DataArray(
14+
data=sc.ones(dims=('events',), shape=(n,)),
15+
coords={
16+
'wavelength': sc.linspace('events', 1, 5, n),
17+
'wire': sc.array(dims=('events',), values=np.random.randint(0, 5, n)),
18+
'stripe': sc.array(dims=('events',), values=np.random.randint(0, 10, n)),
19+
},
20+
)
21+
return da.group('wire', 'stripe')
22+
23+
24+
@pytest.fixture
25+
def reference(request):
26+
n = 50
27+
da = sc.DataArray(
28+
data=sc.ones(dims=('events',), shape=(n,)),
29+
coords={
30+
'wavelength': sc.linspace('events', 1, 5, n),
31+
'wire': sc.array(dims=('events',), values=np.random.randint(0, 5, n)),
32+
'stripe': sc.array(dims=('events',), values=np.random.randint(0, 10, n)),
33+
},
34+
)
35+
return da.group('wire').bin(wavelength=2).bins.sum()
36+
37+
38+
def test_reduce_sample_over_zw_when_data_not_dimensionless(sample, reference):
39+
sample = sample.copy(deep=True)
40+
sample.bins.unit = '1/s'
41+
reduce_sample_over_zw(
42+
sample,
43+
reference,
44+
reference.coords['wavelength'],
45+
)
46+
47+
48+
def test_reduce_sample_over_zw_has_expected_coords(sample, reference):
49+
r = reduce_sample_over_zw(
50+
sample,
51+
reference,
52+
reference.coords['wavelength'],
53+
)
54+
assert r.dims == reference.dims
55+
assert r.shape == reference.shape
56+
assert not (sample.bins is None) ^ (r.bins is None)

0 commit comments

Comments
 (0)