diff --git a/Cargo.lock b/Cargo.lock index 7bd2d7037..c9f388f48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,6 +547,12 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "bit-vec" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4ff8b16e6076c3e14220b39fbc1fabb6737522281a388998046859400895f" + [[package]] name = "bitflags" version = "2.6.0" @@ -634,6 +640,15 @@ dependencies = [ "piper", ] +[[package]] +name = "bloom" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d00ac8e5056d6d65376a3c1aa5c7c34850d6949ace17f0266953a254eb3d6fe8" +dependencies = [ + "bit-vec", +] + [[package]] name = "blowfish" version = "0.9.1" @@ -3949,6 +3964,7 @@ dependencies = [ "bittorrent-http-protocol", "bittorrent-primitives", "bittorrent-tracker-client", + "bloom", "blowfish", "camino", "chrono", diff --git a/Cargo.toml b/Cargo.toml index f512dca92..6832f17f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ axum-server = { version = "0", features = ["tls-rustls-no-provider"] } bittorrent-http-protocol = { version = "3.0.0-develop", path = "packages/http-protocol" } bittorrent-primitives = "0.1.0" bittorrent-tracker-client = { version = "3.0.0-develop", path = "packages/tracker-client" } +bloom = "0.3.2" blowfish = "0" camino = { version = "1", features = ["serde", "serde1"] } chrono = { version = "0", default-features = false, features = ["clock"] } diff --git a/cSpell.json b/cSpell.json index 090a2b0e3..a21e69b9f 100644 --- a/cSpell.json +++ b/cSpell.json @@ -5,6 +5,7 @@ "alekitto", "appuser", "Arvid", + "ASMS", "asyn", "autoclean", "AUTOINCREMENT", diff --git a/share/default/config/tracker.udp.benchmarking.toml b/share/default/config/tracker.udp.benchmarking.toml index c6644d8dc..8a898153a 100644 --- a/share/default/config/tracker.udp.benchmarking.toml +++ b/share/default/config/tracker.udp.benchmarking.toml @@ -18,4 +18,4 @@ persistent_torrent_completed_stat = false remove_peerless_torrents = false [[udp_trackers]] -bind_address = "0.0.0.0:6969" +bind_address = "0.0.0.0:3000" diff --git a/src/servers/udp/handlers.rs b/src/servers/udp/handlers.rs index af22b263d..1fb450e1a 100644 --- a/src/servers/udp/handlers.rs +++ b/src/servers/udp/handlers.rs @@ -11,12 +11,14 @@ use aquatic_udp_protocol::{ ResponsePeer, ScrapeRequest, ScrapeResponse, TorrentScrapeStatistics, TransactionId, }; use bittorrent_primitives::info_hash::InfoHash; +use tokio::sync::RwLock; use torrust_tracker_clock::clock::Time as _; use tracing::{instrument, Level}; use uuid::Uuid; use zerocopy::network_endian::I32; use super::connection_cookie::{check, make}; +use super::server::banning::BanService; use super::RawRequest; use crate::core::{statistics, PeersWanted, Tracker}; use crate::servers::udp::error::Error; @@ -51,12 +53,13 @@ impl CookieTimeValues { /// - Delegating the request to the correct handler depending on the request type. /// /// It will return an `Error` response if the request is invalid. -#[instrument(fields(request_id), skip(udp_request, tracker, cookie_time_values), ret(level = Level::TRACE))] +#[instrument(fields(request_id), skip(udp_request, tracker, cookie_time_values, ban_service), ret(level = Level::TRACE))] pub(crate) async fn handle_packet( udp_request: RawRequest, tracker: &Tracker, local_addr: SocketAddr, cookie_time_values: CookieTimeValues, + ban_service: Arc>, ) -> Response { tracing::Span::current().record("request_id", Uuid::new_v4().to_string()); tracing::debug!("Handling Packets: {udp_request:?}"); @@ -68,6 +71,17 @@ pub(crate) async fn handle_packet( Ok(request) => match handle_request(request, udp_request.from, tracker, cookie_time_values.clone()).await { Ok(response) => return response, Err((e, transaction_id)) => { + match &e { + Error::CookieValueNotNormal { .. } + | Error::CookieValueExpired { .. } + | Error::CookieValueFromFuture { .. } => { + // code-review: should we include `RequestParseError` and `BadRequest`? + let mut ban_service = ban_service.write().await; + ban_service.increase_counter(&udp_request.from.ip()); + } + _ => {} + } + handle_error( udp_request.from, tracker, diff --git a/src/servers/udp/server/banning.rs b/src/servers/udp/server/banning.rs new file mode 100644 index 000000000..df236820c --- /dev/null +++ b/src/servers/udp/server/banning.rs @@ -0,0 +1,150 @@ +//! Banning service for UDP tracker. +//! +//! It bans clients that send invalid connection id's. +//! +//! It uses two levels of filtering: +//! +//! 1. First, tt uses a Counting Bloom Filter to keep track of the number of +//! connection ID errors per ip. That means there can be false positives, but +//! not false negatives. 1 out of 100000 requests will be a false positive +//! and the client will be banned and not receive a response. +//! 2. Since we want to avoid false positives (banning a client that is not +//! sending invalid connection id's), we use a `HashMap` to keep track of the +//! exact number of connection ID errors per ip. +//! +//! This two level filtering is to avoid false positives. It has the advantage +//! of being fast by using a Counting Bloom Filter and not having false +//! negatives at the cost of increasing the memory usage. +use std::collections::HashMap; +use std::net::IpAddr; + +use bloom::{CountingBloomFilter, ASMS}; +use tokio::time::Instant; +use url::Url; + +use crate::servers::udp::UDP_TRACKER_LOG_TARGET; + +pub struct BanService { + max_connection_id_errors_per_ip: u32, + fuzzy_error_counter: CountingBloomFilter, + accurate_error_counter: HashMap, + local_addr: Url, + last_connection_id_errors_reset: Instant, +} + +impl BanService { + #[must_use] + pub fn new(max_connection_id_errors_per_ip: u32, local_addr: Url) -> Self { + Self { + max_connection_id_errors_per_ip, + local_addr, + fuzzy_error_counter: CountingBloomFilter::with_rate(4, 0.01, 100), + accurate_error_counter: HashMap::new(), + last_connection_id_errors_reset: tokio::time::Instant::now(), + } + } + + pub fn increase_counter(&mut self, ip: &IpAddr) { + self.fuzzy_error_counter.insert(&ip.to_string()); + *self.accurate_error_counter.entry(*ip).or_insert(0) += 1; + } + + #[must_use] + pub fn get_count(&self, ip: &IpAddr) -> Option { + self.accurate_error_counter.get(ip).copied() + } + + #[must_use] + pub fn get_estimate_count(&self, ip: &IpAddr) -> u32 { + self.fuzzy_error_counter.estimate_count(&ip.to_string()) + } + + /// Returns true if the given ip address is banned. + #[must_use] + pub fn is_banned(&self, ip: &IpAddr) -> bool { + // First check if the ip is in the bloom filter (fast check) + if self.fuzzy_error_counter.estimate_count(&ip.to_string()) <= self.max_connection_id_errors_per_ip { + return false; + } + + // Check with the exact counter (to avoid false positives) + match self.get_count(ip) { + Some(count) => count > self.max_connection_id_errors_per_ip, + None => false, + } + } + + /// Resets the filters and updates the reset timestamp. + pub fn reset_bans(&mut self) { + self.fuzzy_error_counter.clear(); + + self.accurate_error_counter.clear(); + + self.last_connection_id_errors_reset = Instant::now(); + + let local_addr = self.local_addr.to_string(); + tracing::info!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop (connection id errors filter cleared)"); + } +} + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + + use super::BanService; + + /// Sample service with one day ban duration. + fn ban_service(counter_limit: u32) -> BanService { + let udp_tracker_url = "udp://127.0.0.1".parse().unwrap(); + BanService::new(counter_limit, udp_tracker_url) + } + + #[test] + fn it_should_increase_the_errors_counter_for_a_given_ip() { + let mut ban_service = ban_service(1); + + let ip: IpAddr = "127.0.0.2".parse().unwrap(); + + ban_service.increase_counter(&ip); + + assert_eq!(ban_service.get_count(&ip), Some(1)); + } + + #[test] + fn it_should_ban_ips_with_counters_exceeding_a_predefined_limit() { + let mut ban_service = ban_service(1); + + let ip: IpAddr = "127.0.0.2".parse().unwrap(); + + ban_service.increase_counter(&ip); // Counter = 1 + ban_service.increase_counter(&ip); // Counter = 2 + + println!("Counter: {}", ban_service.get_count(&ip).unwrap()); + + assert!(ban_service.is_banned(&ip)); + } + + #[test] + fn it_should_not_ban_ips_whose_counters_do_not_exceed_the_predefined_limit() { + let mut ban_service = ban_service(1); + + let ip: IpAddr = "127.0.0.2".parse().unwrap(); + + ban_service.increase_counter(&ip); + + assert!(!ban_service.is_banned(&ip)); + } + + #[test] + fn it_should_allow_resetting_all_the_counters() { + let mut ban_service = ban_service(1); + + let ip: IpAddr = "127.0.0.2".parse().unwrap(); + + ban_service.increase_counter(&ip); // Counter = 1 + + ban_service.reset_bans(); + + assert_eq!(ban_service.get_estimate_count(&ip), 0); + } +} diff --git a/src/servers/udp/server/launcher.rs b/src/servers/udp/server/launcher.rs index d6827346d..f314e3721 100644 --- a/src/servers/udp/server/launcher.rs +++ b/src/servers/udp/server/launcher.rs @@ -6,9 +6,11 @@ use bittorrent_tracker_client::udp::client::check; use derive_more::Constructor; use futures_util::StreamExt; use tokio::select; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, RwLock}; +use tokio::time::interval; use tracing::instrument; +use super::banning::BanService; use super::request_buffer::ActiveRequests; use crate::bootstrap::jobs::Started; use crate::core::{statistics, Tracker}; @@ -20,6 +22,11 @@ use crate::servers::udp::server::processor::Processor; use crate::servers::udp::server::receiver::Receiver; use crate::servers::udp::UDP_TRACKER_LOG_TARGET; +/// The maximum number of connection id errors per ip. Clients will be banned if +/// they exceed this limit. +const MAX_CONNECTION_ID_ERRORS_PER_IP: u32 = 10; +const IP_BANS_RESET_INTERVAL_IN_SECS: u64 = 120; + /// A UDP server instance launcher. #[derive(Constructor)] pub struct Launcher; @@ -115,13 +122,30 @@ impl Launcher { let active_requests = &mut ActiveRequests::default(); let addr = receiver.bound_socket_address(); + let local_addr = format!("udp://{addr}"); let cookie_lifetime = cookie_lifetime.as_secs_f64(); - loop { - let processor = Processor::new(receiver.socket.clone(), tracker.clone(), cookie_lifetime); + let ban_service = Arc::new(RwLock::new(BanService::new( + MAX_CONNECTION_ID_ERRORS_PER_IP, + local_addr.parse().unwrap(), + ))); + + let ban_cleaner = ban_service.clone(); + + tokio::spawn(async move { + let mut cleaner_interval = interval(Duration::from_secs(IP_BANS_RESET_INTERVAL_IN_SECS)); + + cleaner_interval.tick().await; + loop { + cleaner_interval.tick().await; + ban_cleaner.write().await.reset_bans(); + } + }); + + loop { if let Some(req) = { tracing::trace!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server (wait for request)"); receiver.next().await @@ -149,18 +173,26 @@ impl Launcher { } } - // We spawn the new task even if there active requests buffer is - // full. This could seem counterintuitive because we are accepting - // more request and consuming more memory even if the server is - // already busy. However, we "force_push" the new tasks in the - // buffer. That means, in the worst scenario we will abort a - // running task to make place for the new task. - // - // Once concern could be to reach an starvation point were we - // are only adding and removing tasks without given them the - // chance to finish. However, the buffer is yielding before - // aborting one tasks, giving it the chance to finish. - let abort_handle: tokio::task::AbortHandle = tokio::task::spawn(processor.process_request(req)).abort_handle(); + if ban_service.read().await.is_banned(&req.from.ip()) { + tracing::debug!(target: UDP_TRACKER_LOG_TARGET, local_addr, "Udp::run_udp_server::loop continue: (banned ip)"); + continue; + } + + let processor = Processor::new(receiver.socket.clone(), tracker.clone(), cookie_lifetime); + + /* We spawn the new task even if the active requests buffer is + full. This could seem counterintuitive because we are accepting + more request and consuming more memory even if the server is + already busy. However, we "force_push" the new tasks in the + buffer. That means, in the worst scenario we will abort a + running task to make place for the new task. + + Once concern could be to reach an starvation point were we are + only adding and removing tasks without given them the chance to + finish. However, the buffer is yielding before aborting one + tasks, giving it the chance to finish. */ + let abort_handle: tokio::task::AbortHandle = + tokio::task::spawn(processor.process_request(req, ban_service.clone())).abort_handle(); if abort_handle.is_finished() { continue; diff --git a/src/servers/udp/server/mod.rs b/src/servers/udp/server/mod.rs index 7067512b6..9f974ca8c 100644 --- a/src/servers/udp/server/mod.rs +++ b/src/servers/udp/server/mod.rs @@ -6,6 +6,7 @@ use thiserror::Error; use super::RawRequest; +pub mod banning; pub mod bound_socket; pub mod launcher; pub mod processor; diff --git a/src/servers/udp/server/processor.rs b/src/servers/udp/server/processor.rs index fc39f28b9..120196431 100644 --- a/src/servers/udp/server/processor.rs +++ b/src/servers/udp/server/processor.rs @@ -3,8 +3,10 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use aquatic_udp_protocol::Response; +use tokio::sync::RwLock; use tracing::{instrument, Level}; +use super::banning::BanService; use super::bound_socket::BoundSocket; use crate::core::{statistics, Tracker}; use crate::servers::udp::handlers::CookieTimeValues; @@ -25,16 +27,18 @@ impl Processor { } } - #[instrument(skip(self, request))] - pub async fn process_request(self, request: RawRequest) { + #[instrument(skip(self, request, ban_service))] + pub async fn process_request(self, request: RawRequest, ban_service: Arc>) { let from = request.from; let response = handlers::handle_packet( request, &self.tracker, self.socket.address(), CookieTimeValues::new(self.cookie_lifetime), + ban_service, ) .await; + self.send_response(from, response).await; } diff --git a/tests/servers/udp/contract.rs b/tests/servers/udp/contract.rs index b12a8a900..9e9085e62 100644 --- a/tests/servers/udp/contract.rs +++ b/tests/servers/udp/contract.rs @@ -130,10 +130,31 @@ mod receiving_an_announce_request { use crate::servers::udp::contract::send_connection_request; use crate::servers::udp::Started; - pub async fn send_and_get_announce(tx_id: TransactionId, c_id: ConnectionId, client: &UdpTrackerClient) { - // Send announce request + pub async fn assert_send_and_get_announce(tx_id: TransactionId, c_id: ConnectionId, client: &UdpTrackerClient) { + let response = send_and_get_announce(tx_id, c_id, client).await; + assert!(is_ipv4_announce_response(&response)); + } + + pub async fn send_and_get_announce( + tx_id: TransactionId, + c_id: ConnectionId, + client: &UdpTrackerClient, + ) -> aquatic_udp_protocol::Response { + let announce_request = build_sample_announce_request(tx_id, c_id, client.client.socket.local_addr().unwrap().port()); + + match client.send(announce_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; - let announce_request = AnnounceRequest { + match client.receive().await { + Ok(response) => response, + Err(err) => panic!("{err}"), + } + } + + fn build_sample_announce_request(tx_id: TransactionId, c_id: ConnectionId, port: u16) -> AnnounceRequest { + AnnounceRequest { connection_id: ConnectionId(c_id.0), action_placeholder: AnnounceActionPlaceholder::default(), transaction_id: tx_id, @@ -146,26 +167,34 @@ mod receiving_an_announce_request { ip_address: Ipv4Addr::new(0, 0, 0, 0).into(), key: PeerKey::new(0i32), peers_wanted: NumberOfPeers(1i32.into()), - port: Port(client.client.socket.local_addr().unwrap().port().into()), - }; + port: Port(port.into()), + } + } - match client.send(announce_request.into()).await { - Ok(_) => (), - Err(err) => panic!("{err}"), - }; + #[tokio::test] + async fn should_return_an_announce_response() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); - let response = match client.receive().await { - Ok(response) => response, + let env = Started::new(&configuration::ephemeral().into()).await; + + let client = match UdpTrackerClient::new(env.bind_address(), DEFAULT_TIMEOUT).await { + Ok(udp_tracker_client) => udp_tracker_client, Err(err) => panic!("{err}"), }; - // println!("test response {response:?}"); + let tx_id = TransactionId::new(123); - assert!(is_ipv4_announce_response(&response)); + let c_id = send_connection_request(tx_id, &client).await; + + assert_send_and_get_announce(tx_id, c_id, &client).await; + + env.stop().await; } #[tokio::test] - async fn should_return_an_announce_response() { + async fn should_return_many_announce_response() { INIT.call_once(|| { tracing_stderr_init(LevelFilter::ERROR); }); @@ -181,13 +210,16 @@ mod receiving_an_announce_request { let c_id = send_connection_request(tx_id, &client).await; - send_and_get_announce(tx_id, c_id, &client).await; + for x in 0..1000 { + tracing::info!("req no: {x}"); + assert_send_and_get_announce(tx_id, c_id, &client).await; + } env.stop().await; } #[tokio::test] - async fn should_return_many_announce_response() { + async fn should_ban_the_client_ip_if_it_sends_more_than_10_requests_with_a_cookie_value_not_normal() { INIT.call_once(|| { tracing_stderr_init(LevelFilter::ERROR); }); @@ -201,13 +233,30 @@ mod receiving_an_announce_request { let tx_id = TransactionId::new(123); - let c_id = send_connection_request(tx_id, &client).await; + // The eleven first requests should be fine - for x in 0..1000 { + let invalid_connection_id = ConnectionId::new(0); // Zero is one of the not normal values. + + for x in 0..=10 { tracing::info!("req no: {x}"); - send_and_get_announce(tx_id, c_id, &client).await; + send_and_get_announce(tx_id, invalid_connection_id, &client).await; } + // The twelfth request should be banned (timeout error) + + let announce_request = build_sample_announce_request( + tx_id, + invalid_connection_id, + client.client.socket.local_addr().unwrap().port(), + ); + + match client.send(announce_request.into()).await { + Ok(_) => (), + Err(err) => panic!("{err}"), + }; + + assert!(client.receive().await.is_err()); + env.stop().await; } }