11use std:: net:: IpAddr ;
2- use std:: time:: Duration ;
32
43use bloom:: { CountingBloomFilter , ASMS } ;
54use tokio:: time:: Instant ;
65use url:: Url ;
76
87use crate :: servers:: udp:: UDP_TRACKER_LOG_TARGET ;
98
10- /// The maximum number of connection id errors per ip. Clients will be banned if
11- /// they exceed this limit.
12- pub const MAX_CONNECTION_ID_ERRORS_PER_IP : u32 = 10 ;
13- pub const RESET_CONNECTION_ID_ERRORS_COUNTER_FREQUENCY_IN_SECS : u64 = 3600 ;
14-
159pub struct BanService {
1610 max_connection_id_errors_per_ip : u32 ,
17- ban_duration : Duration ,
1811 connection_id_errors_per_ip : CountingBloomFilter ,
1912 local_addr : Url ,
2013 last_connection_id_errors_reset : Instant ,
2114}
2215
2316impl BanService {
2417 #[ must_use]
25- pub fn new ( max_connection_id_errors_per_ip : u32 , duration_in_seconds : u64 , local_addr : Url ) -> Self {
18+ pub fn new ( max_connection_id_errors_per_ip : u32 , local_addr : Url ) -> Self {
2619 Self {
2720 max_connection_id_errors_per_ip,
28- ban_duration : Duration :: from_secs ( duration_in_seconds) ,
2921 local_addr,
3022 connection_id_errors_per_ip : CountingBloomFilter :: with_rate ( 4 , 0.01 , 100 ) ,
3123 last_connection_id_errors_reset : tokio:: time:: Instant :: now ( ) ,
@@ -48,14 +40,8 @@ impl BanService {
4840 connection_id_errors_from_ip > self . max_connection_id_errors_per_ip
4941 }
5042
51- pub fn run_bans_cleaner ( & mut self ) {
52- if self . last_connection_id_errors_reset . elapsed ( ) >= self . ban_duration {
53- self . reset_filter ( ) ;
54- }
55- }
56-
5743 /// Resets the filter and updates the reset timestamp.
58- pub fn reset_filter ( & mut self ) {
44+ pub fn reset_bans ( & mut self ) {
5945 self . connection_id_errors_per_ip . clear ( ) ;
6046
6147 self . last_connection_id_errors_reset = Instant :: now ( ) ;
@@ -68,23 +54,18 @@ impl BanService {
6854#[ cfg( test) ]
6955mod tests {
7056 use std:: net:: IpAddr ;
71- use std:: time:: Duration ;
72-
73- use tokio:: time:: sleep;
7457
7558 use super :: BanService ;
7659
7760 /// Sample service with one day ban duration.
78- fn service_with_one_day_ban ( counter_limit : u32 ) -> BanService {
79- let one_day_in_seconds = 86400 ;
61+ fn ban_service ( counter_limit : u32 ) -> BanService {
8062 let udp_tracker_url = "udp://127.0.0.1" . parse ( ) . unwrap ( ) ;
81-
82- BanService :: new ( counter_limit, one_day_in_seconds, udp_tracker_url)
63+ BanService :: new ( counter_limit, udp_tracker_url)
8364 }
8465
8566 #[ test]
8667 fn it_should_increase_the_ip_counter ( ) {
87- let mut ban_service = service_with_one_day_ban ( 1 ) ;
68+ let mut ban_service = ban_service ( 1 ) ;
8869
8970 let ip: IpAddr = "127.0.0.2" . parse ( ) . unwrap ( ) ;
9071
@@ -95,7 +76,7 @@ mod tests {
9576
9677 #[ test]
9778 fn it_should_ban_ips_with_counters_exceeding_a_predefined_limit ( ) {
98- let mut ban_service = service_with_one_day_ban ( 1 ) ;
79+ let mut ban_service = ban_service ( 1 ) ;
9980
10081 let ip: IpAddr = "127.0.0.2" . parse ( ) . unwrap ( ) ;
10182
@@ -107,7 +88,7 @@ mod tests {
10788
10889 #[ test]
10990 fn it_should_not_ban_ips_whose_counters_do_not_exceed_the_predefined_limit ( ) {
110- let mut ban_service = service_with_one_day_ban ( 1 ) ;
91+ let mut ban_service = ban_service ( 1 ) ;
11192
11293 let ip: IpAddr = "127.0.0.2" . parse ( ) . unwrap ( ) ;
11394
@@ -118,31 +99,13 @@ mod tests {
11899
119100 #[ test]
120101 fn it_should_allow_resetting_all_the_counters ( ) {
121- let mut ban_service = service_with_one_day_ban ( 1 ) ;
122-
123- let ip: IpAddr = "127.0.0.2" . parse ( ) . unwrap ( ) ;
124-
125- ban_service. increase_counter ( & ip) ; // Counter = 1
126-
127- ban_service. reset_filter ( ) ;
128-
129- assert_eq ! ( ban_service. get_counter( & ip) , 0 ) ;
130- }
131-
132- #[ tokio:: test]
133- async fn it_should_allow_run_a_bans_cleaner_to_reset_the_counters_periodically ( ) {
134- let udp_tracker_url = "udp://127.0.0.1" . parse ( ) . unwrap ( ) ;
135- let ban_duration_in_secs = 1 ;
136-
137- let mut ban_service = BanService :: new ( 1 , ban_duration_in_secs, udp_tracker_url) ;
102+ let mut ban_service = ban_service ( 1 ) ;
138103
139104 let ip: IpAddr = "127.0.0.2" . parse ( ) . unwrap ( ) ;
140105
141106 ban_service. increase_counter ( & ip) ; // Counter = 1
142107
143- sleep ( Duration :: from_secs ( 2 ) ) . await ;
144-
145- ban_service. run_bans_cleaner ( ) ;
108+ ban_service. reset_bans ( ) ;
146109
147110 assert_eq ! ( ban_service. get_counter( & ip) , 0 ) ;
148111 }
0 commit comments