File tree Expand file tree Collapse file tree 4 files changed +50
-6
lines changed Expand file tree Collapse file tree 4 files changed +50
-6
lines changed Original file line number Diff line number Diff line change @@ -2,17 +2,23 @@ class ClassifierTrainerJob < ApplicationJob
2
2
# Job to train classifier asynchronously
3
3
queue_as :training
4
4
5
- def perform ( trained_message )
5
+ def perform ( trained_messages )
6
6
Rails . logger . info "Retrain all the classifiers for public"
7
- if trained_message . user_name?
7
+ # Separate messages by their training target
8
+ user_name_messages = trained_messages . select ( &:user_name? )
9
+ message_content_messages = trained_messages . select ( &:message_content? )
10
+
11
+ if user_name_messages . any?
8
12
GroupClassifierState . username . find_each do |classifier |
9
13
spam_classifier = SpamClassifierService . new ( classifier . group_id , classifier . group_name )
10
- spam_classifier . train ( trained_message )
14
+ spam_classifier . train_batch ( user_name_messages )
11
15
end
12
- elsif trained_message . message_content?
16
+ end
17
+
18
+ if message_content_messages . any?
13
19
GroupClassifierState . for_group . find_each do |classifier |
14
20
spam_classifier = SpamClassifierService . new ( classifier . group_id , classifier . group_name )
15
- spam_classifier . train ( trained_message )
21
+ spam_classifier . train_batch ( message_content_messages )
16
22
end
17
23
end
18
24
end
Original file line number Diff line number Diff line change @@ -57,7 +57,7 @@ def should_ban_user
57
57
def retrain_classifier
58
58
return if untrained?
59
59
60
- ClassifierTrainerJob . perform_later ( self )
60
+ ClassifierTrainerJob . perform_later ( [ self ] )
61
61
end
62
62
63
63
private
Original file line number Diff line number Diff line change @@ -61,6 +61,13 @@ def train(trained_message)
61
61
@classifier_state . save!
62
62
end
63
63
64
+ def train_batch ( trained_messages )
65
+ trained_messages . each do |trained_message |
66
+ train_only ( trained_message )
67
+ end
68
+ @classifier_state . save!
69
+ end
70
+
64
71
def classify ( message_text )
65
72
# P(Spam|Words) = P(Words|Spam) * P(Spam) / P(Words)
66
73
# Return false if the model isn't trained enough
Original file line number Diff line number Diff line change @@ -188,6 +188,37 @@ def setup
188
188
assert spam_score > ham_score , "Spam score should be higher than ham score"
189
189
end
190
190
191
+ test "#train_batch train a list of messages and identify spam message correctly" do
192
+ service = SpamClassifierService . new ( @group_id , @group_name )
193
+ service . train_batch ( [
194
+ TrainedMessage . new (
195
+ group_id : @group_id ,
196
+ message : "便宜的伟哥现在买" ,
197
+ message_type : :spam ,
198
+ sender_chat_id : 1 ,
199
+ sender_user_name : "s"
200
+ ) ,
201
+ TrainedMessage . new (
202
+ group_id : @group_id ,
203
+ message : "免费点击这里" ,
204
+ message_type : :spam ,
205
+ sender_chat_id : 1 ,
206
+ sender_user_name : "s"
207
+ ) ,
208
+ TrainedMessage . new (
209
+ group_id : @group_id ,
210
+ message : "你好,今天天气不错" ,
211
+ message_type : :ham ,
212
+ sender_chat_id : 2 ,
213
+ sender_user_name : "s"
214
+ )
215
+ ] )
216
+ is_spam , spam_score , ham_score = service . classify ( "点击这里买伟哥" )
217
+
218
+ assert is_spam , "Message should be classified as spam"
219
+ assert spam_score > ham_score , "Spam score should be higher than ham score"
220
+ end
221
+
191
222
test "#classify should correctly identify a message as ham" do
192
223
service = SpamClassifierService . new ( @group_id , @group_name )
193
224
service . train ( TrainedMessage . new (
You can’t perform that action at this time.
0 commit comments