|
| 1 | +import math |
| 2 | + |
| 3 | +from typing import ( |
| 4 | + Callable, |
| 5 | + Collection, |
| 6 | + Dict, |
| 7 | + FrozenSet, |
| 8 | + Generic, |
| 9 | + ItemsView, |
| 10 | + Iterable, |
| 11 | + KeysView, |
| 12 | + List, |
| 13 | + Optional, |
| 14 | + Set, |
| 15 | + TypeVar, |
| 16 | + Union, |
| 17 | +) |
| 18 | + |
| 19 | +import graphviz |
| 20 | +import networkx as nx |
| 21 | + |
| 22 | +N = TypeVar("N") |
| 23 | + |
| 24 | + |
| 25 | +class DiGraph(nx.DiGraph, Generic[N]): |
| 26 | + def __init__(self, *args, **kwargs): |
| 27 | + super().__init__(*args, **kwargs) |
| 28 | + self._dominator_forest: Optional[DiGraph[N]] = None |
| 29 | + self._roots: Optional[Collection[N]] = None |
| 30 | + self._path_lengths: Optional[Dict[N, Dict[N, int]]] = None |
| 31 | + |
| 32 | + def path_length(self, from_node: N, to_node: N) -> Union[int, float]: |
| 33 | + if self._path_lengths is None: |
| 34 | + self._path_lengths = dict(nx.all_pairs_shortest_path_length(self, cutoff=None)) |
| 35 | + if from_node not in self._path_lengths or to_node not in self._path_lengths[from_node]: |
| 36 | + return math.inf |
| 37 | + else: |
| 38 | + return self._path_lengths[from_node][to_node] |
| 39 | + |
| 40 | + def set_roots(self, roots: Collection[N]): |
| 41 | + self._roots = roots |
| 42 | + |
| 43 | + def _find_roots(self) -> Iterable[N]: |
| 44 | + return (n for n, d in self.in_degree() if d == 0) |
| 45 | + |
| 46 | + @property |
| 47 | + def roots(self) -> Collection[N]: |
| 48 | + if self._roots is None: |
| 49 | + self._roots = tuple(self._find_roots()) |
| 50 | + return self._roots |
| 51 | + |
| 52 | + def depth(self, node: N) -> Union[int, float]: |
| 53 | + return min(self.path_length(root, node) for root in self.roots) |
| 54 | + |
| 55 | + def ancestors(self, node: N) -> Set[N]: |
| 56 | + return nx.ancestors(self, node) |
| 57 | + |
| 58 | + def descendants(self, node: N) -> FrozenSet[N]: |
| 59 | + return frozenset(nx.dfs_successors(self, node).keys()) |
| 60 | + |
| 61 | + @property |
| 62 | + def dominator_forest(self) -> "DAG[N]": |
| 63 | + if self._dominator_forest is not None: |
| 64 | + return self._dominator_forest |
| 65 | + self._dominator_forest = DAG() |
| 66 | + for root in self.roots: |
| 67 | + for node, dominated_by in nx.immediate_dominators(self, root).items(): |
| 68 | + if node != dominated_by: |
| 69 | + self._dominator_forest.add_edge(dominated_by, node) |
| 70 | + return self._dominator_forest |
| 71 | + |
| 72 | + def to_dot( |
| 73 | + self, comment: Optional[str] = None, labeler: Optional[Callable[[N], str]] = None, node_filter=None |
| 74 | + ) -> graphviz.Digraph: |
| 75 | + if comment is not None: |
| 76 | + dot = graphviz.Digraph(comment=comment) |
| 77 | + else: |
| 78 | + dot = graphviz.Digraph() |
| 79 | + if labeler is None: |
| 80 | + labeler = str |
| 81 | + node_ids = {node: i for i, node in enumerate(self.nodes)} |
| 82 | + for node in self.nodes: |
| 83 | + if node_filter is None or node_filter(node): |
| 84 | + dot.node(f"func{node_ids[node]}", label=labeler(node)) |
| 85 | + for caller, callee in self.edges: |
| 86 | + if node_filter is None or (node_filter(caller) and node_filter(callee)): |
| 87 | + dot.edge(f"func{node_ids[caller]}", f"func{node_ids[callee]}") |
| 88 | + return dot |
| 89 | + |
| 90 | + |
| 91 | +class DAG(DiGraph[N], Generic[N]): |
| 92 | + def vertex_induced_subgraph(self, vertices: Iterable[N]) -> "DAG[N]": |
| 93 | + vertices = frozenset(vertices) |
| 94 | + subgraph = self.copy() |
| 95 | + to_remove = set(self.nodes) - vertices |
| 96 | + for v in vertices: |
| 97 | + node = v |
| 98 | + parent = None |
| 99 | + while True: |
| 100 | + parents = tuple(subgraph.predecessors(node)) |
| 101 | + if not parents: |
| 102 | + if parent is not None: |
| 103 | + subgraph.remove_edge(parent, v) |
| 104 | + subgraph.add_edge(node, v) |
| 105 | + break |
| 106 | + assert len(parents) == 1 |
| 107 | + ancestor = parents[0] |
| 108 | + if parent is None: |
| 109 | + parent = ancestor |
| 110 | + if ancestor in vertices: |
| 111 | + to_remove.add(v) |
| 112 | + break |
| 113 | + node = ancestor |
| 114 | + subgraph.remove_nodes_from(to_remove) |
| 115 | + return subgraph |
| 116 | + |
| 117 | + |
| 118 | +class FunctionInfo: |
| 119 | + def __init__( |
| 120 | + self, |
| 121 | + name: str, |
| 122 | + cmp_bytes: Dict[str, List[int]], |
| 123 | + input_bytes: Dict[str, List[int]] = None, |
| 124 | + called_from: Iterable[str] = (), |
| 125 | + ): |
| 126 | + self.name: str = name |
| 127 | + self.called_from: FrozenSet[str] = frozenset(called_from) |
| 128 | + self.cmp_bytes: Dict[str, List[int]] = cmp_bytes |
| 129 | + if input_bytes is None: |
| 130 | + self.input_bytes: Dict[str, List[int]] = cmp_bytes |
| 131 | + else: |
| 132 | + self.input_bytes = input_bytes |
| 133 | + |
| 134 | + @property |
| 135 | + def taint_sources(self) -> KeysView[str]: |
| 136 | + return self.input_bytes.keys() |
| 137 | + |
| 138 | + def __getitem__(self, input_source_name: str) -> List[int]: |
| 139 | + return self.input_bytes[input_source_name] |
| 140 | + |
| 141 | + def __iter__(self) -> Iterable[str]: |
| 142 | + return self.taint_sources |
| 143 | + |
| 144 | + def items(self) -> ItemsView[str, List[int]]: |
| 145 | + return self.input_bytes.items() |
| 146 | + |
| 147 | + def __hash__(self): |
| 148 | + return hash(self.name) |
| 149 | + |
| 150 | + def __str__(self): |
| 151 | + return self.name |
| 152 | + |
| 153 | + def __repr__(self): |
| 154 | + return f"{self.__class__.__name__}(name={self.name!r}, cmp_bytes={self.cmp_bytes!r}, input_bytes={self.input_bytes!r}, called_from={self.called_from!r})" |
| 155 | + |
| 156 | + |
| 157 | +class CFG(DiGraph[FunctionInfo]): |
| 158 | + def __init__(self): |
| 159 | + super().__init__() |
| 160 | + |
| 161 | + def to_dot( |
| 162 | + self, |
| 163 | + comment: Optional[str] = "PolyTracker Program Trace", |
| 164 | + labeler: Optional[Callable[[FunctionInfo], str]] = None, |
| 165 | + node_filter=None, |
| 166 | + ) -> graphviz.Digraph: |
| 167 | + function_labels: Dict[str, str] = {} |
| 168 | + |
| 169 | + def func_labeler(f): |
| 170 | + if labeler is not None: |
| 171 | + return labeler(f) |
| 172 | + elif f.name in function_labels: |
| 173 | + return f"{f.name} ({function_labels[f.name]})" |
| 174 | + else: |
| 175 | + return f.name |
| 176 | + |
| 177 | + return super().to_dot(comment, labeler=func_labeler, node_filter=node_filter) |
0 commit comments