Skip to content

Commit 961c8eb

Browse files
cthoytdjinnome
andcommitted
Give functions better names
Co-Authored-By: Jeremy Zucker <[email protected]>
1 parent 59f7094 commit 961c8eb

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

src/y0/algorithm/identify/id_c.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def rule_2_of_do_calculus_applies(identification: Identification, condition: Var
4949
treatments = identification.treatments
5050
conditions = treatments | (identification.conditions - {condition})
5151

52-
# TODO give a better name
53-
graph_mod = graph.intervene(treatments).remove_outgoing_edges_from([condition])
52+
graph_mod = graph.remove_in_edges(treatments).remove_out_edges(condition)
5453

5554
judgements = [
5655
are_d_separated(graph_mod, outcome, condition, conditions=conditions)

src/y0/algorithm/identify/id_std.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def identify(identification: Identification) -> Expression:
3131
return identify(line_2(identification))
3232

3333
# line 3
34-
intervened_graph = graph.intervene(treatments)
34+
intervened_graph = graph.remove_in_edges(treatments)
3535
no_effect_on_outcome = (vertices - treatments) - intervened_graph.ancestors_inclusive(outcomes)
3636
if no_effect_on_outcome:
3737
return identify(line_3(identification))
@@ -142,7 +142,7 @@ def line_3(identification: Identification) -> Identification:
142142
graph = identification.graph
143143
vertices = set(graph.nodes())
144144

145-
intervened_graph = graph.intervene(treatments)
145+
intervened_graph = graph.remove_in_edges(treatments)
146146
no_effect_on_outcome = (vertices - treatments) - intervened_graph.ancestors_inclusive(outcomes)
147147
if not no_effect_on_outcome:
148148
raise ValueError(

src/y0/graph.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,39 +354,39 @@ def subgraph(self, vertices: Collection[Variable]) -> NxMixedGraph:
354354
undirected=_include_adjacent(self.undirected, vertices),
355355
)
356356

357-
def intervene(self, vertices: Collection[Variable]) -> NxMixedGraph:
357+
def remove_in_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
358358
"""Return a mutilated graph given a set of interventions.
359359
360360
:param vertices: a subset of nodes from which to remove incoming edges
361361
:returns: A NxMixedGraph subgraph
362362
"""
363-
vertices = set(vertices)
363+
vertices = _ensure_set(vertices)
364364
return self.from_edges(
365365
nodes=vertices,
366366
directed=_exclude_target(self.directed, vertices),
367367
undirected=_exclude_adjacent(self.undirected, vertices),
368368
)
369369

370-
def remove_nodes_from(self, vertices: Collection[Variable]) -> NxMixedGraph:
370+
def remove_nodes_from(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
371371
"""Return a subgraph that does not contain any of the specified vertices.
372372
373373
:param vertices: a set of nodes to remove from graph
374374
:returns: A NxMixedGraph subgraph
375375
"""
376-
vertices = set(vertices)
376+
vertices = _ensure_set(vertices)
377377
return self.from_edges(
378378
nodes=self.nodes() - vertices,
379379
directed=_exclude_adjacent(self.directed, vertices),
380380
undirected=_exclude_adjacent(self.undirected, vertices),
381381
)
382382

383-
def remove_outgoing_edges_from(self, vertices: Collection[Variable]) -> NxMixedGraph:
383+
def remove_out_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
384384
"""Return a subgraph that does not have any outgoing edges from any of the given vertices.
385385
386386
:param vertices: a set of nodes whose outgoing edges get removed from the graph
387387
:returns: NxMixedGraph subgraph
388388
"""
389-
vertices = set(vertices)
389+
vertices = _ensure_set(vertices)
390390
return self.from_edges(
391391
nodes=self.nodes(),
392392
directed=_exclude_source(self.directed, vertices),
@@ -552,5 +552,9 @@ def _get_latex(node) -> str:
552552
raise TypeError
553553

554554

555+
def _ensure_set(vertices: Union[Variable, Iterable[Variable]]) -> set[Variable]:
556+
return {vertices} if isinstance(vertices, Variable) else set(vertices)
557+
558+
555559
class NoAnankeError(TypeError):
556560
"""Thrown when an :mod:`ananke` graph was used but a y0 NxMixedGraph should have been used."""

tests/test_graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ def test_intervention(self):
143143
directed=[("X", "Y"), ("Z", "X")],
144144
undirected=[("X", "Z"), ("X", "Y"), ("Y", "Z")],
145145
)
146-
self.assertEqual(graph, graph.intervene(set()))
146+
self.assertEqual(graph, graph.remove_in_edges(set()))
147147

148148
intervened_graph = NxMixedGraph.from_str_edges(
149149
directed=[("X", "Y")],
150150
undirected=[("Z", "Y")],
151151
)
152-
self.assertEqual(intervened_graph, graph.intervene({X}))
152+
self.assertEqual(intervened_graph, graph.remove_in_edges({X}))
153153

154154
def test_remove_nodes_from(self):
155155
"""Test generating a new graph without the given nodes."""
@@ -165,14 +165,14 @@ def test_remove_nodes_from(self):
165165
def test_remove_outgoing_edges_from(self):
166166
"""Test generating a new graph without the outgoing edgs from the given nodes."""
167167
graph = NxMixedGraph.from_str_edges(directed=[("X", "Y")])
168-
self.assertEqual(graph, graph.remove_outgoing_edges_from(set()))
168+
self.assertEqual(graph, graph.remove_out_edges(set()))
169169

170170
graph = NxMixedGraph.from_str_edges(undirected=[("X", "Y")])
171-
self.assertEqual(graph, graph.remove_outgoing_edges_from(set()))
171+
self.assertEqual(graph, graph.remove_out_edges(set()))
172172

173173
graph = NxMixedGraph.from_str_edges(directed=[("W", "X"), ("X", "Y"), ("Y", "Z")])
174174
expected = NxMixedGraph.from_str_edges(directed=[("W", "X"), ("Y", "Z")])
175-
self.assertEqual(expected, graph.remove_outgoing_edges_from({X}))
175+
self.assertEqual(expected, graph.remove_out_edges({X}))
176176

177177
def test_ancestors_inclusive(self):
178178
"""Test getting ancestors, inclusive."""

0 commit comments

Comments
 (0)