3
3
4
4
use std:: sync:: atomic:: Ordering ;
5
5
6
- use crate :: error:: Result ;
7
- use sea_orm:: { ConnectionTrait , DatabaseConnection , DbErr , Statement , UpdateResult } ;
8
- use std:: time:: Duration ;
9
- use tokio:: time:: sleep;
6
+ use crate :: error:: { Error , Result } ;
7
+ use sea_orm:: {
8
+ ConnectionTrait , DatabaseConnection , DbErr , ExecResult , QueryResult , Statement ,
9
+ TransactionTrait , UpdateResult ,
10
+ } ;
11
+ use std:: time:: { Duration , Instant } ;
12
+
13
+ type DbResult < T > = std:: result:: Result < T , DbErr > ;
14
+
15
+ async fn exec_without_timeout ( pool : & DatabaseConnection , stmt : Statement ) -> DbResult < ExecResult > {
16
+ let increase_timeout = Statement :: from_string (
17
+ pool. get_database_backend ( ) ,
18
+ "SET LOCAL statement_timeout=0;" ,
19
+ ) ;
20
+ let tx = pool. begin ( ) . await ?;
21
+ let _ = tx. execute ( increase_timeout) . await ?;
22
+ let res = tx. execute ( stmt) . await ?;
23
+ tx. commit ( ) . await ?;
24
+ Ok ( res)
25
+ }
26
+ async fn query_one_without_timeout (
27
+ pool : & DatabaseConnection ,
28
+ stmt : Statement ,
29
+ ) -> DbResult < Option < QueryResult > > {
30
+ let increase_timeout = Statement :: from_string (
31
+ pool. get_database_backend ( ) ,
32
+ "SET LOCAL statement_timeout=0;" ,
33
+ ) ;
34
+ let tx = pool. begin ( ) . await ?;
35
+ let _ = tx. execute ( increase_timeout) . await ?;
36
+ let res = tx. query_one ( stmt) . await ?;
37
+ tx. commit ( ) . await ?;
38
+ Ok ( res)
39
+ }
10
40
11
41
/// Nullifies the payload column for expired messages,
12
42
/// `limit` sets how many rows to update at a time.
13
43
pub async fn clean_expired_messages (
14
44
pool : & DatabaseConnection ,
15
45
limit : u32 ,
16
- ) -> std:: result:: Result < UpdateResult , DbErr > {
17
- let legacy_stmt = Statement :: from_sql_and_values (
18
- pool. get_database_backend ( ) ,
19
- r#"
46
+ enable_legacy_message_cleaner : bool ,
47
+ ) -> DbResult < UpdateResult > {
48
+ // See the docs for [`has_message_payloads_pending_expiry`] for background on the legacy cleaner.
49
+ let legacy_row_count = if enable_legacy_message_cleaner {
50
+ let legacy_res = {
51
+ let legacy_stmt = Statement :: from_sql_and_values (
52
+ pool. get_database_backend ( ) ,
53
+ r#"
20
54
UPDATE message SET payload = NULL WHERE id IN (
21
55
SELECT id FROM message
22
56
WHERE
@@ -26,9 +60,15 @@ pub async fn clean_expired_messages(
26
60
FOR UPDATE SKIP LOCKED
27
61
)
28
62
"# ,
29
- [ limit. into ( ) ] ,
30
- ) ;
31
- let legacy_res = pool. execute ( legacy_stmt) . await ?;
63
+ [ limit. into ( ) ] ,
64
+ ) ;
65
+
66
+ exec_without_timeout ( pool, legacy_stmt) . await ?
67
+ } ;
68
+ legacy_res. rows_affected ( )
69
+ } else {
70
+ 0
71
+ } ;
32
72
33
73
let stmt = Statement :: from_sql_and_values (
34
74
pool. get_database_backend ( ) ,
@@ -48,32 +88,79 @@ pub async fn clean_expired_messages(
48
88
let res = pool. execute ( stmt) . await ?;
49
89
50
90
Ok ( UpdateResult {
51
- rows_affected : legacy_res . rows_affected ( ) + res. rows_affected ( ) ,
91
+ rows_affected : legacy_row_count + res. rows_affected ( ) ,
52
92
} )
53
93
}
54
94
95
+ /// Checks to see if the message table has any non-null payloads requiring expiry.
96
+ ///
97
+ /// ## Background
98
+ ///
99
+ /// Initially payloads were modeled as a field in `message`, but later migrated to a separate
100
+ /// table (`messagecontent`). In cases where there are no longer any payloads to expire in `message` we
101
+ /// can avoid the expense of running the cleaner on the `message` table since all new messages should now be using
102
+ /// `messagecontent`.
103
+ async fn has_message_payloads_pending_expiry ( pool : & DatabaseConnection ) -> Result < bool > {
104
+ query_one_without_timeout (
105
+ pool,
106
+ Statement :: from_string (
107
+ pool. get_database_backend ( ) ,
108
+ r#"SELECT EXISTS (SELECT 1 FROM message WHERE payload IS NOT NULL LIMIT 1)"# ,
109
+ ) ,
110
+ )
111
+ . await ?
112
+ . ok_or_else ( || Error :: generic ( "failed to check for message payloads" ) ) ?
113
+ . try_get_by_index ( 0 )
114
+ . map_err ( |e| Error :: generic ( format ! ( "failed to check for message payloads: {e}" ) ) )
115
+ }
116
+
55
117
/// Polls the database for expired messages to nullify payloads for.
56
118
///
57
119
/// Uses a variable polling schedule, based on affected row counts each iteration of the loop.
58
120
pub async fn expired_message_cleaner_loop ( pool : & DatabaseConnection ) -> Result < ( ) > {
121
+ let message_table_needs_cleaning = has_message_payloads_pending_expiry ( pool) . await ?;
122
+ if !message_table_needs_cleaning {
123
+ tracing:: info!( "No payloads pending expiry found in `message` table. Skipping the cleaner for this table." ) ;
124
+ }
125
+
59
126
// When no rows have been updated, widen the interval.
60
- const IDLE : Duration = Duration :: from_secs ( 10 ) ;
127
+ const IDLE : Duration = Duration :: from_secs ( 60 * 60 * 12 ) ;
61
128
// When the affected row count dips below this, switch to the `SLOWING` interval.
62
- const SLOWING_THRESHOLD : u64 = 1_000 ;
63
- const SLOWING : Duration = Duration :: from_secs ( 3 ) ;
129
+ const SLOWING_THRESHOLD : u64 = 5_000 ;
130
+ const SLOWING : Duration = Duration :: from_secs ( 60 * 60 * 12 ) ;
131
+ const ON_ERROR : Duration = Duration :: from_secs ( 10 ) ;
64
132
const BATCH_SIZE : u32 = 5_000 ;
65
- let mut sleep_time = Some ( IDLE ) ;
133
+ let mut sleep_time = None ;
66
134
loop {
67
135
if let Some ( duration) = sleep_time {
68
- sleep ( duration) . await ;
136
+ let sleep_start = Instant :: now ( ) ;
137
+ let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 10 ) ) ;
138
+ interval. tick ( ) . await ;
139
+ // Doing a plain sleep() was fine when the polling frequency was mere seconds, but since we're doing wider
140
+ // periods now (hours, not seconds), we need to be a little more careful about not preventing the process
141
+ // from shutting down.
142
+ // Using `interval()` so we can track how long we've been sleeping for, while still checking for the
143
+ // shutdown signal.
144
+ ' inner: loop {
145
+ if crate :: SHUTTING_DOWN . load ( Ordering :: SeqCst ) {
146
+ return Ok ( ( ) ) ;
147
+ }
148
+ interval. tick ( ) . await ;
149
+ if sleep_start. elapsed ( ) > duration {
150
+ break ' inner;
151
+ }
152
+ }
69
153
}
70
- match clean_expired_messages ( pool, BATCH_SIZE ) . await {
154
+
155
+ let start = Instant :: now ( ) ;
156
+ match clean_expired_messages ( pool, BATCH_SIZE , message_table_needs_cleaning) . await {
71
157
Err ( err) => {
72
158
tracing:: error!( "{}" , err) ;
159
+ sleep_time = Some ( ON_ERROR ) ;
73
160
}
74
161
Ok ( UpdateResult { rows_affected } ) => {
75
162
if rows_affected > 0 {
76
- tracing:: trace! ( "expired {} payloads" , rows_affected) ;
163
+ tracing:: debug! ( elapsed =? start . elapsed ( ) , "expired {} payloads" , rows_affected) ;
77
164
}
78
165
79
166
sleep_time = match rows_affected {
0 commit comments