Skip to content

Commit e174577

Browse files
committed
TST: migrate out of answer test framework and to pytest
1 parent 8fc6796 commit e174577

File tree

7 files changed

+216
-280
lines changed

7 files changed

+216
-280
lines changed

yt_astro_analysis/cosmological_observation/light_cone/tests/test_light_cone.py

Lines changed: 61 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,96 +14,79 @@
1414
# -----------------------------------------------------------------------------
1515

1616
import os
17-
import shutil
18-
import tempfile
1917

18+
import h5py
2019
import numpy as np
20+
import numpy.testing as npt
21+
import pytest
22+
import unyt as un
2123

22-
from yt.testing import assert_equal
23-
from yt.units.yt_array import YTQuantity
24-
from yt.utilities.answer_testing.framework import AnswerTestingTest
25-
from yt.utilities.on_demand_imports import _h5py as h5py
24+
import yt # noqa
25+
from yt.testing import requires_file
2626
from yt_astro_analysis.cosmological_observation.api import LightCone
27-
from yt_astro_analysis.utilities.testing import requires_sim
2827

2928
ETC = "enzo_tiny_cosmology/32Mpc_32.enzo"
3029
_funits = {
31-
"density": YTQuantity(1, "g/cm**3"),
32-
"temperature": YTQuantity(1, "K"),
33-
"length": YTQuantity(1, "cm"),
30+
"density": un.unyt_quantity(1, "g/cm**3"),
31+
"temperature": un.unyt_quantity(1, "K"),
32+
"length": un.unyt_quantity(1, "cm"),
3433
}
3534

3635

37-
class LightConeProjectionTest(AnswerTestingTest):
38-
_type_name = "LightConeProjection"
39-
_attrs = ()
40-
41-
def __init__(self, parameter_file, simulation_type, field, weight_field=None):
42-
self.parameter_file = parameter_file
43-
self.simulation_type = simulation_type
44-
self.ds = os.path.basename(self.parameter_file)
45-
self.field = field
46-
self.weight_field = weight_field
47-
48-
@property
49-
def storage_name(self):
50-
return "_".join(
51-
(os.path.basename(self.parameter_file), self.field, str(self.weight_field))
52-
)
53-
54-
def run(self):
55-
# Set up in a temp dir
56-
tmpdir = tempfile.mkdtemp()
57-
curdir = os.getcwd()
58-
os.chdir(tmpdir)
59-
60-
lc = LightCone(
61-
self.parameter_file,
62-
self.simulation_type,
63-
0.0,
64-
0.1,
65-
observer_redshift=0.0,
66-
time_data=False,
67-
)
68-
lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt")
69-
lc.project_light_cone(
70-
(600.0, "arcmin"),
71-
(60.0, "arcsec"),
72-
self.field,
73-
weight_field=self.weight_field,
74-
save_stack=True,
75-
)
76-
77-
dname = f"{self.field}_{self.weight_field}"
78-
fh = h5py.File("LC/LightCone.h5", mode="r")
36+
@requires_file(ETC)
37+
@pytest.mark.parametrize(
38+
"field, weight_field, expected",
39+
[
40+
(
41+
"density",
42+
None,
43+
[6.0000463633868075e-05, 1.1336502301470154e-05, 0.08970763360935877],
44+
),
45+
(
46+
"temperature",
47+
"density",
48+
[37.79481498628398, 0.018410545597485613, 543702.4613479003],
49+
),
50+
],
51+
)
52+
def test_light_cone_projection(tmp_path, field, weight_field, expected):
53+
parameter_file = ETC
54+
simulation_type = "Enzo"
55+
field = field
56+
weight_field = weight_field
57+
58+
os.chdir(tmp_path)
59+
lc = LightCone(
60+
parameter_file,
61+
simulation_type,
62+
near_redshift=0.0,
63+
far_redshift=0.1,
64+
observer_redshift=0.0,
65+
time_data=False,
66+
)
67+
lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt")
68+
lc.project_light_cone(
69+
(600.0, "arcmin"),
70+
(60.0, "arcsec"),
71+
field,
72+
weight_field=weight_field,
73+
save_stack=True,
74+
)
75+
76+
dname = f"{field}_{weight_field}"
77+
with h5py.File("LC/LightCone.h5", mode="r") as fh:
7978
data = fh[dname][()]
8079
units = fh[dname].attrs["units"]
81-
if self.weight_field is None:
82-
punits = _funits[self.field] * _funits["length"]
80+
if weight_field is None:
81+
punits = _funits[field] * _funits["length"]
8382
else:
84-
punits = (
85-
_funits[self.field] * _funits[self.weight_field] * _funits["length"]
86-
)
87-
wunits = fh["weight_field_%s" % self.weight_field].attrs["units"]
88-
pwunits = _funits[self.weight_field] * _funits["length"]
83+
punits = _funits[field] * _funits[weight_field] * _funits["length"]
84+
wunits = fh[f"weight_field_{weight_field}"].attrs["units"]
85+
pwunits = _funits[weight_field] * _funits["length"]
8986
assert wunits == str(pwunits.units)
90-
assert units == str(punits.units)
91-
fh.close()
92-
93-
# clean up
94-
os.chdir(curdir)
95-
shutil.rmtree(tmpdir)
96-
97-
mean = data.mean()
98-
mi = data[data.nonzero()].min()
99-
ma = data.max()
100-
return np.array([mean, mi, ma])
101-
102-
def compare(self, new_result, old_result):
103-
assert_equal(new_result, old_result, verbose=True)
104-
87+
assert units == str(punits.units)
10588

