Skip to content

Commit 1021a47

Browse files
committed
Add train_batch support to train multiple messages
1 parent 1954353 commit 1021a47

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed

app/jobs/classifier_trainer_job.rb

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@ class ClassifierTrainerJob < ApplicationJob
22
# Job to train classifier asynchronously
33
queue_as :training
44

5-
def perform(trained_message)
5+
def perform(trained_messages)
66
Rails.logger.info "Retrain all the classifiers for public"
7-
if trained_message.user_name?
7+
# Separate messages by their training target
8+
user_name_messages = trained_messages.select(&:user_name?)
9+
message_content_messages = trained_messages.select(&:message_content?)
10+
11+
if user_name_messages.any?
812
GroupClassifierState.username.find_each do |classifier|
913
spam_classifier = SpamClassifierService.new(classifier.group_id, classifier.group_name)
10-
spam_classifier.train(trained_message)
14+
spam_classifier.train_batch(user_name_messages)
1115
end
12-
elsif trained_message.message_content?
16+
end
17+
18+
if message_content_messages.any?
1319
GroupClassifierState.for_group.find_each do |classifier|
1420
spam_classifier = SpamClassifierService.new(classifier.group_id, classifier.group_name)
15-
spam_classifier.train(trained_message)
21+
spam_classifier.train_batch(message_content_messages)
1622
end
1723
end
1824
end

app/models/trained_message.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def should_ban_user
5757
def retrain_classifier
5858
return if untrained?
5959

60-
ClassifierTrainerJob.perform_later(self)
60+
ClassifierTrainerJob.perform_later([ self ])
6161
end
6262

6363
private

app/services/spam_classifier_service.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ def train(trained_message)
6161
@classifier_state.save!
6262
end
6363

64+
def train_batch(trained_messages)
65+
trained_messages.each do |trained_message|
66+
train_only(trained_message)
67+
end
68+
@classifier_state.save!
69+
end
70+
6471
def classify(message_text)
6572
# P(Spam|Words) = P(Words|Spam) * P(Spam) / P(Words)
6673
# Return false if the model isn't trained enough

test/services/spam_classifier_service_test.rb

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,37 @@ def setup
188188
assert spam_score > ham_score, "Spam score should be higher than ham score"
189189
end
190190

191+
test "#train_batch train a list of messages and identify spam message correctly" do
192+
service = SpamClassifierService.new(@group_id, @group_name)
193+
service.train_batch([
194+
TrainedMessage.new(
195+
group_id: @group_id,
196+
message: "便宜的伟哥现在买",
197+
message_type: :spam,
198+
sender_chat_id: 1,
199+
sender_user_name: "s"
200+
),
201+
TrainedMessage.new(
202+
group_id: @group_id,
203+
message: "免费点击这里",
204+
message_type: :spam,
205+
sender_chat_id: 1,
206+
sender_user_name: "s"
207+
),
208+
TrainedMessage.new(
209+
group_id: @group_id,
210+
message: "你好,今天天气不错",
211+
message_type: :ham,
212+
sender_chat_id: 2,
213+
sender_user_name: "s"
214+
)
215+
])
216+
is_spam, spam_score, ham_score = service.classify("点击这里买伟哥")
217+
218+
assert is_spam, "Message should be classified as spam"
219+
assert spam_score > ham_score, "Spam score should be higher than ham score"
220+
end
221+
191222
test "#classify should correctly identify a message as ham" do
192223
service = SpamClassifierService.new(@group_id, @group_name)
193224
service.train(TrainedMessage.new(

0 commit comments

Comments
 (0)