Skip to content

Commit 12af876

Browse files
Set defaults, test in test_correlator
1 parent e09636c commit 12af876

File tree

4 files changed

+44
-40
lines changed

4 files changed

+44
-40
lines changed

janus_core/cli/utils.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -328,23 +328,23 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
328328
The parsed correlation_kwargs for md.
329329
"""
330330
parsed_kwargs = []
331-
for name, kwarg in kwargs.value.items():
332-
if "a" not in kwarg and "b" not in kwarg:
331+
for name, cli_kwargs in kwargs.value.items():
332+
if "a" not in cli_kwargs and "b" not in cli_kwargs:
333333
raise ValueError("At least on observable must be supplied as 'a' or 'b'")
334334

335335
# Accept on Observable to be replicated.
336-
if "b" not in kwarg:
337-
a = kwarg["a"]
336+
if "b" not in cli_kwargs:
337+
a = cli_kwargs["a"]
338338
b = a
339-
elif "a" not in kwarg:
340-
a = kwarg["b"]
339+
elif "a" not in cli_kwargs:
340+
a = cli_kwargs["b"]
341341
b = a
342342
else:
343-
a = kwarg["a"]
344-
b = kwarg["b"]
343+
a = cli_kwargs["a"]
344+
b = cli_kwargs["b"]
345345

346-
a_kwargs = kwarg["a_kwargs"] if "a_kwargs" in kwarg else {}
347-
b_kwargs = kwarg["b_kwargs"] if "b_kwargs" in kwarg else {}
346+
a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
347+
b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
348348

349349
# Accept "." in place of one kwargs to repeat.
350350
if a_kwargs == "." and b_kwargs == ".":
@@ -354,22 +354,15 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
354354
elif b_kwargs and a_kwargs == ".":
355355
a_kwargs = b_kwargs
356356

357-
blocks = kwarg["blocks"] if "blocks" in kwarg else 1
358-
points = kwarg["points"] if "points" in kwarg else 1
359-
averaging = kwarg["averaging"] if "averaging" in kwarg else 1
360-
update_frequency = (
361-
kwarg["update_frequency"] if "update_frequency" in kwarg else 1
362-
)
357+
cor_kwargs = {
358+
"name": name,
359+
"a": getattr(observables, a)(**a_kwargs),
360+
"b": getattr(observables, b)(**b_kwargs),
361+
}
363362

364-
parsed_kwargs.append(
365-
{
366-
"name": name,
367-
"a": getattr(observables, a)(**a_kwargs),
368-
"b": getattr(observables, b)(**b_kwargs),
369-
"blocks": blocks,
370-
"points": points,
371-
"averaging": averaging,
372-
"update_frequency": update_frequency,
373-
}
374-
)
363+
for optional in ["blocks", "points", "averaging", "update_frequency"]:
364+
if optional in cli_kwargs:
365+
cor_kwargs[optional] = cli_kwargs[optional]
366+
367+
parsed_kwargs.append(cor_kwargs)
375368
return parsed_kwargs

janus_core/processing/correlator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,13 @@ class Correlation:
253253
Observable for b.
254254
name
255255
Name of correlation.
256-
blocks
256+
blocks : int, default 1
257257
Number of correlation blocks.
258-
points
258+
points : int, default 1
259259
Number of points per block.
260-
averaging
260+
averaging : int, default 1
261261
Averaging window per block level.
262-
update_frequency
262+
update_frequency : int, default 1
263263
Frequency to update the correlation, md steps.
264264
"""
265265

@@ -269,10 +269,10 @@ def __init__(
269269
a: Observable,
270270
b: Observable,
271271
name: str,
272-
blocks: int,
273-
points: int,
274-
averaging: int,
275-
update_frequency: int,
272+
blocks: int = 1,
273+
points: int = 1,
274+
averaging: int = 1,
275+
update_frequency: int = 1,
276276
) -> None:
277277
"""
278278
Initialise a correlation.
@@ -285,13 +285,13 @@ def __init__(
285285
Observable for b.
286286
name
287287
Name of correlation.
288-
blocks
288+
blocks : int, default 1
289289
Number of correlation blocks.
290-
points
290+
points : int, default 1
291291
Number of points per block.
292-
averaging
292+
averaging : int, default 1
293293
Averaging window per block level.
294-
update_frequency
294+
update_frequency : int, default 1
295295
Frequency to update the correlation, md steps.
296296
"""
297297
self.name = name

tests/test_correlator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def test_vaf(tmp_path):
113113
"averaging": 1,
114114
"update_frequency": 1,
115115
},
116+
{
117+
"a": Velocity(),
118+
"b": Velocity(),
119+
"name": "vaf_default",
120+
},
116121
],
117122
write_kwargs={"invalidate_calc": False},
118123
)
@@ -133,6 +138,12 @@ def test_vaf(tmp_path):
133138
assert vaf_na * 3 == approx(vaf_post[1][0], rel=1e-5)
134139
assert vaf_cl * 3 == approx(vaf_post[1][1], rel=1e-5)
135140

141+
# Default arguments are equivalent to mean square velocities.
142+
v = np.mean([np.mean(atoms.get_velocities() ** 2) for atoms in traj])
143+
vaf_default = vaf["vaf_default"]
144+
assert len(vaf_default["value"]) == 1
145+
assert v == approx(vaf_default["value"][0], rel=1e-5)
146+
136147

137148
def test_md_correlations(tmp_path):
138149
"""Test correlations as part of MD cycle."""

tests/test_md_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_md(ensemble):
104104
"--correlation-kwargs",
105105
(
106106
"{'vaf': {'a': 'Velocity'},"
107-
" 'vaf_x': {'a': 'velocity',"
107+
" 'vaf_x': {'a': 'Velocity',"
108108
"'a_kwargs': {'components': ['x']}, 'b_kwargs': '.'}}"
109109
),
110110
],

0 commit comments

Comments
 (0)