Skip to content

Commit e425009

Browse files
committed
Increase function flexibility
1 parent cf51066 commit e425009

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/y0/graph.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,13 @@ def from_causalfusion_json(cls, data: Mapping[str, Any]) -> NxMixedGraph:
341341
raise ValueError(f'unhandled edge type: {edge["type"]}')
342342
return rv
343343

344-
def subgraph(self, vertices: Collection[Variable]) -> NxMixedGraph:
344+
def subgraph(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
345345
"""Return a subgraph given a set of vertices.
346346
347347
:param vertices: a subset of nodes
348348
:returns: A NxMixedGraph subgraph
349349
"""
350-
vertices = set(vertices)
350+
vertices = _ensure_set(vertices)
351351
return self.from_edges(
352352
nodes=vertices,
353353
directed=_include_adjacent(self.directed, vertices),
@@ -393,8 +393,9 @@ def remove_out_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxM
393393
undirected=self.undirected.edges(),
394394
)
395395

396-
def ancestors_inclusive(self, sources: Iterable[Variable]) -> set[Variable]:
396+
def ancestors_inclusive(self, sources: Union[Variable, Iterable[Variable]]) -> set[Variable]:
397397
"""Ancestors of a set include the set itself."""
398+
sources = _ensure_set(sources)
398399
return _ancestors_inclusive(self.directed, sources)
399400

400401
def topological_sort(self) -> Iterable[Variable]:
@@ -414,33 +415,34 @@ def is_connected(self) -> bool:
414415
return nx.is_connected(self.undirected)
415416

416417

417-
def _ancestors_inclusive(graph: nx.DiGraph, sources: Iterable[Variable]) -> set[Variable]:
418-
rv = set(sources)
419-
for source in sources:
420-
rv.update(nx.algorithms.dag.ancestors(graph, source))
421-
return rv
418+
def _ancestors_inclusive(graph: nx.DiGraph, sources: set[Variable]) -> set[Variable]:
419+
ancestors = set(
420+
itt.chain.from_iterable(nx.algorithms.dag.ancestors(graph, source) for source in sources)
421+
)
422+
return sources | ancestors
422423

423424

424425
def _include_adjacent(
425-
graph: nx.Graph, vertices: Collection[Variable]
426+
graph: nx.Graph, vertices: set[Variable]
426427
) -> Collection[Tuple[Variable, Variable]]:
428+
vertices = _ensure_set(vertices)
427429
return [(u, v) for u, v in graph.edges() if u in vertices and v in vertices]
428430

429431

430432
def _exclude_source(
431-
graph: nx.Graph, vertices: Collection[Variable]
433+
graph: nx.Graph, vertices: set[Variable]
432434
) -> Collection[Tuple[Variable, Variable]]:
433435
return [(u, v) for u, v in graph.edges() if u not in vertices]
434436

435437

436438
def _exclude_target(
437-
graph: nx.Graph, vertices: Collection[Variable]
439+
graph: nx.Graph, vertices: set[Variable]
438440
) -> Collection[Tuple[Variable, Variable]]:
439441
return [(u, v) for u, v in graph.edges() if v not in vertices]
440442

441443

442444
def _exclude_adjacent(
443-
graph: nx.Graph, vertices: Collection[Variable]
445+
graph: nx.Graph, vertices: set[Variable]
444446
) -> Collection[Tuple[Variable, Variable]]:
445447
return [(u, v) for u, v in graph.edges() if u not in vertices and v not in vertices]
446448

tests/test_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def test_ancestors_inclusive(self):
189189
directed=[("X", "Z"), ("Z", "Y")], undirected=[("X", "Y")]
190190
)
191191
self.assertEqual({X, Y, Z}, graph.ancestors_inclusive({Y}))
192+
self.assertEqual({X, Y, Z}, graph.ancestors_inclusive(Y))
192193
self.assertEqual({X, Z}, graph.ancestors_inclusive({Z}))
193194
self.assertEqual({X}, graph.ancestors_inclusive({X}))
194195

0 commit comments

Comments
 (0)