Skip to content

Commit 33c0d69

Browse files
committed
update tests
1 parent 095a912 commit 33c0d69

File tree

2 files changed

+92
-16
lines changed

2 files changed

+92
-16
lines changed

tests/test_load_graph.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import pytest
1010

11+
from risk.cluster import define_domains
1112
from risk.network.graph._summary import Summary
1213

1314

@@ -831,6 +832,49 @@ def test_primary_domain_labels_are_disjoint(graph):
831832
)
832833

833834

835+
def test_define_domains_handles_safeguard_row_drop_without_global_fallback():
836+
"""
837+
Ensure dropped linkage rows do not desynchronize annotation-domain assignment.
838+
A zero-variance significant annotation should be assigned a deterministic unique
839+
domain, while non-significant annotations remain unassigned (domain 0).
840+
"""
841+
# Two significant annotations: one degenerate (zero-variance) and one clusterable.
842+
top_annotation = pd.DataFrame(
843+
{
844+
"significant_annotation": [True, True, False],
845+
"full_terms": ["term_a", "term_b", "term_c"],
846+
"significant_cluster_significance_sums": [1.0, 2.0, 0.0],
847+
"significant_significance_score": [1.0, 2.0, 0.0],
848+
},
849+
index=["term_a", "term_b", "term_c"],
850+
)
851+
# term_a is dropped by safeguard (constant column), term_b is retained, term_c is non-significant.
852+
significant_clusters_significance = np.array(
853+
[
854+
[5.0, 1.0, 0.0],
855+
[5.0, 2.0, 0.0],
856+
[5.0, 3.0, 0.0],
857+
[5.0, 4.0, 0.0],
858+
],
859+
dtype=float,
860+
)
861+
862+
domains = define_domains(
863+
top_annotation=top_annotation,
864+
significant_clusters_significance=significant_clusters_significance,
865+
linkage_criterion="distance",
866+
linkage_method="average",
867+
linkage_metric="euclidean",
868+
linkage_threshold=0.2,
869+
)
870+
871+
assert top_annotation.loc["term_c", "domain"] == 0
872+
assert top_annotation.loc["term_a", "domain"] > 0
873+
assert top_annotation.loc["term_b", "domain"] > 0
874+
assert top_annotation.loc["term_a", "domain"] != top_annotation.loc["term_b", "domain"]
875+
assert (domains["primary_domain"] > 0).any()
876+
877+
834878
def _validate_graph(graph):
835879
"""
836880
Validate that the graph is not None and contains nodes and edges.

tests/test_load_network.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_initialize_risk(risk, verbose_setting):
1818
Test RISK instance initialization with verbose parameter.
1919
2020
Args:
21+
risk: Factory fixture that returns a configured RISK instance.
2122
verbose_setting: Boolean value to set verbosity of the RISK instance.
2223
"""
2324
try:
@@ -127,20 +128,17 @@ def test_load_network_networkx(risk_obj, dummy_network):
127128
128129
Args:
129130
risk_obj: The RISK object instance used for loading the network.
130-
network: The NetworkX graph object to be loaded into the RISK network.
131+
dummy_network: The NetworkX graph object to be loaded into the RISK network.
131132
"""
132133
network = risk_obj.load_network_networkx(network=dummy_network)
133134

134135
assert network is not None
135136
assert len(network.nodes) > 0 # Check that the graph has nodes
136137
assert len(network.edges) > 0 # Check that the graph has edges
137-
# Additional checks to verify the properties of the loaded graph
138138
for node in network.nodes:
139-
# Check that each node in the original network is in the RISK network
140139
assert node in network.nodes
141140

142141
for edge in network.edges:
143-
# Check that each edge in the original network is in the RISK network
144142
assert edge in network.edges
145143

146144

@@ -151,36 +149,30 @@ def test_round_trip_io(risk_obj):
151149
Args:
152150
risk_obj: The RISK object instance used for loading the network.
153151
"""
154-
# Create a small test graph
155152
G = nx.Graph()
156153
G.add_edges_from([(0, 1), (1, 2)])
157-
# Add node positions as required for network loading
158154
G.nodes[0]["x"] = 0.0
159155
G.nodes[0]["y"] = 0.0
160156
G.nodes[1]["x"] = 1.0
161157
G.nodes[1]["y"] = 1.0
162158
G.nodes[2]["x"] = 2.0
163159
G.nodes[2]["y"] = 2.0
164160

165-
# Ensure the tmp directory exists under data/tmp
166161
tmp_dir = os.path.join("data", "tmp")
167162
os.makedirs(tmp_dir, exist_ok=True)
168163
tmp_path = os.path.join(tmp_dir, "test_round_trip_io.gpickle")
169164

