Skip to content

Commit b83e78c

Browse files
committed
optimize method to efficiently retrain all classifiers
1 parent 70699f9 commit b83e78c

File tree

2 files changed

+64
-44
lines changed

2 files changed

+64
-44
lines changed

app/services/spam_classifier_service.rb

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ class SpamClassifierService
55

66
attr_reader :group_id, :classifier_state, :group_name
77

8-
def initialize(group_id, group_name)
8+
def initialize(group_id, group_name, classifier_state: nil)
99
@group_id = group_id
1010
@group_name = group_name
11-
@classifier_state = GroupClassifierState.find_or_create_by!(group_id: @group_id) do |new_state|
11+
@classifier_state = classifier_state || GroupClassifierState.find_or_create_by!(group_id: @group_id) do |new_state|
1212
# Find the most recently updated classifier for group to use as a template.
1313
template = GroupClassifierState.for_group.order(updated_at: :desc).first
1414
if template
@@ -35,25 +35,28 @@ def initialize(group_id, group_name)
3535
end
3636

3737
def train_only(trained_message)
38+
# Lazily initialize the vocabulary set ONCE per service instance
39+
@vocabulary ||= Set.new((@classifier_state.spam_counts.keys + @classifier_state.ham_counts.keys))
40+
3841
tokens = tokenize(trained_message.message)
39-
vocabulary = Set.new((@classifier_state.spam_counts.keys + @classifier_state.ham_counts.keys))
42+
4043
if trained_message.spam?
4144
@classifier_state.total_spam_messages += 1
4245
@classifier_state.total_spam_words += tokens.size
4346
tokens.each do |token|
4447
@classifier_state.spam_counts[token] = @classifier_state.spam_counts.fetch(token, 0) + 1
45-
vocabulary.add(token)
48+
@vocabulary.add(token)
4649
end
4750
else # :ham
4851
@classifier_state.total_ham_messages += 1
4952
@classifier_state.total_ham_words += tokens.size
5053
tokens.each do |token|
5154
@classifier_state.ham_counts[token] = @classifier_state.ham_counts.fetch(token, 0) + 1
52-
vocabulary.add(token)
55+
@vocabulary.add(token)
5356
end
5457
end
5558

56-
@classifier_state.vocabulary_size = vocabulary.size
59+
@classifier_state.vocabulary_size = @vocabulary.size
5760
end
5861

5962
def train(trained_message)
@@ -110,41 +113,6 @@ def classify(message_text)
110113
[ is_spam, spam_score, ham_score ]
111114
end
112115

113-
class << self
114-
def rebuild_for_group(group_id, group_name)
115-
service = new(group_id, group_name)
116-
service.rebuild_classifier
117-
end
118-
end
119-
120-
def rebuild_classifier
121-
Rails.logger.info "Rebuild classifier for group_id: #{group_id}"
122-
messages_to_train = if group_id == GroupClassifierState::USER_NAME_CLASSIFIER_GROUP_ID
123-
TrainedMessage.trainable.for_user_name
124-
else
125-
TrainedMessage.trainable.for_message_content
126-
end
127-
128-
ActiveRecord::Base.transaction do
129-
classifier_state.update!(
130-
group_name: group_name,
131-
spam_counts: {},
132-
ham_counts: {},
133-
total_spam_words: 0,
134-
total_ham_words: 0,
135-
total_spam_messages: 0,
136-
total_ham_messages: 0,
137-
vocabulary_size: 0
138-
)
139-
140-
# Retrain from all trainable messages
141-
messages_to_train.find_each do |message|
142-
train_only(message)
143-
end
144-
classifier_state.save!
145-
end
146-
end
147-
148116
def tokenize(text)
149117
cleaned_text = clean_text(text)
150118
# This regex pre-tokenizes the string into 4 groups:
@@ -219,4 +187,58 @@ def pure_numbers?(token)
219187
# Check if token contains only numbers (Arabic or Chinese)
220188
token.match?(/^[0-9一二三四五六七八九十百千万亿零]+$/)
221189
end
190+
191+
class << self
192+
def rebuild_all_public
193+
Rails.logger.info "Starting rebuild for all public classifiers..."
194+
195+
# 1. Load all classifier states from the DB
196+
classifier_states = GroupClassifierState.for_public.index_by(&:group_id)
197+
198+
# Reset stats in memory before starting
199+
classifier_states.each_value do |state|
200+
state.spam_counts = {}
201+
state.ham_counts = {}
202+
state.total_spam_words = 0
203+
state.total_ham_words = 0
204+
state.total_spam_messages = 0
205+
state.total_ham_messages = 0
206+
state.vocabulary_size = 0
207+
end
208+
209+
# 2. Create a service instance for each state, injecting the state object
210+
# This avoids all redundant database lookups.
211+
services = classifier_states.transform_values do |state|
212+
new(state.group_id, state.group_name, classifier_state: state)
213+
end
214+
215+
# 3. Process each category of messages ONCE
216+
user_name_service = services[GroupClassifierState::USER_NAME_CLASSIFIER_GROUP_ID]
217+
group_services = services.values.reject { |s| s.group_id == GroupClassifierState::USER_NAME_CLASSIFIER_GROUP_ID }
218+
219+
# Process user name messages
220+
if user_name_service
221+
TrainedMessage.trainable.for_user_name.find_each do |message|
222+
user_name_service.train_only(message)
223+
end
224+
end
225+
226+
# Process group content messages
227+
TrainedMessage.trainable.for_message_content.find_each do |message|
228+
group_services.each do |service|
229+
service.train_only(message)
230+
end
231+
end
232+
233+
# 4. Save everything in one transaction
234+
ActiveRecord::Base.transaction do
235+
services.each_value do |service|
236+
Rails.logger.info "Saving classifier for group_id: #{service.group_id}"
237+
service.classifier_state.save!
238+
end
239+
end
240+
241+
Rails.logger.info "Classifier rebuild completed!"
242+
end
243+
end
222244
end

lib/tasks/data_migration.rake

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ namespace :data_migration do
3636
end
3737
desc "Retrain all classifier"
3838
task retrain_all_classifier: :environment do
39-
GroupClassifierState.for_public.find_each do |classifier|
40-
SpamClassifierService.rebuild_for_group(classifier.group_id, classifier.group_name)
41-
end
39+
SpamClassifierService.rebuild_all_public
4240
end
4341
end

0 commit comments

Comments
 (0)