106-
@requires_sim(ETC, "Enzo")
107-
def test_light_cone_projection():
108-
yield LightConeProjectionTest(ETC, "Enzo", "density")
109-
yield LightConeProjectionTest(ETC, "Enzo", "temperature", weight_field="density")
89+
mean = np.nanmean(data)
90+
mi = np.nanmin(data[data.nonzero()])
91+
ma = np.nanmax(data)
92+
npt.assert_equal([mean, mi, ma], expected, verbose=True)

yt_astro_analysis/halo_analysis/halo_catalog/analysis_operators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def add_quantity(name, function):
6363
quantity_registry[name] = AnalysisQuantity(function)
6464

6565

66+
def _remove_quantity(name):
67+
# this is useful to avoid test pollution when using add_quantity in tests
68+
# but it's not meant as public API
69+
quantity_registry.pop(name)
70+
71+
6672
class AnalysisQuantity(AnalysisCallback):
6773
r"""
6874
An AnalysisQuantity is a function that takes minimally a target object,

yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,73 +13,57 @@
1313
# The full license is in the file COPYING.txt, distributed with this software.
1414
# -----------------------------------------------------------------------------
1515

16-
import os
17-
import shutil
18-
import tempfile
19-
20-
import numpy as np
16+
import numpy.testing as npt
17+
import pytest
18+
import unyt as un
2119

2220
from yt.loaders import load
23-
from yt.testing import assert_equal
24-
from yt.utilities.answer_testing.framework import (
25-
AnswerTestingTest,
26-
data_dir_load,
27-
requires_ds,
21+
from yt.testing import requires_file
22+
from yt_astro_analysis.halo_analysis import HaloCatalog
23+
from yt_astro_analysis.halo_analysis.halo_catalog.analysis_operators import (
24+
_remove_quantity,
25+
add_quantity,
2826
)
29-
from yt_astro_analysis.halo_analysis import HaloCatalog, add_quantity
30-
31-
32-
def _nstars(halo):
33-
sp = halo.data_object
34-
return (sp["all", "creation_time"] > 0).sum()
35-
36-
37-
add_quantity("nstars", _nstars)
38-
39-
40-
class HaloQuantityTest(AnswerTestingTest):
41-
_type_name = "HaloQuantity"
42-
_attrs = ()
43-
44-
def __init__(self, data_ds_fn, halos_ds_fn):
45-
self.data_ds_fn = data_ds_fn
46-
self.halos_ds_fn = halos_ds_fn
47-
self.ds = data_dir_load(data_ds_fn)
48-
49-
def run(self):
50-
curdir = os.getcwd()
51-
tmpdir = tempfile.mkdtemp()
52-
os.chdir(tmpdir)
53-
54-
dds = data_dir_load(self.data_ds_fn)
55-
hds = data_dir_load(self.halos_ds_fn)
56-
hc = HaloCatalog(
57-
data_ds=dds, halos_ds=hds, output_dir=os.path.join(tmpdir, str(dds))
58-
)
59-
hc.add_callback("sphere")
60-
hc.add_quantity("nstars")
61-
hc.create()
62-
63-
fn = os.path.join(tmpdir, str(dds), "%s.0.h5" % str(dds))
64-
ds = load(fn)
65-
ad = ds.all_data()
66-
mi, ma = ad.quantities.extrema("nstars")
67-
mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones")
27+
from yt_astro_analysis.utilities.testing import data_dir_load
6828

69-
os.chdir(curdir)
70-
shutil.rmtree(tmpdir)
7129

72-
return np.array([mean, mi, ma])
30+
@pytest.fixture
31+
def nstars_defined():
32+
def _nstars(halo):
33+
sp = halo.data_object
34+
return (sp["all", "creation_time"] > 0).sum()
7335

74-
def compare(self, new_result, old_result):
75-
assert_equal(new_result, old_result, verbose=True)
36+
add_quantity("nstars", _nstars)
37+
yield
38+
_remove_quantity("nstars")
7639

7740

7841
rh0 = "rockstar_halos/halos_0.0.bin"
7942
e64 = "Enzo_64/DD0043/data0043"
8043

8144

82-
@requires_ds(rh0)
83-
@requires_ds(e64)
84-
def test_halo_quantity():
85-
yield HaloQuantityTest(e64, rh0)
45+
@requires_file(rh0)
46+
@requires_file(e64)
47+
@pytest.mark.usefixtures("nstars_defined")
48+
def test_halo_quantity(tmp_path):
49+
data_ds_fn = e64
50+
halos_ds_fn = rh0
51+
ds = data_dir_load(data_ds_fn)
52+
53+
dds = data_dir_load(data_ds_fn)
54+
hds = data_dir_load(halos_ds_fn)
55+
hc = HaloCatalog(data_ds=dds, halos_ds=hds, output_dir=str(tmp_path))
56+
hc.add_callback("sphere")
57+
hc.add_quantity("nstars")
58+
hc.create()
59+
60+
fn = tmp_path / str(dds) / f"{dds}.0.h5"
61+
ds = load(fn)
62+
ad = ds.all_data()
63+
mi, ma = ad.quantities.extrema("nstars")
64+
mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones")
65+
66+
npt.assert_equal(
67+
un.unyt_array([mean, mi, ma]),
68+
[28.533783783783782, 0.0, 628.0] * un.dimensionless,
69+
)

yt_astro_analysis/halo_analysis/tests/test_halo_finders.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
2-
import shutil
32
import sys
4-
import tempfile
3+
4+
import pytest
5+
from unyt.testing import assert_allclose_units
56

67
from yt.frontends.halo_catalog.data_structures import YTHaloCatalogDataset
78
from yt.frontends.rockstar.data_structures import RockstarDataset
89
from yt.loaders import load
9-
from yt.utilities.answer_testing.framework import FieldValuesTest, requires_ds
10+
from yt.testing import requires_file
1011

1112
_fields = (
1213
("halos", "particle_position_x"),
@@ -21,28 +22,28 @@
2122
etiny = "enzo_tiny_cosmology/DD0046/DD0046"
2223

2324

24-
@requires_ds(etiny, big_data=True)
25-
def test_halo_finders_single():
25+
@requires_file(etiny)
26+
def test_halo_finders_single(tmp_path):
27+
pytest.importorskip("mpi4py")
2628
from mpi4py import MPI
2729

28-
tmpdir = tempfile.mkdtemp()
29-
curdir = os.getcwd()
30-
os.chdir(tmpdir)
30+
os.chdir(tmp_path)
3131

3232
filename = os.path.join(os.path.dirname(__file__), "run_halo_finder.py")
3333
for method in methods:
3434
comm = MPI.COMM_SELF.Spawn(
35-
sys.executable, args=[filename, method, tmpdir], maxprocs=methods[method]
35+
sys.executable,
36+
args=[filename, method, str(tmp_path)],
37+
maxprocs=methods[method],
3638
)
3739
comm.Disconnect()
3840

3941
if method == "rockstar":
4042
hcfn = "halos_0.0.bin"
4143
else:
4244
hcfn = os.path.join("DD0046", "DD0046.0.h5")
43-
fn = os.path.join(tmpdir, "halo_catalogs", method, hcfn)
4445

45-
ds = load(fn)
46+
ds = load(tmp_path / "halo_catalogs" / method / hcfn)
4647
if method == "rockstar":
4748
ds.parameters["format_revision"] = 2
4849
ds_type = RockstarDataset
@@ -51,11 +52,16 @@ def test_halo_finders_single():
5152
assert isinstance(ds, ds_type)
5253

5354
for field in _fields:
54-
my_test = FieldValuesTest(
55-
ds, field, particle_type=True, decimals=decimals[method]
55+
obj = ds.all_data()
56+
field = obj._determine_fields(field)[0]
57+
# fd = ds.field_info[field]
58+
weight_field = (field[0], "particle_ones")
59+
avg = obj.quantities.weighted_average_quantity(field, weight=weight_field)
60+
mi, ma = obj.quantities.extrema(field)
61+
assert_allclose_units(
62+
[avg, mi, ma],
63+
[1, 2, 3],
64+
10.0 ** (-decimals[method]),
65+
err_msg=f"Field values for {field} not equal.",
66+
verbose=True,
5667
)
57-
my_test.suffix = method
58-
yield my_test
59-
60-
os.chdir(curdir)
61-
shutil.rmtree(tmpdir)

0 commit comments

Comments
 (0)