@@ -341,13 +341,13 @@ def from_causalfusion_json(cls, data: Mapping[str, Any]) -> NxMixedGraph:
341341 raise ValueError (f'unhandled edge type: { edge ["type" ]} ' )
342342 return rv
343343
344- def subgraph (self , vertices : Collection [Variable ]) -> NxMixedGraph :
344+ def subgraph (self , vertices : Union [Variable , Iterable [ Variable ] ]) -> NxMixedGraph :
345345 """Return a subgraph given a set of vertices.
346346
347347 :param vertices: a subset of nodes
348348 :returns: A NxMixedGraph subgraph
349349 """
350- vertices = set (vertices )
350+ vertices = _ensure_set (vertices )
351351 return self .from_edges (
352352 nodes = vertices ,
353353 directed = _include_adjacent (self .directed , vertices ),
@@ -393,8 +393,9 @@ def remove_out_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxM
393393 undirected = self .undirected .edges (),
394394 )
395395
396- def ancestors_inclusive (self , sources : Iterable [Variable ]) -> set [Variable ]:
396+ def ancestors_inclusive (self , sources : Union [ Variable , Iterable [Variable ] ]) -> set [Variable ]:
397397 """Ancestors of a set include the set itself."""
398+ sources = _ensure_set (sources )
398399 return _ancestors_inclusive (self .directed , sources )
399400
400401 def topological_sort (self ) -> Iterable [Variable ]:
@@ -414,33 +415,34 @@ def is_connected(self) -> bool:
414415 return nx .is_connected (self .undirected )
415416
416417
417- def _ancestors_inclusive (graph : nx .DiGraph , sources : Iterable [Variable ]) -> set [Variable ]:
418- rv = set (sources )
419- for source in sources :
420- rv . update ( nx . algorithms . dag . ancestors ( graph , source ) )
421- return rv
418+ def _ancestors_inclusive (graph : nx .DiGraph , sources : set [Variable ]) -> set [Variable ]:
419+ ancestors = set (
420+ itt . chain . from_iterable ( nx . algorithms . dag . ancestors ( graph , source ) for source in sources )
421+ )
422+ return sources | ancestors
422423
423424
424425def _include_adjacent (
425- graph : nx .Graph , vertices : Collection [Variable ]
426+ graph : nx .Graph , vertices : set [Variable ]
426427) -> Collection [Tuple [Variable , Variable ]]:
428+ vertices = _ensure_set (vertices )
427429 return [(u , v ) for u , v in graph .edges () if u in vertices and v in vertices ]
428430
429431
430432def _exclude_source (
431- graph : nx .Graph , vertices : Collection [Variable ]
433+ graph : nx .Graph , vertices : set [Variable ]
432434) -> Collection [Tuple [Variable , Variable ]]:
433435 return [(u , v ) for u , v in graph .edges () if u not in vertices ]
434436
435437
436438def _exclude_target (
437- graph : nx .Graph , vertices : Collection [Variable ]
439+ graph : nx .Graph , vertices : set [Variable ]
438440) -> Collection [Tuple [Variable , Variable ]]:
439441 return [(u , v ) for u , v in graph .edges () if v not in vertices ]
440442
441443
442444def _exclude_adjacent (
443- graph : nx .Graph , vertices : Collection [Variable ]
445+ graph : nx .Graph , vertices : set [Variable ]
444446) -> Collection [Tuple [Variable , Variable ]]:
445447 return [(u , v ) for u , v in graph .edges () if u not in vertices and v not in vertices ]
446448
0 commit comments