Skip to content

Commit 3a84c35

Browse files
committed
Raise if intervention variable is used in the wrong place
1 parent 0f0b9e4 commit 3a84c35

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/y0/graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from networkx.classes.reportviews import NodeView
1515
from networkx.utils import open_file
1616

17-
from .dsl import CounterfactualVariable, Variable, vmap_adj, vmap_pairs
17+
from .dsl import CounterfactualVariable, Intervention, Variable, vmap_adj, vmap_pairs
1818

1919
__all__ = [
2020
"NxMixedGraph",
@@ -553,7 +553,10 @@ def _get_latex(node) -> str:
553553

554554

555555
def _ensure_set(vertices: Union[Variable, Iterable[Variable]]) -> set[Variable]:
556-
return {vertices} if isinstance(vertices, Variable) else set(vertices)
556+
rv = {vertices} if isinstance(vertices, Variable) else set(vertices)
557+
if any(isinstance(v, Intervention) for v in rv):
558+
raise TypeError("can not use interventions here")
559+
return rv
557560

558561

559562
class NoAnankeError(TypeError):

tests/test_graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def test_intervention(self):
150150
undirected=[("Z", "Y")],
151151
)
152152
self.assertEqual(intervened_graph, graph.remove_in_edges({X}))
153+
self.assertEqual(intervened_graph, graph.remove_in_edges(X))
154+
155+
with self.assertRaises(TypeError):
156+
self.assertEqual(intervened_graph, graph.remove_in_edges({-X}))
153157

154158
def test_remove_nodes_from(self):
155159
"""Test generating a new graph without the given nodes."""

0 commit comments

Comments
 (0)