Skip to content

Commit e34ff8c

Browse files
committed
Apriori: implement join step of apriori-gen
The apriori-gen function described in section 2.1.1 of Apriori paper has two steps; first, the join step looks for itemsets with the same prefix, and creates new candidates by appending all pairs combinations to this prefix. Here is pseudocode copied from paper: select p.1, p.2, ..., p.k-1, q.k-1 from p in L(k-1), q in L(k-1) where p.1 = q.1, ..., p.k-2 = q.k-2, p.k-1 < q.k-1 The reason is that if a sequence q with the same prefix as p does not belong to L(k-1), itemset p+(q.k-1,) cannot be frequent. Before this commit, we were considering p+(q.k-1,) for any q.k-1 > p.k-1. The second step of apriori-gen function is called prune step, it will be implemented in a distinct commit. See discussion in #644.
1 parent 1308f6b commit e34ff8c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

mlxtend/frequent_patterns/apriori.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def generate_new_combinations(old_combinations):
3030
3131
Returns
3232
-----------
33-
Generator of all combinations from the last step x items
34-
from the previous step.
33+
Generator of combinations based on the last state of Apriori algorithm.
34+
In order to reduce number of candidates, this function implements the
35+
join step of apriori-gen described in section 2.1.1 of Apriori paper.
36+
Prune step is not yet implemented.
3537
3638
Examples
3739
-----------
@@ -40,15 +42,17 @@ def generate_new_combinations(old_combinations):
4042
4143
"""
4244

43-
items_types_in_previous_step = np.unique(old_combinations.flatten())
44-
for old_combination in old_combinations:
45-
max_combination = old_combination[-1]
46-
mask = items_types_in_previous_step > max_combination
47-
valid_items = items_types_in_previous_step[mask]
48-
old_tuple = tuple(old_combination)
49-
for item in valid_items:
50-
yield from old_tuple
51-
yield item
45+
length = len(old_combinations)
46+
for i, old_combination in enumerate(old_combinations):
47+
head_i = list(old_combination[:-1])
48+
j = i + 1
49+
while j < length:
50+
*head_j, tail_j = old_combinations[j]
51+
if head_i != head_j:
52+
break
53+
yield from old_combination
54+
yield tail_j
55+
j = j + 1
5256

5357

5458
def generate_new_combinations_low_memory(old_combinations, X, min_support,

0 commit comments

Comments
 (0)