Skip to content

Commit 1ca9c48

Browse files
committed
ac automata in python
1 parent 6b50ac0 commit 1ca9c48

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
Aho-Corasick Algorithm
3+
4+
Author: Wenru Dong
5+
"""
6+
7+
from collections import deque
8+
from typing import List
9+
10+
class ACNode:
11+
def __init__(self, data: str):
12+
self._data = data
13+
self._children = [None] * 26
14+
self._is_ending_char = False
15+
self._length = -1
16+
self._suffix = None
17+
18+
19+
class ACAutomata:
20+
def __init__(self):
21+
self._root = ACNode("/")
22+
23+
def _build_suffix_link(self) -> None:
24+
q = deque()
25+
q.append(self._root)
26+
while q:
27+
node = q.popleft()
28+
for child in node._children:
29+
if child:
30+
if node == self._root:
31+
child._suffix = self._root
32+
else:
33+
suffix = node._suffix
34+
while suffix:
35+
suffix_child = suffix._children[ord(child._data) - ord("a")]
36+
if suffix_child:
37+
child._suffix = suffix_child
38+
break
39+
suffix = suffix._suffix
40+
if not suffix:
41+
child._suffix = self._root
42+
q.append(child)
43+
44+
def _insert(self, text: str) -> None:
45+
node = self._root
46+
for index, char in map(lambda x: (ord(x) - ord("a"), x), text):
47+
if not node._children[index]:
48+
node._children[index] = ACNode(char)
49+
node = node._children[index]
50+
node._is_ending_char = True
51+
node._length = len(text)
52+
53+
def insert(self, patterns: List[str]) -> None:
54+
for pattern in patterns:
55+
self._insert(pattern)
56+
self._build_suffix_link()
57+
58+
def match(self, text: str) -> None:
59+
node = self._root
60+
for i, char in enumerate(text):
61+
index = ord(char) - ord("a")
62+
while not node._children[index] and node != self._root:
63+
node = node._suffix
64+
node = node._children[index]
65+
if not node:
66+
node = self._root
67+
tmp = node
68+
while tmp != self._root:
69+
if tmp._is_ending_char:
70+
print(f"匹配起始下标{i - tmp._length + 1},长度{tmp._length}")
71+
tmp = tmp._suffix
72+
73+
74+
if __name__ == "__main__":
75+
76+
patterns = ["at", "art", "oars", "soar"]
77+
ac = ACAutomata()
78+
ac.insert(patterns)
79+
80+
ac.match("soarsoars")

0 commit comments

Comments
 (0)