Skip to content

Commit 015a7d8

Browse files
authored
fix: decompose cycles which are too close (#31)
1 parent f5f429a commit 015a7d8

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

src/epsearch/_cycle.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def argmatch_from_closest_masked(
132132
def get_cycles(
133133
eigvals: Sequence[Sequence[TNumber]] | np.ndarray[tuple[int, int], np.dtype[TNumber]],
134134
/,
135+
*,
136+
decompose_threshold_diff_factor: float = 2,
137+
decompose_threshold_rate: float = 0.5,
135138
) -> Cycles[TNumber]:
136139
"""
137140
Get cycles from the eigenvalues for each point.
@@ -140,6 +143,16 @@ def get_cycles(
140143
----------
141144
eigvals : Sequence[Sequence[TNumber]] | np.ndarray[tuple[int, int], np.dtype[TNumber]]
142145
A (ordered) sequence which contains the eigenvalues for each point.
146+
decompose_threshold_diff_factor : float, optional
147+
If the rate of elements which are closer than
148+
max(mean difference) * decompose_threshold_diff_factor
149+
is higher than decompose_threshold_rate,
150+
the cycle is decomposed, by default 0.5.
151+
decompose_threshold_rate : float, optional
152+
If the rate of elements which are closer than
153+
max(mean difference) * decompose_threshold_diff_factor
154+
is higher than decompose_threshold_rate,
155+
the cycle is decomposed, by default 0.5.
143156
144157
Returns
145158
-------
@@ -181,6 +194,23 @@ def get_cycles(
181194
arg = argmatch_from_closest(eigvals_c[-1, :], eigvals_c[0, :])
182195
G = nx.DiGraph(list(enumerate(arg[~arg.mask])))
183196
cycles = list(nx.simple_cycles(G))
197+
198+
# decompose cycles which are too close
199+
for cycle in cycles:
200+
mean_diff = np.max(np.mean(np.abs(np.diff(eigvals_c[:, cycle], axis=0)), axis=0))
201+
if (
202+
np.mean(
203+
np.abs(eigvals_c[:, cycle[0]][:, None] - eigvals_c)
204+
< mean_diff * decompose_threshold_diff_factor
205+
)
206+
> decompose_threshold_rate
207+
):
208+
G.remove_edges_from(
209+
[(cycle[i], cycle[(i + 1) % len(cycle)]) for i in range(len(cycle))]
210+
)
211+
G.add_edges_from([(cycle[i], cycle[i]) for i in range(len(cycle))])
212+
cycles = list(nx.simple_cycles(G))
213+
184214
# order by length
185215
cycles = sorted(cycles, key=len, reverse=True)
186216
return Cycles(

0 commit comments

Comments
 (0)