170165
try:
171-
# Save the graph using pickle
172166
with open(tmp_path, "wb") as f:
173167
pickle.dump(G, f)
174168

175-
# Load it back using risk_obj's load_network_gpickle
176169
G_loaded = risk_obj.load_network_gpickle(filepath=tmp_path)
177170

178-
# Compare properties of the graphs - RISK sets node IDs to 'label' attribute when no label is present
171+
# Graph structure should survive a simple pickle round-trip via loader.
179172
assert set(G.nodes()) == set(G_loaded.nodes())
180173
assert set(G.edges()) == set(G_loaded.edges())
181174

182175
finally:
183-
# Always remove the temporary file at the end
184176
if os.path.exists(tmp_path):
185177
os.remove(tmp_path)
186178

@@ -192,17 +184,14 @@ def test_node_positions_constant_after_networkx_load(risk_obj, dummy_network):
192184
193185
Args:
194186
risk_obj: The RISK object instance used for loading the network.
195-
network: The NetworkX graph object to be loaded into the RISK network.
187+
dummy_network: The NetworkX graph object to be loaded into the RISK network.
196188
"""
197-
# Store the original positions of nodes from the dummy_network
198189
original_positions = {
199190
node: (dummy_network.nodes[node]["x"], dummy_network.nodes[node]["y"])
200191
for node in dummy_network.nodes
201192
}
202-
# Pass the network to the load function, and ignore the returned network
203193
_ = risk_obj.load_network_networkx(network=dummy_network)
204194

205-
# Ensure that the original network (dummy_network) still has the same node positions
206195
for node in dummy_network.nodes:
207196
assert (
208197
"x" in dummy_network.nodes[node]
@@ -353,13 +342,56 @@ def test_sphere_unfolding(risk_obj, data_path):
353342
assert -1 <= attrs["y"] <= 1, f"Node {node} 'y' coordinate is out of bounds"
354343

355344

345+
@pytest.mark.parametrize(
346+
"compute_sphere,node_coords",
347+
[
348+
# Flat x-axis (all x identical): exercises safe normalization without sphere mapping.
349+
(False, {"a": (1.0, 0.0), "b": (1.0, 1.0), "c": (1.0, 2.0)}),
350+
# Flat y-axis (all y identical): exercises safe normalization with sphere mapping enabled.
351+
(True, {"a": (0.0, 2.0), "b": (1.0, 2.0), "c": (2.0, 2.0)}),
352+
],
353+
)
354+
def test_flat_axis_coordinates_remain_finite_after_loading(risk_obj, compute_sphere, node_coords):
355+
"""
356+
Ensure coordinate normalization is stable when x or y has zero range.
357+
358+
Args:
359+
risk_obj: The RISK object instance used for loading the network.
360+
compute_sphere: Whether to enable spherical coordinate processing.
361+
node_coords: Mapping of node ids to (x, y) coordinates.
362+
"""
363+
network = nx.Graph()
364+
network.add_edges_from([("a", "b"), ("b", "c")])
365+
for node, (x_coord, y_coord) in node_coords.items():
366+
network.nodes[node]["x"] = x_coord
367+
network.nodes[node]["y"] = y_coord
368+
network.nodes[node]["label"] = node
369+
370+
loaded = risk_obj.load_network_networkx(
371+
network=network,
372+
compute_sphere=compute_sphere,
373+
surface_depth=0.2,
374+
min_edges_per_node=0,
375+
)
376+
377+
# Guard against NaN/Inf coordinate propagation from zero-range normalization.
378+
for node, attrs in loaded.nodes(data=True):
379+
assert np.isfinite(attrs["x"]), f"Node {node} has non-finite x coordinate"
380+
assert np.isfinite(attrs["y"]), f"Node {node} has non-finite y coordinate"
381+
382+
# Edge lengths must remain finite and strictly positive for downstream clustering/layout.
383+
for u, v, attrs in loaded.edges(data=True):
384+
assert np.isfinite(attrs["length"]), f"Edge ({u}, {v}) has non-finite length"
385+
assert attrs["length"] > 0, f"Edge ({u}, {v}) has non-positive length"
386+
387+
356388
def test_edge_attribute_fallback(risk_obj, dummy_network):
357389
"""
358390
Test fallback when edges are missing 'length' or 'weight' attributes.
359391
360392
Args:
361393
risk_obj: The RISK object instance used for loading the network.
362-
dummy_network: The Cytoscape network to be loaded into the R
394+
dummy_network: The network used to validate edge attribute fallback behavior.
363395
"""
364396
# Remove 'length' and 'weight' attributes
365397
for u, v in dummy_network.edges():

0 commit comments

Comments
 (0)