Skip to content

Commit ddd7370

Browse files
Allow MACE-OFF model download failure (#684)
1 parent a70c9b4 commit ddd7370

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

tests/test_mlip_calculators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
("mace_mp", "cpu", {"model": "small"}),
8080
("mace_mp", "cpu", {"model": MACE_MP_PATH}),
8181
("mace_off", "cpu", {}),
82-
("mace_off", "cpu", {"model": "small"}),
82+
("mace_off", "cpu", {"model": "medium"}),
8383
("mace_off", "cpu", {"model": MACE_OFF_PATH}),
8484
("mace_omol", "cpu", {}),
8585
("mace_omol", "cpu", {"model": "extra_large"}),
@@ -118,6 +118,9 @@ def test_mlips(arch, device, kwargs):
118118
except URLError as err:
119119
if "Connection timed out" in err.reason:
120120
pytest.skip("Model download failed")
121+
except RuntimeError as err:
122+
if "Model download failed" in str(err):
123+
pytest.skip("Model download failed")
121124
raise err
122125

123126

tests/test_single_point.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def test_extras(arch, device, expected_energy, struct, kwargs):
144144
if "Connection timed out" in err.reason:
145145
pytest.skip("Model download failed")
146146
raise err
147+
except RuntimeError as err:
148+
if "Model download failed" in str(err):
149+
pytest.skip("Model download failed")
150+
raise err
147151

148152

149153
def test_single_point_none():
@@ -467,28 +471,33 @@ def test_dispersion(arch, kwargs, pred):
467471
skip_extras(arch)
468472
pytest.importorskip("torch_dftd")
469473

470-
data_path = DATA_PATH / "benzene.xyz"
471-
sp_no_d3 = SinglePoint(
472-
struct=data_path,
473-
arch=arch,
474-
properties="energy",
475-
calc_kwargs={"dispersion": False},
476-
)
477-
assert not isinstance(sp_no_d3.struct.calc, SumCalculator)
478-
no_d3_results = sp_no_d3.run()
474+
try:
475+
data_path = DATA_PATH / "benzene.xyz"
476+
sp_no_d3 = SinglePoint(
477+
struct=data_path,
478+
arch=arch,
479+
properties="energy",
480+
calc_kwargs={"dispersion": False},
481+
)
482+
assert not isinstance(sp_no_d3.struct.calc, SumCalculator)
483+
no_d3_results = sp_no_d3.run()
479484

480-
sp_d3 = SinglePoint(
481-
struct=data_path,
482-
arch=arch,
483-
properties="energy",
484-
calc_kwargs={"dispersion": True, "dispersion_kwargs": {**kwargs}},
485-
)
486-
assert isinstance(sp_d3.struct.calc, SumCalculator)
487-
d3_results = sp_d3.run()
485+
sp_d3 = SinglePoint(
486+
struct=data_path,
487+
arch=arch,
488+
properties="energy",
489+
calc_kwargs={"dispersion": True, "dispersion_kwargs": {**kwargs}},
490+
)
491+
assert isinstance(sp_d3.struct.calc, SumCalculator)
492+
d3_results = sp_d3.run()
488493

489-
assert (d3_results["energy"] - no_d3_results["energy"]) == pytest.approx(
490-
pred, rel=1e-5
491-
)
494+
assert (d3_results["energy"] - no_d3_results["energy"]) == pytest.approx(
495+
pred, rel=1e-5
496+
)
497+
except RuntimeError as err:
498+
if "Model download failed" in str(err):
499+
pytest.skip("Model download failed")
500+
raise err
492501

493502

494503
def test_mace_mp_dispersion():

0 commit comments

Comments
 (0)