diff --git a/Cargo.lock b/Cargo.lock index ce5ce4e8b2c..949ce9663bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1221,6 +1221,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-lc-rs" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "aws-runtime" version = "1.5.5" @@ -1702,6 +1725,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.90", + "which 4.4.2", +] + [[package]] name = "bit-set" version = "0.8.0" @@ -2049,6 +2095,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfb-mode" version = "0.8.2" @@ -2141,7 +2196,7 @@ dependencies = [ "tokio", "tracing", "url", - "which", + "which 6.0.3", "winreg", ] @@ -2302,6 +2357,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ceab37c9e94f42414cccae77e930232c517f1bb190947018cffb0ab41fc40992" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.23" @@ -2387,6 +2453,15 @@ dependencies = [ "digest", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "codespan-reporting" version = "0.11.1" @@ -2506,6 +2581,7 @@ dependencies = [ "rand", "regex", "reqwest", + "rustls 0.23.20", "sea-orm", "segment", "serde", @@ -4065,6 +4141,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -5249,6 +5331,12 @@ dependencies = [ "spin", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "lettre" version = "0.11.11" @@ -5376,6 +5464,16 @@ dependencies = [ "rle-decode-fast", ] +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libm" version = "0.2.11" @@ -6272,6 +6370,7 @@ dependencies = [ "pyroscope_pprofrs", "rand", "rayon", + "rcgen", "regex", "regex-syntax 0.8.5", "report_server", @@ -6296,6 +6395,7 @@ dependencies = [ "tikv-jemallocator", "time", "tokio", + "tokio-rustls 0.26.1", "tokio-stream", "tokio-tungstenite", "tokio-util", @@ -7643,6 +7743,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "recursive" version = "0.1.1" @@ -8089,6 +8202,7 @@ version = "0.23.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -8178,6 +8292,7 @@ version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -10678,6 +10793,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "which" version = "6.0.3" @@ -11157,6 +11284,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index 97b8e48241b..4121f9b9369 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -188,12 +188,14 @@ actix-service = "2.0.3" actix-utils = "3.0.1" derive_more = { version = "2.0.1", features = ["full"] } brotli = "8.0.1" +tokio-rustls = "0.26.1" [dev-dependencies] async-walkdir.workspace = true expect-test.workspace = true base64 = "0.22" float-cmp = "0.10" +rcgen = "0.13" [workspace] members = [ @@ -339,6 +341,7 @@ tempfile = "3" thiserror = "1.0" time = "0.3" tokio = { version = "1", features = ["full"] } +tokio-rustls = "0.26" tokio-util = { version = "0.7.12", features = ["compat"] } tokio-stream = "0.1" tonic = { version = "0.12.3", features = ["gzip", "prost", "tls"] } diff --git a/coverage.sh b/coverage.sh index 7d554b252fd..0f08c316492 100755 --- a/coverage.sh +++ b/coverage.sh @@ -29,7 +29,7 @@ _cov_test() { cargo llvm-cov test \ --workspace \ --verbose \ - --ignore-filename-regex job \ + --ignore-filename-regex 'job|.*generated.*' \ "$@" } diff --git a/src/common/utils/auth_tests.rs b/src/common/utils/auth_tests.rs index ea7c7044a72..5b92e028ff1 100644 --- a/src/common/utils/auth_tests.rs +++ b/src/common/utils/auth_tests.rs @@ -3268,6 +3268,7 @@ mod tests { file_download_priority_queue_window_secs: Default::default(), file_download_enable_priority_queue: Default::default(), histogram_enabled: Default::default(), + calculate_stats_step_limit: Default::default(), }, compact: config::Compact { enabled: bool::default(), diff --git a/src/common/utils/redirect_response.rs b/src/common/utils/redirect_response.rs index 854e70c658b..9bb17396488 100644 --- a/src/common/utils/redirect_response.rs +++ b/src/common/utils/redirect_response.rs @@ -59,6 +59,7 @@ impl RedirectResponse { .append_header((LOCATION, redirect_uri)) .finish() } else { + // if the URL is too long, we send the original URL and let FE handle the redirect. let html = format!( r#" diff --git a/src/config/Cargo.toml b/src/config/Cargo.toml index ce7904eed25..32a093b68fb 100644 --- a/src/config/Cargo.toml +++ b/src/config/Cargo.toml @@ -76,6 +76,7 @@ tracing-subscriber.workspace = true urlencoding.workspace = true utoipa.workspace = true vrl.workspace = true +rustls = "0.23.20" [dev-dependencies] expect-test.workspace = true diff --git a/src/config/src/config.rs b/src/config/src/config.rs index 6530aab8562..44b3b90e1a3 100644 --- a/src/config/src/config.rs +++ b/src/config/src/config.rs @@ -659,6 +659,14 @@ pub struct TCP { pub tcp_port: u16, #[env_config(name = "ZO_UDP_PORT", default = 5514)] pub udp_port: u16, + #[env_config(name = "ZO_TCP_TLS_ENABLED", default = false)] + pub tcp_tls_enabled: bool, + #[env_config(name = "ZO_TCP_TLS_CERT_PATH", default = "")] + pub tcp_tls_cert_path: String, + #[env_config(name = "ZO_TCP_TLS_KEY_PATH", default = "")] + pub tcp_tls_key_path: String, + #[env_config(name = "ZO_TCP_TLS_CA_CERT_PATH", default = "")] + pub tcp_tls_ca_cert_path: String, } #[derive(EnvConfig)] diff --git a/src/config/src/utils/cert.rs b/src/config/src/utils/cert.rs new file mode 100644 index 00000000000..8fbe4b27fa8 --- /dev/null +++ b/src/config/src/utils/cert.rs @@ -0,0 +1,128 @@ +use rustls::client::danger; +use rustls::crypto::{verify_tls12_signature, verify_tls13_signature}; +use rustls::{DigitallySignedStruct, SignatureScheme}; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; + +/// A custom certificate verifier that accepts any certificate +/// as long as it's in the list of trusted certificates +#[derive(Debug)] +pub struct SelfSignedCertVerifier<'a> { + allowed_certs: Vec>, +} + +impl<'a> SelfSignedCertVerifier<'a> { + pub fn new(allowed_certs: Vec>) -> Self { + Self { allowed_certs: allowed_certs } + } +} + +impl<'a> danger::ServerCertVerifier for SelfSignedCertVerifier<'a> { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + // Check if the server's certificate matches our trusted certificate + if self.allowed_certs.iter().any(|cert| cert.as_ref() == _end_entity.as_ref()) { + Ok(danger::ServerCertVerified::assertion()) + } else { + Err(rustls::Error::General("Server certificate not trusted".into())) + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + let provider = rustls::crypto::ring::default_provider(); + verify_tls12_signature( + message, + cert, + dss, + &provider.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + let provider = rustls::crypto::ring::default_provider(); + verify_tls13_signature( + message, + cert, + dss, + &provider.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + let provider = rustls::crypto::ring::default_provider(); + provider.signature_verification_algorithms.supported_schemes() + } +} + +#[cfg(test)] +mod tests { + use rustls::client::danger::ServerCertVerifier; + use super::*; + + #[test] + fn test_verifies_self_signed_cert() { + // Create a test certificate + let cert_data = b"-----MOCK CERTIFICATE DATA-----"; + let cert = CertificateDer::from(cert_data.to_vec()); + + // Create a verifier with this certificate as allowed + let verifier = SelfSignedCertVerifier::new(vec![cert.clone()]); + + // Test the verification + let result = verifier.verify_server_cert( + &cert, + &[], + &ServerName::try_from("example.com").unwrap(), + &[], + UnixTime::now(), + ); + + assert!(result.is_ok(), "Certificate verification should succeed"); + } + + #[test] + fn test_fails_on_invalid_cert() { + // Create a trusted certificate + let trusted_cert_data = b"-----TRUSTED CERTIFICATE DATA-----"; + let trusted_cert = CertificateDer::from(trusted_cert_data.to_vec()); + + // Create a different, untrusted certificate + let untrusted_cert_data = b"-----UNTRUSTED CERTIFICATE DATA-----"; + let untrusted_cert = CertificateDer::from(untrusted_cert_data.to_vec()); + + // Create a verifier that only trusts the first certificate + let verifier = SelfSignedCertVerifier::new(vec![trusted_cert]); + + // Test verification with the untrusted certificate + let result = verifier.verify_server_cert( + &untrusted_cert, + &[], + &ServerName::try_from("example.com").unwrap(), + &[], + UnixTime::now(), + ); + + assert!(result.is_err(), "Certificate verification should fail"); + match result { + Err(rustls::Error::General(msg)) => { + assert_eq!(msg, "Server certificate not trusted"); + } + _ => panic!("Expected General error with 'Server certificate not trusted' message"), + } + } +} \ No newline at end of file diff --git a/src/config/src/utils/mod.rs b/src/config/src/utils/mod.rs index d985514af74..a52f1fb42ab 100644 --- a/src/config/src/utils/mod.rs +++ b/src/config/src/utils/mod.rs @@ -16,6 +16,7 @@ pub mod arrow; pub mod async_file; pub mod base64; +pub mod cert; pub mod download_utils; pub mod file; pub mod flatten; diff --git a/src/handler/grpc/flight/mod.rs b/src/handler/grpc/flight/mod.rs index 0da8582d4f0..20fdaf2c03b 100644 --- a/src/handler/grpc/flight/mod.rs +++ b/src/handler/grpc/flight/mod.rs @@ -136,9 +136,7 @@ impl FlightService for FlightServiceImpl { Ok(v) => v, Err(e) => { // clear session data - crate::service::search::datafusion::storage::file_list::clear(&trace_id); - // release wal lock files - crate::common::infra::wal::release_request(&trace_id); + clear_session_data(&trace_id); log::error!( "[trace_id {}] flight->search: do_get physical plan generate error: {e:?}", trace_id, @@ -183,20 +181,31 @@ impl FlightService for FlightServiceImpl { let start = std::time::Instant::now(); let write_options: IpcWriteOptions = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::ZSTD)) - .map_err(|e| Status::internal(e.to_string()))?; + .map_err(|e| { + // clear session data + clear_session_data(&trace_id); + log::error!( + "[trace_id {}] flight->search: do_get create IPC write options error: {e:?}", + trace_id, + ); + Status::internal(e.to_string()) + })?; + let stream = execute_stream(physical_plan, ctx.task_ctx().clone()).map_err(|e| { + // clear session data + clear_session_data(&trace_id); + log::error!( + "[trace_id {}] flight->search: do_get physical plan execution error: {e:?}", + trace_id, + ); + Status::internal(e.to_string()) + })?; let flight_data_stream = FlightDataEncoderBuilder::new() .with_schema(schema) .with_max_flight_data_size(33554432) // 32MB .with_options(write_options) .build(FlightSenderStream::new( trace_id.to_string(), - execute_stream(physical_plan, ctx.task_ctx().clone()).map_err(|e| { - log::error!( - "[trace_id {}] flight->search: do_get physical plan execution error: {e:?}", - trace_id, - ); - Status::internal(e.to_string()) - })?, + stream, defer, start, timeout, @@ -353,9 +362,7 @@ impl Drop for FlightSenderStream { self.trace_id ); // clear session data - crate::service::search::datafusion::storage::file_list::clear(&self.trace_id); - // release wal lock files - crate::common::infra::wal::release_request(&self.trace_id); + clear_session_data(&self.trace_id); } } } @@ -399,3 +406,10 @@ fn add_scan_stats_to_schema(schema: Arc, scan_stats: ScanStats) -> Arc, body: web::Bytes) -> Result, +} + /// Retrieve the original URL from a short_id /// /// #{"ratelimit_module":"ShortUrl", "ratelimit_module_operation":"get"}# @@ -84,12 +91,14 @@ pub async fn shorten(org_id: web::Path, body: web::Bytes) -> Result, Query, description = "Response type - if 'ui', returns JSON object instead of redirect", example = "ui") ), responses( - (status = 302, description = "Redirect to the original URL", headers( - ("Location" = String, description = "The original URL to which the client is redirected") + (status = 302, description = "Redirect to original URL (if < 1024 chars) or /web/short_url/{short_id}", headers( + ("Location" = String, description = "The original URL or /web/short_url/{short_id} to which the client is redirected") )), + (status = 200, description = "JSON response when type=ui", body = String, content_type = "application/json"), (status = 404, description = "Short URL not found", content_type = "text/plain") ), tag = "Short Url" @@ -98,16 +107,32 @@ pub async fn shorten(org_id: web::Path, body: web::Bytes) -> Result, + query: web::Query, ) -> Result { log::info!( "short_url::retrieve handler called for path: {}", req.path() ); - let (_org_id, short_id) = path.into_inner(); + let (org_id, short_id) = path.into_inner(); let original_url = short_url::retrieve(&short_id).await; - if let Some(url) = original_url { - let redirect_http = RedirectResponseBuilder::new(&url).build().redirect_http(); + // Check if type=ui for JSON response + if let Some(ref type_param) = query.type_param { + if type_param == "ui" { + if let Some(url) = original_url { + return Ok(HttpResponse::Ok().json(url)); + } else { + return Ok(HttpResponse::NotFound().finish()); + } + } + } + + // Here we redirect the legacy short urls to the new short url + // the redirection then will be handled by the frontend using this flow + // TODO: Remove this once we are sure there is no more legacy short urls + if original_url.is_some() { + let redirect_url = short_url::construct_short_url(&org_id, &short_id); + let redirect_http = RedirectResponseBuilder::new(&redirect_url).build().redirect_http(); Ok(redirect_http) } else { let redirect = RedirectResponseBuilder::default().build(); diff --git a/src/handler/http/request/syslog/mod.rs b/src/handler/http/request/syslog/mod.rs index eb6c0092c5d..83ba0afa44b 100644 --- a/src/handler/http/request/syslog/mod.rs +++ b/src/handler/http/request/syslog/mod.rs @@ -135,3 +135,45 @@ async fn delete_route(path: web::Path<(String, String)>) -> impl Responder { let (_, id) = path.into_inner(); syslogs_route::delete_route(&id).await } + +/// GetSyslogTCPServerCACert +/// +/// #{"ratelimit_module":"Syslog Routes", "ratelimit_module_operation":"get"}# +#[utoipa::path( + context_path = "/api", + tag = "Syslog Routes", + operation_id = "GetSyslogTCPServerCACert", + security( + ("Authorization" = []) + ), + responses( + (status = StatusCode::OK, description = "PEM-formatted TLS CA certificate", body = String, content_type = "application/x-pem-file"), + (status = StatusCode::NOT_FOUND, description = "Certificate not found"), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error") + ), +)] +#[get("/syslog-tcp-server-ca-cert")] +async fn get_tcp_tls_ca_cert() -> impl Responder { + syslogs_route::get_tcp_tls_ca_cert().await +} + +/// GetSyslogTCPServerCert +/// +/// #{"ratelimit_module":"Syslog Routes", "ratelimit_module_operation":"get"}# +#[utoipa::path( + context_path = "/api", + tag = "Syslog Routes", + operation_id = "GetSyslogTCPServerCert", + security( + ("Authorization" = []) + ), + responses( + (status = StatusCode::OK, description = "PEM-formatted TLS certificate", body = String, content_type = "application/x-pem-file"), + (status = StatusCode::NOT_FOUND, description = "Certificate not found"), + (status = StatusCode::INTERNAL_SERVER_ERROR, description = "Internal server error") + ), +)] +#[get("/syslog-tcp-server-cert")] +async fn get_tcp_tls_cert() -> impl Responder { + syslogs_route::get_tcp_tls_cert().await +} \ No newline at end of file diff --git a/src/handler/http/router/mod.rs b/src/handler/http/router/mod.rs index 864d84ce271..fb124b02216 100644 --- a/src/handler/http/router/mod.rs +++ b/src/handler/http/router/mod.rs @@ -521,6 +521,8 @@ pub fn get_service_routes(svc: &mut web::ServiceConfig) { .service(syslog::delete_route) .service(syslog::update_route) .service(syslog::toggle_state) + .service(syslog::get_tcp_tls_ca_cert) + .service(syslog::get_tcp_tls_cert) .service(enrichment_table::save_enrichment_table) .service(metrics::ingest::otlp_metrics_write) .service(logs::ingest::otlp_logs_write) diff --git a/src/handler/http/router/openapi.rs b/src/handler/http/router/openapi.rs index 85b9ca994a3..6db182d8a27 100644 --- a/src/handler/http/router/openapi.rs +++ b/src/handler/http/router/openapi.rs @@ -139,6 +139,8 @@ use crate::{common::meta, handler::http::request}; request::syslog::update_route, request::syslog::list_routes, request::syslog::delete_route, + request::syslog::get_tcp_tls_ca_cert, + request::syslog::get_tcp_tls_cert, request::clusters::list_clusters, request::short_url::shorten, request::short_url::retrieve, diff --git a/src/handler/tcp_udp/mod.rs b/src/handler/tcp_udp/mod.rs index 132d8fe82dd..32664587d8c 100644 --- a/src/handler/tcp_udp/mod.rs +++ b/src/handler/tcp_udp/mod.rs @@ -13,12 +13,11 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . +use std::net::SocketAddr; use bytes::BytesMut; -use tokio::{ - io::AsyncReadExt, - net::{TcpListener, UdpSocket}, -}; - +use tokio::{io::AsyncReadExt, net::{TcpListener, UdpSocket}}; +use tokio::io::AsyncRead; +use tokio_rustls::TlsAcceptor; use crate::{job::syslog_server::BROADCASTER, service::logs::syslog}; pub static STOP_SRV: &str = "ZO_STOP_TCP_UDP"; @@ -55,63 +54,75 @@ pub async fn udp_server(socket: UdpSocket) { } } -pub async fn tcp_server(listener: TcpListener) { +pub async fn tls_tcp_server(listener: TcpListener, tls_acceptor: Option) { let sender = BROADCASTER.read().await; let mut tcp_receiver_rx = sender.subscribe(); loop { - let (mut stream, _) = match listener.accept().await { + let (tcp_stream, peer_addr) = match listener.accept().await { Ok(val) => val, Err(e) => { log::error!("Error while accepting TCP connection: {}", e); continue; } }; - tokio::task::spawn(async move { - let mut buf_tcp = vec![0u8; 1460]; - let peer_addr = match stream.peer_addr() { - Ok(addr) => addr, + match tls_acceptor.clone() { + Some(acceptor) => match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + log::info!("accepted TLS connection for peer {}", peer_addr); + tokio::task::spawn(handle_connection(tls_stream, peer_addr)); + }, Err(e) => { - log::error!("Error while reading peer_addr from TCP stream: {}", e); - return; - } - }; - log::info!("spawned new syslog tcp receiver for peer {}", peer_addr); - loop { - let n = match stream.read(&mut buf_tcp).await { - Ok(0) => { - log::info!("received 0 bytes, closing for peer {}", peer_addr); - break; - } - Ok(n) => n, - Err(e) => { - log::error!("Error while reading from TCP stream: {}", e); - break; - } - }; - let message = BytesMut::from(&buf_tcp[..n]); - let input_str = match String::from_utf8(message.to_vec()) { - Ok(val) => val, - Err(e) => { - log::error!("Error while converting TCP message to UTF8 string: {}", e); - continue; - } - }; - if input_str != STOP_SRV { - if let Err(e) = syslog::ingest(&input_str, peer_addr).await { - log::error!("Error while ingesting TCP message: {}", e); - } - } else { - log::info!("received stop signal, closing for peer {}", peer_addr); - break; + log::error!("TLS accept error: {}", e); } + }, + None => { + tokio::task::spawn(handle_connection(tcp_stream, peer_addr)); } - }); + } + if let Ok(val) = tcp_receiver_rx.try_recv() { if !val { log::warn!("TCP server - received the stop signal, exiting."); drop(listener); break; } - }; + } } } + +async fn handle_connection(mut stream: S, peer_addr: SocketAddr) +where + S: AsyncRead + Unpin + Send + 'static, +{ + let mut buf_tcp = vec![0u8; 1460]; + log::info!("spawned new syslog tcp receiver for peer {}", peer_addr); + loop { + let n = match stream.read(&mut buf_tcp).await { + Ok(0) => { + log::info!("received 0 bytes, closing for peer {}", peer_addr); + break; + } + Ok(n) => n, + Err(e) => { + log::error!("Error while reading from TCP stream: {}", e); + break; + } + }; + let message = BytesMut::from(&buf_tcp[..n]); + let input_str = match String::from_utf8(message.to_vec()) { + Ok(val) => val, + Err(e) => { + log::error!("Error while converting TCP message to UTF8 string: {}", e); + continue; + } + }; + if input_str != STOP_SRV { + if let Err(e) = syslog::ingest(&input_str, peer_addr).await { + log::error!("Error while ingesting TCP message: {}", e); + } + } else { + log::info!("received stop signal, closing for peer {}", peer_addr); + break; + } + } +} \ No newline at end of file diff --git a/src/job/syslog_server.rs b/src/job/syslog_server.rs index be7b6d62781..a1726cac0d5 100644 --- a/src/job/syslog_server.rs +++ b/src/job/syslog_server.rs @@ -14,21 +14,24 @@ // along with this program. If not, see . use std::{ - io::Write, - net::{SocketAddr, TcpStream}, + net::{SocketAddr}, }; - +use std::sync::Arc; use once_cell::sync::Lazy; +use rustls::pki_types::ServerName; use tokio::{ net::{TcpListener, UdpSocket}, sync::{RwLock, broadcast}, }; - +use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; +use tokio_rustls::{TlsAcceptor, TlsConnector}; use crate::{ common::infra::config::SYSLOG_ENABLED, - handler::tcp_udp::{STOP_SRV, tcp_server, udp_server}, + handler::tcp_udp::{STOP_SRV, tls_tcp_server, udp_server}, service::db::syslog::toggle_syslog_setting, }; +use crate::service::tls::{tcp_tls_self_connect_client_config, tcp_tls_server_config}; // TCP UDP Server pub static BROADCASTER: Lazy>> = Lazy::new(|| { @@ -42,13 +45,27 @@ pub async fn run(start_srv: bool, is_init: bool) -> Result<(), anyhow::Error> { let bind_addr = "0.0.0.0"; let tcp_addr: SocketAddr = format!("{bind_addr}:{}", cfg.tcp.tcp_port).parse()?; let udp_addr: SocketAddr = format!("{bind_addr}:{}", cfg.tcp.udp_port).parse()?; + let tcp_tls_enabled = cfg.tcp.tcp_tls_enabled; if (!server_running || is_init) && start_srv { log::info!("Starting TCP UDP server"); let tcp_listener: TcpListener = TcpListener::bind(tcp_addr).await?; let udp_socket = UdpSocket::bind(udp_addr).await?; + let tls_server_config = if tcp_tls_enabled { + log::info!("TCP TLS enabled, preparing TLS server config"); + let tls_server_config = tcp_tls_server_config()?; + log::info!("TCP TLS config prepared"); + Some(tls_server_config) + } else { + None + }; + let tls_acceptor = tls_server_config + .map(Arc::new) + .map(TlsAcceptor::from); + tokio::task::spawn(async move { - _ = tcp_server(tcp_listener).await; + _ = tls_tcp_server(tcp_listener, tls_acceptor).await; }); + tokio::task::spawn(async move { _ = udp_server(udp_socket).await; }); @@ -60,11 +77,25 @@ pub async fn run(start_srv: bool, is_init: bool) -> Result<(), anyhow::Error> { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.send_to(STOP_SRV.as_bytes(), udp_addr).await?; - let mut stream = TcpStream::connect(tcp_addr)?; - stream.write_all(STOP_SRV.as_bytes())?; - drop(socket); - drop(stream); + + if tcp_tls_enabled { + let config= tcp_tls_self_connect_client_config()?; + let connector = TlsConnector::from(config); + match TcpStream::connect(tcp_addr).await { + Ok(stream) => { + let mut tls_stream = connector.connect(ServerName::try_from("127.0.0.1").unwrap(), stream).await?; + tls_stream.write_all(STOP_SRV.as_bytes()).await?; + } + Err(e) => { + log::error!("Failed to connect to TCP server for stop signal: {}", e); + } + } + } else { + // Plain TCP connection for non-TLS mode + let mut stream = TcpStream::connect(tcp_addr).await?; + stream.write_all(STOP_SRV.as_bytes()).await?; + } toggle_syslog_setting(start_srv).await.unwrap(); } diff --git a/src/service/circuit_breaker.rs b/src/service/circuit_breaker.rs index 349b5ab500d..226e112d3a9 100644 --- a/src/service/circuit_breaker.rs +++ b/src/service/circuit_breaker.rs @@ -186,3 +186,432 @@ impl CircuitBreaker { Utc::now().timestamp() / self.watching_window } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::time; + + use super::*; + + // Helper function to create a test circuit breaker with custom parameters + fn create_test_circuit_breaker( + watching_window: i64, + reset_window_num: i64, + slow_request_threshold: u64, + http_slow_log_threshold: u64, + ) -> Arc { + CircuitBreaker::new( + watching_window, + reset_window_num, + slow_request_threshold, + http_slow_log_threshold, + ) + } + + // Helper function to create a default test circuit breaker + fn create_default_test_circuit_breaker() -> Arc { + create_test_circuit_breaker(10, 2, 5, 1000) // 10s window, 2x reset, 5 slow req threshold, 1s slow threshold + } + + #[test] + fn test_circuit_breaker_new() { + let cb = create_test_circuit_breaker(5, 3, 10, 2000); + + assert_eq!(cb.watching_window, 5); + assert_eq!(cb.reset_window_num, 3); + assert_eq!(cb.slow_request_threshold, 10); + assert_eq!(cb.http_slow_log_threshold, 2000); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 0); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_circuit_breaker_watch_fast_requests() { + let cb = create_default_test_circuit_breaker(); + + // Send fast requests (below slow threshold) + for _ in 0..10 { + cb.watch(500); // 500ms, below 1000ms threshold + } + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 10); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 0); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + } + + #[test] + fn test_circuit_breaker_watch_slow_requests_below_threshold() { + let cb = create_default_test_circuit_breaker(); + + // Send slow requests but below threshold + for _ in 0..4 { + cb.watch(1500); // 1500ms, above 1000ms threshold + } + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 4); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 4); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + } + + #[tokio::test] + async fn test_circuit_breaker_opens_when_threshold_exceeded() { + let cb = create_default_test_circuit_breaker(); + + // Send enough slow requests to trigger circuit breaker + // Need 6 requests because fetch_add returns previous value + for _ in 0..6 { + cb.watch(1500); // 1500ms, above 1000ms threshold + } + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 6); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 6); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + assert!(cb.will_reset_at.load(Ordering::Relaxed) > Utc::now().timestamp()); + } + + #[tokio::test] + async fn test_circuit_breaker_mixed_requests() { + let cb = create_default_test_circuit_breaker(); + + // Send mix of fast and slow requests + for _ in 0..3 { + cb.watch(500); // Fast request + cb.watch(1500); // Slow request + } + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 6); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 3); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + + // Add 3 more slow requests to exceed threshold (need total of 6 slow requests) + cb.watch(1500); + cb.watch(1500); + cb.watch(1500); + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 9); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 6); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[tokio::test] + async fn test_circuit_breaker_open_directly() { + let cb = create_default_test_circuit_breaker(); + + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + + cb.open(); + + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + let reset_time = cb.will_reset_at.load(Ordering::Relaxed); + let expected_reset_time = Utc::now().timestamp() + cb.watching_window * cb.reset_window_num; + // Allow 1 second tolerance for timing + assert!((reset_time - expected_reset_time).abs() <= 1); + } + + #[test] + fn test_get_current_window_timestamp() { + let cb = create_test_circuit_breaker(10, 2, 5, 1000); + + let now = Utc::now().timestamp(); + let window_timestamp = cb.get_current_window_timestamp(); + let expected_window = now / 10; + + assert_eq!(window_timestamp, expected_window); + } + + #[test] + fn test_reset_current_window_same_window() { + let cb = create_default_test_circuit_breaker(); + + // Add some requests + cb.watch(500); + cb.watch(1500); + + let initial_total = cb.total_requests.load(Ordering::Relaxed); + let initial_slow = cb.slow_requests.load(Ordering::Relaxed); + let initial_window = cb.current_window.load(Ordering::Relaxed); + + // Reset current window when still in same window should not change counters + cb.reset_current_window(); + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), initial_total); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), initial_slow); + assert_eq!(cb.current_window.load(Ordering::Relaxed), initial_window); + } + + #[tokio::test] + async fn test_reset_state_when_conditions_met() { + let cb = create_test_circuit_breaker(1, 1, 3, 500); // 1s window, 1x reset, 3 slow req threshold + + // Open the circuit breaker + cb.open(); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + + // Wait for reset time to pass + time::sleep(Duration::from_secs(2)).await; + + // Make sure slow requests are below threshold + assert!(cb.slow_requests.load(Ordering::Relaxed) < cb.slow_request_threshold); + + // Trigger reset state check + cb.reset_state(); + + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + } + + #[tokio::test] + async fn test_reset_state_when_conditions_not_met() { + let cb = create_default_test_circuit_breaker(); + + // Open the circuit breaker + cb.open(); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + + // Reset should not happen immediately (reset time not reached) + cb.reset_state(); + + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[test] + fn test_circuit_breaker_state_enum_values() { + assert_eq!(CircuitBreakerState::Closed as u64, 0); + assert_eq!(CircuitBreakerState::Open as u64, 1); + assert_eq!(CircuitBreakerState::HalfOpen as u64, 2); + } + + #[tokio::test] + async fn test_circuit_breaker_boundary_conditions() { + // Test with threshold of 1 + let cb = create_test_circuit_breaker(5, 2, 1, 1000); + + // First slow request should not trigger due to fetch_add behavior + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 1); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + + // Second slow request should trigger circuit breaker + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 2); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[tokio::test] + async fn test_circuit_breaker_zero_threshold() { + // Test with threshold of 0 (circuit breaker opens immediately on first slow request) + let cb = create_test_circuit_breaker(5, 2, 0, 1000); + + // First slow request should open the circuit breaker (0 >= 0) + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 1); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[tokio::test] + async fn test_circuit_breaker_exact_threshold_boundary() { + let cb = create_test_circuit_breaker(5, 2, 3, 1000); + + // Send exactly 2 slow requests (below threshold) + cb.watch(1500); + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 2); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + + // Send 3rd slow request (still below due to fetch_add behavior) + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 3); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + + // Send 4th slow request (meets threshold due to fetch_add behavior) + cb.watch(1500); + + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 4); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[tokio::test] + async fn test_circuit_breaker_concurrent_watch_calls() { + let cb = create_default_test_circuit_breaker(); + let cb_clone = cb.clone(); + + // Simulate concurrent watch calls using tokio tasks + let handle1 = tokio::task::spawn(async move { + for _ in 0..10 { + cb_clone.watch(1500); // Slow requests + } + }); + + let cb_clone2 = cb.clone(); + let handle2 = tokio::task::spawn(async move { + for _ in 0..10 { + cb_clone2.watch(500); // Fast requests + } + }); + + handle1.await.expect("Task 1 should complete successfully"); + handle2.await.expect("Task 2 should complete successfully"); + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 20); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 10); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } + + #[test] + fn test_circuit_breaker_watch_request_threshold_edge() { + let cb = create_test_circuit_breaker(5, 2, 5, 1000); + + // Send requests at exactly the slow threshold + cb.watch(999); // Just below threshold - should be considered fast + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 0); + + cb.watch(1000); // Exactly at threshold - should be considered slow (>= threshold) + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 1); + + cb.watch(1001); // Just above threshold - should be considered slow + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 2); + } + + #[test] + fn test_circuit_breaker_large_values() { + let cb = create_test_circuit_breaker(3600, 10, 1000, 5000); // 1 hour window + + // Test with large request times + cb.watch(10000); // 10 seconds + cb.watch(u64::MAX); // Maximum value + + assert_eq!(cb.total_requests.load(Ordering::Relaxed), 2); + assert_eq!(cb.slow_requests.load(Ordering::Relaxed), 2); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Closed as u64 + ); + } + + #[tokio::test] + async fn test_global_circuit_breaker_watch_request() { + // Test the global watch_request function + watch_request(500); // Fast request + watch_request(2000); // Slow request (assuming default threshold) + + // This test mainly verifies the function doesn't panic + // Since CIRCUIT_BREAKER is a singleton, we can't easily verify state + // without affecting other tests, but we can ensure it doesn't crash + } + + #[test] + fn test_circuit_breaker_window_timestamp_calculation() { + let cb = create_test_circuit_breaker(60, 2, 5, 1000); // 60 second window + + let now = Utc::now().timestamp(); + let window = cb.get_current_window_timestamp(); + + // Window should be the current timestamp divided by window size + assert_eq!(window, now / 60); + + // Test with different window sizes + let cb2 = create_test_circuit_breaker(3600, 2, 5, 1000); // 1 hour window + let window2 = cb2.get_current_window_timestamp(); + assert_eq!(window2, now / 3600); + } + + #[tokio::test] + async fn test_circuit_breaker_reset_calculation() { + let cb = create_test_circuit_breaker(10, 3, 5, 1000); + + let before_open = Utc::now().timestamp(); + cb.open(); + let after_open = Utc::now().timestamp(); + + let reset_time = cb.will_reset_at.load(Ordering::Relaxed); + let expected_min = before_open + (10 * 3); // watching_window * reset_window_num + let expected_max = after_open + (10 * 3); + + assert!(reset_time >= expected_min); + assert!(reset_time <= expected_max); + } + + #[tokio::test] + async fn test_circuit_breaker_doesnt_open_when_already_open() { + let cb = create_default_test_circuit_breaker(); + + // Open the circuit breaker + cb.open(); + let first_reset_time = cb.will_reset_at.load(Ordering::Relaxed); + + // Try to open again + cb.open(); + let second_reset_time = cb.will_reset_at.load(Ordering::Relaxed); + + // Reset time should be updated + assert!(second_reset_time >= first_reset_time); + assert_eq!( + cb.state.load(Ordering::Relaxed), + CircuitBreakerState::Open as u64 + ); + } +} diff --git a/src/service/compact/flatten.rs b/src/service/compact/flatten.rs index 05d46e31e5d..b145892330f 100644 --- a/src/service/compact/flatten.rs +++ b/src/service/compact/flatten.rs @@ -295,3 +295,467 @@ fn generate_vertical_partition_recordbatch( let schema = Arc::new(Schema::new(fields)); Ok(vec![RecordBatch::try_new(schema, cols)?]) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Array, ArrayRef, StringArray}, + record_batch::RecordBatch, + }; + use arrow_schema::{DataType, Field, Schema}; + + use super::*; + + // Helper function to create a record batch with the "_all" field + fn create_test_batch_with_all_field( + all_field_values: Vec>, + additional_fields: Vec<(&str, Vec>)>, + ) -> RecordBatch { + let mut fields = vec![Field::new("_all", DataType::Utf8, true)]; + let mut columns: Vec = vec![Arc::new(StringArray::from(all_field_values))]; + + for (field_name, field_values) in additional_fields { + fields.push(Field::new(field_name, DataType::Utf8, true)); + columns.push(Arc::new(StringArray::from(field_values))); + } + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).unwrap() + } + + // Helper function to create a record batch without the "_all" field + fn create_test_batch_without_all_field( + fields_and_values: Vec<(&str, Vec>)>, + ) -> RecordBatch { + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + for (field_name, field_values) in fields_and_values { + fields.push(Field::new(field_name, DataType::Utf8, true)); + columns.push(Arc::new(StringArray::from(field_values))); + } + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).unwrap() + } + + #[test] + fn test_generate_vertical_partition_recordbatch_empty_batches() { + let batches = vec![]; + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_no_all_field() { + let batch = create_test_batch_without_all_field(vec![ + ("field1", vec![Some("value1"), Some("value2")]), + ("field2", vec![Some("100"), Some("200")]), + ]); + let batches = vec![batch.clone()]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + // Should return the original batch unchanged since there's no "_all" field + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), batch.num_rows()); + assert_eq!(output_batch.num_columns(), batch.num_columns()); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_simple_json() { + let json_values = vec![ + Some(r#"{"name": "Alice", "age": "30"}"#), + Some(r#"{"name": "Bob", "age": "25"}"#), + ]; + + let batch = + create_test_batch_with_all_field(json_values, vec![("id", vec![Some("1"), Some("2")])]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 2); + + // Should have original fields + extracted JSON fields + let schema = output_batch.schema(); + let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect(); + + assert!(field_names.contains(&"_all")); // Should be null column now + assert!(field_names.contains(&"id")); // Original field + assert!(field_names.contains(&"name")); // Extracted from JSON + assert!(field_names.contains(&"age")); // Extracted from JSON + + // Check that "_all" field is now null + let all_field_index = schema.index_of("_all").unwrap(); + let all_column = output_batch.column(all_field_index); + assert!(all_column.is_null(0)); + assert!(all_column.is_null(1)); + + // Check extracted values + let name_index = schema.index_of("name").unwrap(); + let name_column = output_batch + .column(name_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_column.value(0), "Alice"); + assert_eq!(name_column.value(1), "Bob"); + + let age_index = schema.index_of("age").unwrap(); + let age_column = output_batch + .column(age_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(age_column.value(0), "30"); + assert_eq!(age_column.value(1), "25"); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_missing_fields() { + let json_values = vec![ + Some(r#"{"name": "Alice", "age": "30", "city": "NYC"}"#), + Some(r#"{"name": "Bob", "country": "USA"}"#), // Missing age and city, has country + Some(r#"{"age": "35"}"#), // Missing name, city, and country + ]; + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 3); + + let schema = output_batch.schema(); + + // Check name field + let name_index = schema.index_of("name").unwrap(); + let name_column = output_batch + .column(name_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_column.value(0), "Alice"); + assert_eq!(name_column.value(1), "Bob"); + assert!(name_column.is_null(2)); // Missing in third row + + // Check age field + let age_index = schema.index_of("age").unwrap(); + let age_column = output_batch + .column(age_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(age_column.value(0), "30"); + assert!(age_column.is_null(1)); // Missing in second row + assert_eq!(age_column.value(2), "35"); + + // Check city field + let city_index = schema.index_of("city").unwrap(); + let city_column = output_batch + .column(city_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(city_column.value(0), "NYC"); + assert!(city_column.is_null(1)); // Missing in second row + assert!(city_column.is_null(2)); // Missing in third row + + // Check country field + let country_index = schema.index_of("country").unwrap(); + let country_column = output_batch + .column(country_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(country_column.is_null(0)); // Missing in first row + assert_eq!(country_column.value(1), "USA"); + assert!(country_column.is_null(2)); // Missing in third row + } + + #[test] + fn test_generate_vertical_partition_recordbatch_empty_json() { + let json_values = vec![ + Some(""), // Empty string + Some("{}"), // Empty JSON object + Some(r#"{"name": "Alice"}"#), // Normal JSON + ]; + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 3); + + let schema = output_batch.schema(); + + // Should only have _all field and name field (name appears in row 3) + assert!(schema.index_of("name").is_ok()); + + let name_index = schema.index_of("name").unwrap(); + let name_column = output_batch + .column(name_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(name_column.is_null(0)); // Empty string -> empty object -> no name + assert!(name_column.is_null(1)); // Empty object -> no name + assert_eq!(name_column.value(2), "Alice"); // Has name + } + + #[test] + fn test_generate_vertical_partition_recordbatch_null_values_in_json() { + let json_values = vec![ + Some(r#"{"name": "Alice", "age": null, "city": "NYC"}"#), + Some(r#"{"name": null, "age": "25"}"#), + ]; + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 2); + + let schema = output_batch.schema(); + + // Check name field (Alice in first row, null in second) + let name_index = schema.index_of("name").unwrap(); + let name_column = output_batch + .column(name_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_column.value(0), "Alice"); + assert!(name_column.is_null(1)); // null in JSON becomes null in column + + // Check age field (null in first row, 25 in second) + let age_index = schema.index_of("age").unwrap(); + let age_column = output_batch + .column(age_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(age_column.is_null(0)); // null in JSON becomes null in column + assert_eq!(age_column.value(1), "25"); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_invalid_json() { + let json_values = vec![ + Some(r#"{"name": "Alice"}"#), // Valid JSON + Some(r#"invalid json"#), // Invalid JSON + ]; + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_err()); + + // Should return an error due to invalid JSON parsing + let error = result.unwrap_err(); + assert!(error.to_string().contains("parse all fields value error")); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_complex_json_values() { + let json_values = vec![ + Some(r#"{"user": {"name": "Alice", "details": {"age": 30}}, "active": true}"#), + Some(r#"{"score": 95.5, "tags": ["tag1", "tag2"], "metadata": {"version": "1.0"}}"#), + ]; + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 2); + + let schema = output_batch.schema(); + + // Complex JSON values should be converted to strings + let user_index = schema.index_of("user").unwrap(); + let user_column = output_batch + .column(user_index) + .as_any() + .downcast_ref::() + .unwrap(); + // The nested object should be converted to a string representation + assert!(!user_column.is_null(0)); + assert!(user_column.is_null(1)); // Not present in second row + } + + #[test] + fn test_generate_vertical_partition_recordbatch_zero_rows() { + let json_values: Vec> = vec![]; // No rows + + let batch = create_test_batch_with_all_field(json_values, vec![]); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert!(output_batches.is_empty()); // Should return empty vec when no rows + } + + #[test] + fn test_generate_vertical_partition_recordbatch_all_field_not_string_array() { + // Create a batch where the "_all" field is not a StringArray + let fields = vec![ + Field::new("_all", DataType::Int64, true), // Not a string array + Field::new("id", DataType::Utf8, true), + ]; + let schema = Arc::new(Schema::new(fields)); + + let columns: Vec = vec![ + Arc::new(arrow::array::Int64Array::from(vec![Some(1), Some(2)])), + Arc::new(StringArray::from(vec![Some("a"), Some("b")])), + ]; + + let batch = RecordBatch::try_new(schema, columns).unwrap(); + let batches = vec![batch.clone()]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + // Should return the original batch unchanged since "_all" is not a StringArray + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), batch.num_rows()); + assert_eq!(output_batch.num_columns(), batch.num_columns()); + } + + #[test] + fn test_generate_vertical_partition_recordbatch_multiple_batches() { + let batch1 = create_test_batch_with_all_field( + vec![Some(r#"{"name": "Alice", "age": "30"}"#)], + vec![("id", vec![Some("1")])], + ); + + let batch2 = create_test_batch_with_all_field( + vec![Some(r#"{"name": "Bob", "city": "NYC"}"#)], + vec![("id", vec![Some("2")])], + ); + + let batches = vec![batch1, batch2]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); // Should concatenate into one batch + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 2); // Combined rows from both batches + + let schema = output_batch.schema(); + + // Should have all fields from both batches + assert!(schema.index_of("name").is_ok()); + assert!(schema.index_of("age").is_ok()); + assert!(schema.index_of("city").is_ok()); + assert!(schema.index_of("id").is_ok()); + + // Check that missing fields are properly handled with nulls + let age_index = schema.index_of("age").unwrap(); + let age_column = output_batch + .column(age_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(age_column.value(0), "30"); // From first batch + assert!(age_column.is_null(1)); // Missing in second batch + + let city_index = schema.index_of("city").unwrap(); + let city_column = output_batch + .column(city_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(city_column.is_null(0)); // Missing in first batch + assert_eq!(city_column.value(1), "NYC"); // From second batch + } + + #[test] + fn test_generate_vertical_partition_recordbatch_preserve_original_fields() { + let batch = create_test_batch_with_all_field( + vec![Some(r#"{"extracted": "value"}"#)], + vec![ + ("original_field1", vec![Some("orig1")]), + ("original_field2", vec![Some("orig2")]), + ], + ); + let batches = vec![batch]; + + let result = generate_vertical_partition_recordbatch(&batches); + assert!(result.is_ok()); + + let output_batches = result.unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + let schema = output_batch.schema(); + + // Should preserve original fields + assert!(schema.index_of("original_field1").is_ok()); + assert!(schema.index_of("original_field2").is_ok()); + + // Should add extracted fields + assert!(schema.index_of("extracted").is_ok()); + + // Check values are preserved + let orig1_index = schema.index_of("original_field1").unwrap(); + let orig1_column = output_batch + .column(orig1_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(orig1_column.value(0), "orig1"); + + let extracted_index = schema.index_of("extracted").unwrap(); + let extracted_column = output_batch + .column(extracted_index) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(extracted_column.value(0), "value"); + } +} diff --git a/src/service/compact/merge.rs b/src/service/compact/merge.rs index 8d6c89f4887..99f47b77013 100644 --- a/src/service/compact/merge.rs +++ b/src/service/compact/merge.rs @@ -1239,7 +1239,7 @@ pub fn generate_inverted_idx_recordbatch( .map(|f| f.name()) .collect::>(); - let mut inverted_idx_columns = if !full_text_search_fields.is_empty() { + let mut inverted_idx_columns: Vec = if !full_text_search_fields.is_empty() { full_text_search_fields.to_vec() } else { config::SQL_FULL_TEXT_SEARCH_FIELDS.to_vec() @@ -1446,3 +1446,512 @@ fn sort_by_time_range(mut file_list: Vec) -> Vec { } files } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema}; + use config::meta::stream::{FileKey, FileMeta}; + use hashbrown::HashMap; + + use super::*; + + // Helper function to create test FileKey + fn create_file_key(key: &str, min_ts: i64, max_ts: i64, original_size: i64) -> FileKey { + FileKey { + id: 0, + account: "test_account".to_string(), + key: key.to_string(), + meta: FileMeta { + min_ts, + max_ts, + records: 100, + original_size, + compressed_size: original_size / 2, // assume 50% compression + index_size: 0, + flattened: false, + }, + deleted: false, + segment_ids: None, + } + } + + #[test] + fn test_sort_by_time_range_edge_case_adjacent_files() { + let files = vec![ + create_file_key("file1.parquet", 1000, 2000, 1024), + create_file_key("file2.parquet", 2000, 3000, 1024), // exactly adjacent + create_file_key("file3.parquet", 3000, 4000, 1024), // exactly adjacent + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + + // Adjacent files should be able to be in the same group + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file2.parquet"); + assert_eq!(result[2].key, "file3.parquet"); + } + + // Test helper function creation + #[test] + fn test_create_file_key_helper() { + let file_key = create_file_key("test.parquet", 1000, 2000, 1024); + assert_eq!(file_key.key, "test.parquet"); + assert_eq!(file_key.meta.min_ts, 1000); + assert_eq!(file_key.meta.max_ts, 2000); + assert_eq!(file_key.meta.original_size, 1024); + assert_eq!(file_key.meta.compressed_size, 512); // 50% compression + assert_eq!(file_key.meta.records, 100); + assert!(!file_key.meta.flattened); + assert_eq!(file_key.id, 0); + assert_eq!(file_key.account, "test_account"); + assert!(!file_key.deleted); + assert!(file_key.segment_ids.is_none()); + } + + // Boundary tests for sort_by_time_range + #[test] + fn test_sort_by_time_range_negative_timestamps() { + let files = vec![ + create_file_key("file1.parquet", -2000, -1000, 1024), + create_file_key("file2.parquet", -1000, 0, 1024), + create_file_key("file3.parquet", 0, 1000, 1024), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file2.parquet"); + assert_eq!(result[2].key, "file3.parquet"); + } + + #[test] + fn test_sort_by_time_range_large_timestamps() { + let files = vec![ + create_file_key("file1.parquet", i64::MAX - 2000, i64::MAX - 1000, 1024), + create_file_key("file2.parquet", i64::MAX - 1000, i64::MAX, 1024), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 2); + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file2.parquet"); + } + + // Edge case where min_ts equals max_ts (point in time) + #[test] + fn test_sort_by_time_range_point_in_time() { + let files = vec![ + create_file_key("file1.parquet", 1000, 1000, 1024), // Point in time + create_file_key("file2.parquet", 1000, 2000, 1024), // Overlaps with file1 + create_file_key("file3.parquet", 2000, 2000, 1024), // Point in time, adjacent + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + + // Verify all files are present + let keys: Vec<&String> = result.iter().map(|f| &f.key).collect(); + assert!(keys.contains(&&"file1.parquet".to_string())); + assert!(keys.contains(&&"file2.parquet".to_string())); + assert!(keys.contains(&&"file3.parquet".to_string())); + } + + #[test] + fn test_sort_by_time_range_many_files_random_order() { + let files = vec![ + create_file_key("file_f.parquet", 6000, 7000, 1024), + create_file_key("file_b.parquet", 2000, 3000, 1024), + create_file_key("file_d.parquet", 4000, 5000, 1024), + create_file_key("file_a.parquet", 1000, 2000, 1024), + create_file_key("file_c.parquet", 3000, 4000, 1024), + create_file_key("file_e.parquet", 5000, 6000, 1024), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 6); + + // Should be sorted by min_ts (all adjacent files) + assert_eq!(result[0].key, "file_a.parquet"); + assert_eq!(result[1].key, "file_b.parquet"); + assert_eq!(result[2].key, "file_c.parquet"); + assert_eq!(result[3].key, "file_d.parquet"); + assert_eq!(result[4].key, "file_e.parquet"); + assert_eq!(result[5].key, "file_f.parquet"); + } + + #[test] + fn test_sort_by_time_range_gaps_between_files() { + let files = vec![ + create_file_key("file1.parquet", 1000, 2000, 1024), + create_file_key("file2.parquet", 5000, 6000, 1024), // gap after file1 + create_file_key("file3.parquet", 3000, 4000, 1024), // fits in gap + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + + // Should be sorted by min_ts + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file3.parquet"); + assert_eq!(result[2].key, "file2.parquet"); + } + + #[test] + fn test_sort_by_time_range_empty_list() { + let files = vec![]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_sort_by_time_range_single_file() { + let files = vec![create_file_key("file1.parquet", 1000, 2000, 1024)]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 1); + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[0].meta.min_ts, 1000); + assert_eq!(result[0].meta.max_ts, 2000); + } + + #[test] + fn test_sort_by_time_range_already_sorted_non_overlapping() { + let files = vec![ + create_file_key("file1.parquet", 1000, 2000, 1024), + create_file_key("file2.parquet", 2000, 3000, 1024), + create_file_key("file3.parquet", 3000, 4000, 1024), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file2.parquet"); + assert_eq!(result[2].key, "file3.parquet"); + } + + #[test] + fn test_sort_by_time_range_unsorted_non_overlapping() { + let files = vec![ + create_file_key("file3.parquet", 3000, 4000, 1024), + create_file_key("file1.parquet", 1000, 2000, 1024), + create_file_key("file2.parquet", 2000, 3000, 1024), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + // Should be sorted by min_ts + assert_eq!(result[0].key, "file1.parquet"); + assert_eq!(result[1].key, "file2.parquet"); + assert_eq!(result[2].key, "file3.parquet"); + } + + #[test] + fn test_sort_by_time_range_overlapping_files() { + let files = vec![ + create_file_key("file1.parquet", 1000, 2500, 1024), // overlaps with file2 + create_file_key("file2.parquet", 2000, 3000, 1024), // overlaps with file1 + create_file_key("file3.parquet", 3500, 4000, 1024), // non-overlapping + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + + // First file should be file1 (min_ts = 1000) + assert_eq!(result[0].key, "file1.parquet"); + + // Due to overlapping, file3 should come next (can fit in same group as file1) + // file2 would be in a separate group since it overlaps with file1 + let mut found_file2 = false; + let mut found_file3 = false; + for file in &result { + if file.key == "file2.parquet" { + found_file2 = true; + } + if file.key == "file3.parquet" { + found_file3 = true; + } + } + assert!(found_file2); + assert!(found_file3); + } + + #[test] + fn test_sort_by_time_range_complex_overlapping() { + let files = vec![ + create_file_key("file1.parquet", 1000, 1500, 1024), + create_file_key("file2.parquet", 1200, 1800, 1024), // overlaps with file1 + create_file_key("file3.parquet", 1600, 2000, 1024), // overlaps with file2 + create_file_key("file4.parquet", 2000, 2500, 1024), // adjacent to file3 + create_file_key("file5.parquet", 3000, 3500, 1024), // separate group + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 5); + + // Verify all files are present + let keys: Vec<&String> = result.iter().map(|f| &f.key).collect(); + assert!(keys.contains(&&"file1.parquet".to_string())); + assert!(keys.contains(&&"file2.parquet".to_string())); + assert!(keys.contains(&&"file3.parquet".to_string())); + assert!(keys.contains(&&"file4.parquet".to_string())); + assert!(keys.contains(&&"file5.parquet".to_string())); + } + + #[test] + fn test_sort_by_time_range_identical_timestamps() { + let files = vec![ + create_file_key("file1.parquet", 1000, 2000, 1024), + create_file_key("file2.parquet", 1000, 2000, 512), + create_file_key("file3.parquet", 1000, 2000, 2048), + ]; + let result = sort_by_time_range(files); + assert_eq!(result.len(), 3); + + // All files have same timestamp, so they should all be in separate groups + // due to overlap, but ordering should be maintained based on original order after sorting + let keys: Vec<&String> = result.iter().map(|f| &f.key).collect(); + assert_eq!(keys.len(), 3); + assert!(keys.contains(&&"file1.parquet".to_string())); + assert!(keys.contains(&&"file2.parquet".to_string())); + assert!(keys.contains(&&"file3.parquet".to_string())); + } + + // Test cases for generate_schema_diff function + #[test] + fn test_generate_schema_diff_no_differences() { + // Create a schema with fields + let field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); + let field2 = Arc::new(Field::new("field2", DataType::Int64, false)); + let schema = Schema::new(vec![field1.clone(), field2.clone()]); + + // Create latest schema map with same types + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(field1.name(), &field1); + latest_schema_map.insert(field2.name(), &field2); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + assert!(diff.is_empty()); + } + + #[test] + fn test_generate_schema_diff_with_type_differences() { + // Create original schema + let field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); + let field2 = Arc::new(Field::new("field2", DataType::Int32, false)); + let schema = Schema::new(vec![field1.clone(), field2.clone()]); + + // Create latest schema with different types + let latest_field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); // Same type + let latest_field2 = Arc::new(Field::new("field2", DataType::Int64, false)); // Different type + + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_field1.name(), &latest_field1); + latest_schema_map.insert(latest_field2.name(), &latest_field2); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + assert_eq!(diff.len(), 1); + assert!(diff.contains_key("field2")); + assert_eq!(diff.get("field2").unwrap(), &DataType::Int64); + } + + #[test] + fn test_generate_schema_diff_multiple_type_differences() { + // Create original schema + let field1 = Arc::new(Field::new("field1", DataType::Int32, true)); + let field2 = Arc::new(Field::new("field2", DataType::Float32, false)); + let field3 = Arc::new(Field::new("field3", DataType::Boolean, true)); + let schema = Schema::new(vec![field1.clone(), field2.clone(), field3.clone()]); + + // Create latest schema with different types + let latest_field1 = Arc::new(Field::new("field1", DataType::Int64, true)); // Different + let latest_field2 = Arc::new(Field::new("field2", DataType::Float64, false)); // Different + let latest_field3 = Arc::new(Field::new("field3", DataType::Boolean, true)); // Same + + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_field1.name(), &latest_field1); + latest_schema_map.insert(latest_field2.name(), &latest_field2); + latest_schema_map.insert(latest_field3.name(), &latest_field3); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + assert_eq!(diff.len(), 2); + assert!(diff.contains_key("field1")); + assert!(diff.contains_key("field2")); + assert!(!diff.contains_key("field3")); // Same type, should not be in diff + assert_eq!(diff.get("field1").unwrap(), &DataType::Int64); + assert_eq!(diff.get("field2").unwrap(), &DataType::Float64); + } + + #[test] + fn test_generate_schema_diff_field_missing_in_latest() { + // Create original schema + let field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); + let field2 = Arc::new(Field::new("field2", DataType::Int64, false)); + let schema = Schema::new(vec![field1.clone(), field2.clone()]); + + // Create latest schema map missing field2 + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(field1.name(), &field1); + // field2 is missing from latest_schema_map + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + // Should be empty since field2 is not found in latest_schema_map + assert!(diff.is_empty()); + } + + #[test] + fn test_generate_schema_diff_empty_schema() { + let schema = Schema::empty(); + let latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + assert!(diff.is_empty()); + } + + #[test] + fn test_generate_schema_diff_empty_latest_schema_map() { + let field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); + let schema = Schema::new(vec![field1]); + let latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + assert!(diff.is_empty()); // No fields in latest_schema_map to compare against + } + + #[test] + fn test_generate_schema_diff_complex_data_types() { + // Test with complex data types like List, Struct, etc. + let list_field = Arc::new(Field::new( + "list_field", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )); + let timestamp_field = Arc::new(Field::new( + "timestamp_field", + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), + false, + )); + let schema = Schema::new(vec![list_field.clone(), timestamp_field.clone()]); + + // Create latest schema with different complex types + let latest_list_field = Arc::new(Field::new( + "list_field", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), /* Different inner type */ + true, + )); + let latest_timestamp_field = Arc::new(Field::new( + "timestamp_field", + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None), // Different time unit + false, + )); + + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_list_field.name(), &latest_list_field); + latest_schema_map.insert(latest_timestamp_field.name(), &latest_timestamp_field); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + assert_eq!(diff.len(), 2); + assert!(diff.contains_key("list_field")); + assert!(diff.contains_key("timestamp_field")); + } + + #[test] + fn test_generate_schema_diff_nullable_differences() { + // Test that nullable differences are detected when data types are same + let field1 = Arc::new(Field::new("field1", DataType::Utf8, false)); // Not nullable + let schema = Schema::new(vec![field1.clone()]); + + let latest_field1 = Arc::new(Field::new("field1", DataType::Utf8, true)); // Nullable + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_field1.name(), &latest_field1); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + // Note: This test verifies current behavior - the function only compares data_type(), + // not the nullable property. If nullable differences should be detected, + // the function would need to be modified. + assert!(diff.is_empty()); // Current implementation doesn't detect nullable differences + } + + #[test] + fn test_generate_schema_diff_mixed_scenario() { + // Create a realistic mixed scenario + let original_fields = vec![ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(Field::new("score", DataType::Float32, true)), + Arc::new(Field::new("active", DataType::Boolean, false)), + Arc::new(Field::new("metadata", DataType::Utf8, true)), + ]; + let schema = Schema::new(original_fields); + + // Latest schema with some changes + let latest_id = Arc::new(Field::new("id", DataType::Int64, false)); // Changed type + let latest_name = Arc::new(Field::new("name", DataType::Utf8, true)); // Same + let latest_score = Arc::new(Field::new("score", DataType::Float64, true)); // Changed type + let latest_active = Arc::new(Field::new("active", DataType::Boolean, false)); // Same + // metadata field missing from latest schema + + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_id.name(), &latest_id); + latest_schema_map.insert(latest_name.name(), &latest_name); + latest_schema_map.insert(latest_score.name(), &latest_score); + latest_schema_map.insert(latest_active.name(), &latest_active); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + assert_eq!(diff.len(), 2); + assert!(diff.contains_key("id")); + assert!(diff.contains_key("score")); + assert!(!diff.contains_key("name")); // Same type + assert!(!diff.contains_key("active")); // Same type + assert!(!diff.contains_key("metadata")); // Missing from latest + + assert_eq!(diff.get("id").unwrap(), &DataType::Int64); + assert_eq!(diff.get("score").unwrap(), &DataType::Float64); + } + + #[test] + fn test_generate_schema_diff_decimal_types() { + // Test with decimal types that have precision and scale + let decimal_field = Arc::new(Field::new( + "decimal_field", + DataType::Decimal128(10, 2), // precision=10, scale=2 + true, + )); + let schema = Schema::new(vec![decimal_field.clone()]); + + let latest_decimal_field = Arc::new(Field::new( + "decimal_field", + DataType::Decimal128(12, 4), // Different precision and scale + true, + )); + + let mut latest_schema_map: HashMap<&String, &Arc> = HashMap::new(); + latest_schema_map.insert(latest_decimal_field.name(), &latest_decimal_field); + + let result = generate_schema_diff(&schema, &latest_schema_map); + assert!(result.is_ok()); + let diff = result.unwrap(); + + assert_eq!(diff.len(), 1); + assert!(diff.contains_key("decimal_field")); + assert_eq!( + diff.get("decimal_field").unwrap(), + &DataType::Decimal128(12, 4) + ); + } +} diff --git a/src/service/search/cache/multi.rs b/src/service/search/cache/multi.rs index b71b178711f..dd7be2b1eff 100644 --- a/src/service/search/cache/multi.rs +++ b/src/service/search/cache/multi.rs @@ -320,3 +320,221 @@ fn get_allowed_up_to( } } } + +#[cfg(test)] +mod tests { + use infra::cache::meta::ResultCacheMeta; + + use super::*; + + fn create_test_cache_meta(start_time: i64, end_time: i64) -> ResultCacheMeta { + ResultCacheMeta { + start_time, + end_time, + is_aggregate: false, + is_descending: false, + } + } + + fn create_test_cache_request(q_start_time: i64, q_end_time: i64) -> CacheQueryRequest { + CacheQueryRequest { + q_start_time, + q_end_time, + is_aggregate: false, + ts_column: "_timestamp".to_string(), + discard_interval: 0, + is_descending: false, + } + } + + #[test] + fn test_select_cache_meta_overlap_strategy() { + let strategy = ResultCacheSelectionStrategy::Overlap; + + // Test case 1: Complete overlap + let meta = create_test_cache_meta(100, 200); + let req = create_test_cache_request(150, 180); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 30); // overlap: 180 - 150 = 30 + + // Test case 2: Partial overlap (cache starts before query) + let meta = create_test_cache_meta(50, 150); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 50); // overlap: 150 - 100 = 50 + + // Test case 3: Partial overlap (cache starts after query) + let meta = create_test_cache_meta(150, 250); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 50); // overlap: 200 - 150 = 50 + + // Test case 4: No overlap (cache is before query) + let meta = create_test_cache_meta(50, 80); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, -20); // overlap: 80 - 100 = -20 (negative means no overlap) + + // Test case 5: No overlap (cache is after query) + let meta = create_test_cache_meta(250, 300); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, -50); // overlap: 200 - 250 = -50 (negative means no overlap) + + // Test case 6: Cache completely contains query + let meta = create_test_cache_meta(50, 300); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 100); // overlap: 200 - 100 = 100 + } + + #[test] + fn test_select_cache_meta_duration_strategy() { + let strategy = ResultCacheSelectionStrategy::Duration; + let req = create_test_cache_request(100, 200); // Query parameters don't affect duration strategy + + // Test case 1: Short duration cache + let meta = create_test_cache_meta(100, 150); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 50); // duration: 150 - 100 = 50 + + // Test case 2: Long duration cache + let meta = create_test_cache_meta(50, 300); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 250); // duration: 300 - 50 = 250 + + // Test case 3: Zero duration cache (edge case) + let meta = create_test_cache_meta(100, 100); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 0); // duration: 100 - 100 = 0 + + // Test case 4: Negative duration (edge case - shouldn't happen but testing robustness) + let meta = create_test_cache_meta(200, 100); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, -100); // duration: 100 - 200 = -100 + } + + #[test] + fn test_select_cache_meta_both_strategy() { + let strategy = ResultCacheSelectionStrategy::Both; + + // Test case 1: 100% overlap (query completely within cache) + let meta = create_test_cache_meta(50, 300); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 40); // overlap: 100, cache_duration: 250, (100 * 100) / 250 = 40 + + // Test case 2: 50% overlap + let meta = create_test_cache_meta(100, 200); + let req = create_test_cache_request(150, 250); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 50); // overlap: 50, cache_duration: 100, (50 * 100) / 100 = 50 + + // Test case 3: 25% overlap + let meta = create_test_cache_meta(100, 400); + let req = create_test_cache_request(350, 450); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 16); // overlap: 50, cache_duration: 300, (50 * 100) / 300 = 16.66... = 16 (integer division) + + // Test case 4: Zero duration cache (edge case) + let meta = create_test_cache_meta(100, 100); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 0); // Special case: cache_duration = 0, returns 0 + + // Test case 5: No overlap + let meta = create_test_cache_meta(100, 150); + let req = create_test_cache_request(200, 300); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, -100); // overlap: -50, cache_duration: 50, (-50 * 100) / 50 = -100 + + // Test case 6: Perfect match (cache and query have same time range) + let meta = create_test_cache_meta(100, 200); + let req = create_test_cache_request(100, 200); + let score = select_cache_meta(&meta, &req, &strategy); + assert_eq!(score, 100); // overlap: 100, cache_duration: 100, (100 * 100) / 100 = 100 + } + + #[test] + fn test_select_cache_meta_edge_cases() { + // Test with zero time ranges + let meta = create_test_cache_meta(0, 0); + let req = create_test_cache_request(0, 0); + + let overlap_score = select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Overlap); + assert_eq!(overlap_score, 0); + + let duration_score = + select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Duration); + assert_eq!(duration_score, 0); + + let both_score = select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Both); + assert_eq!(both_score, 0); + } + + #[test] + fn test_select_cache_meta_large_numbers() { + // Test with large timestamp values (microseconds) + let meta = create_test_cache_meta(1_640_995_200_000_000, 1_640_995_260_000_000); // 60 seconds + let req = create_test_cache_request(1_640_995_230_000_000, 1_640_995_290_000_000); // 60 seconds, 30s overlap + + let overlap_score = select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Overlap); + assert_eq!(overlap_score, 30_000_000); // 30 seconds in microseconds + + let duration_score = + select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Duration); + assert_eq!(duration_score, 60_000_000); // 60 seconds in microseconds + + let both_score = select_cache_meta(&meta, &req, &ResultCacheSelectionStrategy::Both); + assert_eq!(both_score, 50); // (30_000_000 * 100) / 60_000_000 = 50 + } + + #[test] + fn test_select_cache_meta_comparison_scenarios() { + let req = create_test_cache_request(100, 200); + + // Cache 1: Shorter duration but worse overlap + let cache1 = create_test_cache_meta(120, 180); + // Cache 2: Longer duration but better overlap + let cache2 = create_test_cache_meta(50, 250); + + // Overlap strategy should prefer cache2 (better overlap) + let overlap_score1 = + select_cache_meta(&cache1, &req, &ResultCacheSelectionStrategy::Overlap); + let overlap_score2 = + select_cache_meta(&cache2, &req, &ResultCacheSelectionStrategy::Overlap); + assert!(overlap_score2 > overlap_score1); // 100 > 60 + + // Duration strategy should prefer cache2 (longer duration) + let duration_score1 = + select_cache_meta(&cache1, &req, &ResultCacheSelectionStrategy::Duration); + let duration_score2 = + select_cache_meta(&cache2, &req, &ResultCacheSelectionStrategy::Duration); + assert!(duration_score2 > duration_score1); // 200 > 60 + + // Both strategy balances overlap efficiency + let both_score1 = select_cache_meta(&cache1, &req, &ResultCacheSelectionStrategy::Both); + let both_score2 = select_cache_meta(&cache2, &req, &ResultCacheSelectionStrategy::Both); + assert!(both_score1 > both_score2); // cache1: 100% efficiency (60/60), cache2: 50% efficiency (100/200) + } + + #[test] + fn test_select_cache_meta_all_strategies_consistency() { + // Test that all strategies return expected types and don't panic + let meta = create_test_cache_meta(100, 200); + let req = create_test_cache_request(150, 250); + + // All strategies should return i64 values without panicking + let strategies = [ + ResultCacheSelectionStrategy::Overlap, + ResultCacheSelectionStrategy::Duration, + ResultCacheSelectionStrategy::Both, + ]; + + for strategy in strategies.iter() { + let score = select_cache_meta(&meta, &req, strategy); + // Score should be a valid i64 (no overflow/underflow for reasonable inputs) + assert!(score >= i64::MIN && score <= i64::MAX); + } + } +} diff --git a/src/service/search/datafusion/distributed_plan/streaming_aggs_exec.rs b/src/service/search/datafusion/distributed_plan/streaming_aggs_exec.rs index 9699693a2ff..9f61734a337 100644 --- a/src/service/search/datafusion/distributed_plan/streaming_aggs_exec.rs +++ b/src/service/search/datafusion/distributed_plan/streaming_aggs_exec.rs @@ -381,3 +381,94 @@ impl StreamingIdItem { self.start_ok && self.end_ok } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Int32Array, RecordBatch}, + datatypes::{DataType, Field, Schema}, + }; + + use super::*; + + #[test] + fn test_streaming_aggs_cache_insert_max_entries() { + // Create a cache with max_entries = 2 + let cache = StreamingAggsCache::new(2); + + // Create test schema and record batches + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![4, 5, 6]))], + ) + .unwrap(); + + let batch3 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![7, 8, 9]))], + ) + .unwrap(); + + // Insert first entry + cache.insert("key1".to_string(), batch1); + assert!(cache.get("key1").is_some()); + assert_eq!(cache.data.len(), 1); + + // Insert second entry + cache.insert("key2".to_string(), batch2); + assert!(cache.get("key1").is_some()); + assert!(cache.get("key2").is_some()); + assert_eq!(cache.data.len(), 2); + + // Insert third entry - should evict the first (oldest) entry + cache.insert("key3".to_string(), batch3); + assert!(cache.get("key1").is_none()); // Should be evicted + assert!(cache.get("key2").is_some()); + assert!(cache.get("key3").is_some()); + assert_eq!(cache.data.len(), 2); // Should still be 2 (max_entries) + + // Verify that the cacher queue length matches max_entries + let cacher_len = cache.cacher.lock().len(); + assert_eq!(cacher_len, 2); + } + + #[test] + fn test_streaming_aggs_cache_insert_within_limit() { + // Create a cache with max_entries = 5 + let cache = StreamingAggsCache::new(5); + + // Create test schema and record batch + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + // Insert 3 entries (within limit) + cache.insert("key1".to_string(), batch.clone()); + cache.insert("key2".to_string(), batch.clone()); + cache.insert("key3".to_string(), batch.clone()); + + // All entries should be present + assert!(cache.get("key1").is_some()); + assert!(cache.get("key2").is_some()); + assert!(cache.get("key3").is_some()); + assert_eq!(cache.data.len(), 3); + + // Verify that the cacher queue length matches number of entries + let cacher_len = cache.cacher.lock().len(); + assert_eq!(cacher_len, 3); + } +} diff --git a/src/service/search/grpc/wal.rs b/src/service/search/grpc/wal.rs index 950f30fc6fe..4cea02d068a 100644 --- a/src/service/search/grpc/wal.rs +++ b/src/service/search/grpc/wal.rs @@ -561,11 +561,23 @@ async fn get_file_list_inner( let lock_guard = wal_lock.lock().await; // filter by pending delete - let files = crate::service::db::file_list::local::filter_by_pending_delete(files).await; + let mut files = crate::service::db::file_list::local::filter_by_pending_delete(files).await; if files.is_empty() { return Ok(vec![]); } + let files_num = files.len(); + files.sort_unstable(); + files.dedup(); + if files_num != files.len() { + log::warn!( + "[trace_id {}] wal->parquet->search: found duplicate files from {} to {}", + query.trace_id, + files_num, + files.len() + ); + } + // lock theses files wal::lock_files(&files); drop(lock_guard); diff --git a/src/service/search/index.rs b/src/service/search/index.rs index c52864fb96a..37037545780 100644 --- a/src/service/search/index.rs +++ b/src/service/search/index.rs @@ -663,3 +663,235 @@ fn is_blank_or_alphanumeric(s: &str) -> bool { s.chars() .all(|c| c.is_ascii_whitespace() || c.is_ascii_alphanumeric()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_condition_get_tantivy_fields_equal() { + let condition = Condition::Equal("field1".to_string(), "value1".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("field1")); + } + + #[test] + fn test_condition_get_tantivy_fields_in() { + let condition = Condition::In( + "field2".to_string(), + vec![ + "value1".to_string(), + "value2".to_string(), + "value3".to_string(), + ], + ); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("field2")); + } + + #[test] + fn test_condition_get_tantivy_fields_regex() { + let condition = Condition::Regex("field3".to_string(), "pattern.*".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("field3")); + } + + #[test] + fn test_condition_get_tantivy_fields_match_all() { + let condition = Condition::MatchAll("search_term".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains(INDEX_FIELD_NAME_FOR_ALL)); + } + + #[test] + fn test_condition_get_tantivy_fields_fuzzy_match_all() { + let condition = Condition::FuzzyMatchAll("search_term".to_string(), 2); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains(INDEX_FIELD_NAME_FOR_ALL)); + } + + #[test] + fn test_condition_get_tantivy_fields_all() { + let condition = Condition::All(); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 0); + } + + #[test] + fn test_condition_get_tantivy_fields_or_simple() { + let left = Condition::Equal("field1".to_string(), "value1".to_string()); + let right = Condition::Equal("field2".to_string(), "value2".to_string()); + let condition = Condition::Or(Box::new(left), Box::new(right)); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 2); + assert!(fields.contains("field1")); + assert!(fields.contains("field2")); + } + + #[test] + fn test_condition_get_tantivy_fields_and_simple() { + let left = Condition::Equal("field1".to_string(), "value1".to_string()); + let right = Condition::In("field2".to_string(), vec!["value1".to_string()]); + let condition = Condition::And(Box::new(left), Box::new(right)); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 2); + assert!(fields.contains("field1")); + assert!(fields.contains("field2")); + } + + #[test] + fn test_condition_get_tantivy_fields_or_with_overlap() { + let left = Condition::Equal("field1".to_string(), "value1".to_string()); + let right = Condition::Equal("field1".to_string(), "value2".to_string()); + let condition = Condition::Or(Box::new(left), Box::new(right)); + let fields = condition.get_tantivy_fields(); + + // Should only have one field since both conditions use the same field + assert_eq!(fields.len(), 1); + assert!(fields.contains("field1")); + } + + #[test] + fn test_condition_get_tantivy_fields_and_with_overlap() { + let left = Condition::Equal("field1".to_string(), "value1".to_string()); + let right = Condition::Regex("field1".to_string(), "pattern.*".to_string()); + let condition = Condition::And(Box::new(left), Box::new(right)); + let fields = condition.get_tantivy_fields(); + + // Should only have one field since both conditions use the same field + assert_eq!(fields.len(), 1); + assert!(fields.contains("field1")); + } + + #[test] + fn test_condition_get_tantivy_fields_nested_complex() { + // Create a complex nested condition: (field1 = value1 OR field2 = value2) AND (field3 = + // value3 OR match_all(term)) + let left_or = Condition::Or( + Box::new(Condition::Equal("field1".to_string(), "value1".to_string())), + Box::new(Condition::Equal("field2".to_string(), "value2".to_string())), + ); + let right_or = Condition::Or( + Box::new(Condition::Equal("field3".to_string(), "value3".to_string())), + Box::new(Condition::MatchAll("search_term".to_string())), + ); + let condition = Condition::And(Box::new(left_or), Box::new(right_or)); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 4); + assert!(fields.contains("field1")); + assert!(fields.contains("field2")); + assert!(fields.contains("field3")); + assert!(fields.contains(INDEX_FIELD_NAME_FOR_ALL)); + } + + #[test] + fn test_condition_get_tantivy_fields_all_types_mixed() { + // Test with all different condition types mixed together + let equal_cond = Condition::Equal("equal_field".to_string(), "value".to_string()); + let in_cond = Condition::In("in_field".to_string(), vec!["val1".to_string()]); + let regex_cond = Condition::Regex("regex_field".to_string(), "pattern.*".to_string()); + let match_all_cond = Condition::MatchAll("search_term".to_string()); + let fuzzy_match_cond = Condition::FuzzyMatchAll("fuzzy_term".to_string(), 1); + let all_cond = Condition::All(); + + // Create nested structure: ((equal OR in) AND (regex OR match_all)) OR (fuzzy_match_all AND + // all) + let left_or = Condition::Or(Box::new(equal_cond), Box::new(in_cond)); + let right_or = Condition::Or(Box::new(regex_cond), Box::new(match_all_cond)); + let left_and = Condition::And(Box::new(left_or), Box::new(right_or)); + let right_and = Condition::And(Box::new(fuzzy_match_cond), Box::new(all_cond)); + let condition = Condition::Or(Box::new(left_and), Box::new(right_and)); + + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 4); // equal_field, in_field, regex_field, _all (match_all and fuzzy_match_all both use _all, so deduplicated) + assert!(fields.contains("equal_field")); + assert!(fields.contains("in_field")); + assert!(fields.contains("regex_field")); + assert!(fields.contains(INDEX_FIELD_NAME_FOR_ALL)); + } + + #[test] + fn test_condition_get_tantivy_fields_empty_field_names() { + let condition = Condition::Equal("".to_string(), "value".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("")); + } + + #[test] + fn test_condition_get_tantivy_fields_special_characters() { + let condition = Condition::Equal("field.with.dots".to_string(), "value".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("field.with.dots")); + } + + #[test] + fn test_condition_get_tantivy_fields_unicode_field_names() { + let condition = Condition::Equal("поле".to_string(), "значение".to_string()); + let fields = condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 1); + assert!(fields.contains("поле")); + } + + #[test] + fn test_index_condition_get_tantivy_fields() { + let mut index_condition = IndexCondition::new(); + index_condition.add_condition(Condition::Equal("field1".to_string(), "value1".to_string())); + index_condition.add_condition(Condition::MatchAll("search_term".to_string())); + index_condition.add_condition(Condition::In( + "field2".to_string(), + vec!["val1".to_string()], + )); + + let fields = index_condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 3); + assert!(fields.contains("field1")); + assert!(fields.contains("field2")); + assert!(fields.contains(INDEX_FIELD_NAME_FOR_ALL)); + } + + #[test] + fn test_index_condition_get_tantivy_fields_empty() { + let index_condition = IndexCondition::new(); + let fields = index_condition.get_tantivy_fields(); + + assert_eq!(fields.len(), 0); + } + + #[test] + fn test_index_condition_get_tantivy_fields_duplicate_fields() { + let mut index_condition = IndexCondition::new(); + index_condition.add_condition(Condition::Equal("field1".to_string(), "value1".to_string())); + index_condition.add_condition(Condition::Equal("field1".to_string(), "value2".to_string())); + index_condition.add_condition(Condition::Regex( + "field1".to_string(), + "pattern.*".to_string(), + )); + + let fields = index_condition.get_tantivy_fields(); + + // Should deduplicate the field names + assert_eq!(fields.len(), 1); + assert!(fields.contains("field1")); + } +} diff --git a/src/service/search/search_stream.rs b/src/service/search/search_stream.rs index 5e162d52215..9262e8e05c0 100644 --- a/src/service/search/search_stream.rs +++ b/src/service/search/search_stream.rs @@ -1421,3 +1421,422 @@ pub fn get_top_k_values( Ok((top_k_values, result_count as u64)) } + +#[cfg(test)] +mod tests { + use config::meta::{ + search::{Query, Request, SearchEventType}, + sql::OrderBy, + stream::StreamType, + }; + use serde_json::json; + use tokio::sync::mpsc; + + use super::*; + + fn create_sample_hits() -> Vec { + vec![ + json!({ + "zo_sql_key": "apple", + "zo_sql_num": 100 + }), + json!({ + "zo_sql_key": "banana", + "zo_sql_num": 75 + }), + json!({ + "zo_sql_key": "cherry", + "zo_sql_num": 150 + }), + json!({ + "zo_sql_key": "date", + "zo_sql_num": 25 + }), + json!({ + "zo_sql_key": "elderberry", + "zo_sql_num": 200 + }), + ] + } + + fn create_empty_hits() -> Vec { + vec![] + } + + fn create_hits_without_required_fields() -> Vec { + vec![ + json!({ + "some_field": "value1" + }), + json!({ + "another_field": "value2" + }), + ] + } + + #[test] + fn test_get_top_k_values_with_count_sorting() { + let hits = create_sample_hits(); + let field = "test_field"; + let top_k = 3; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 3); + + // Verify the structure of the response + if let Some(Value::Object(field_obj)) = top_k_values.first() { + assert_eq!( + field_obj.get("field"), + Some(&Value::String("test_field".to_string())) + ); + + if let Some(Value::Array(values)) = field_obj.get("values") { + assert_eq!(values.len(), 3); + + // Verify sorting by count (descending) + if let Some(Value::Object(first_item)) = values.first() { + assert_eq!( + first_item.get("zo_sql_key"), + Some(&Value::String("elderberry".to_string())) + ); + assert_eq!( + first_item.get("zo_sql_num"), + Some(&Value::Number(200.into())) + ); + } + } + } + } + + #[test] + fn test_get_top_k_values_with_alphabetical_sorting() { + let hits = create_sample_hits(); + let field = "test_field"; + let top_k = 3; + let no_count = true; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 3); + + // Verify alphabetical sorting + if let Some(Value::Object(field_obj)) = top_k_values.first() { + if let Some(Value::Array(values)) = field_obj.get("values") { + assert_eq!(values.len(), 3); + + // Verify alphabetical order + if let Some(Value::Object(first_item)) = values.first() { + assert_eq!( + first_item.get("zo_sql_key"), + Some(&Value::String("apple".to_string())) + ); + } + if let Some(Value::Object(second_item)) = values.get(1) { + assert_eq!( + second_item.get("zo_sql_key"), + Some(&Value::String("banana".to_string())) + ); + } + } + } + } + + #[test] + fn test_get_top_k_values_empty_field() { + let hits = create_sample_hits(); + let field = ""; + let top_k = 3; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_err()); + + if let Err(error) = result { + assert!(error.to_string().contains("field is empty")); + } + } + + #[test] + fn test_get_top_k_values_empty_hits() { + let hits = create_empty_hits(); + let field = "test_field"; + let top_k = 3; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 0); + } + + #[test] + fn test_get_top_k_values_missing_fields() { + let hits = create_hits_without_required_fields(); + let field = "test_field"; + let top_k = 3; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 2); // Should handle missing fields gracefully + } + + #[test] + fn test_get_top_k_values_zero_top_k() { + let hits = create_sample_hits(); + let field = "test_field"; + let top_k = 0; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 0); + } + + #[test] + fn test_get_top_k_values_large_top_k() { + let hits = create_sample_hits(); + let field = "test_field"; + let top_k = 100; // Larger than available hits + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 5); // Should return all available hits + } + + #[test] + fn test_handle_partial_response_not_partial() { + let response = Response { + is_partial: false, + function_error: vec!["existing error".to_string()], + ..Default::default() + }; + + let result = handle_partial_response(response); + assert!(!result.is_partial); + assert_eq!(result.function_error, vec!["existing error".to_string()]); + } + + #[test] + fn test_handle_partial_response_partial_with_existing_errors() { + let response = Response { + is_partial: true, + function_error: vec!["existing error".to_string()], + ..Default::default() + }; + + let result = handle_partial_response(response); + assert!(result.is_partial); + assert_eq!(result.function_error.len(), 2); + assert!( + result + .function_error + .contains(&"existing error".to_string()) + ); + assert!( + result + .function_error + .contains(&PARTIAL_ERROR_RESPONSE_MESSAGE.to_string()) + ); + } + + #[test] + fn test_handle_partial_response_partial_without_existing_errors() { + let response = Response { + is_partial: true, + function_error: vec![], + ..Default::default() + }; + + let result = handle_partial_response(response); + assert!(result.is_partial); + assert_eq!( + result.function_error, + vec![PARTIAL_ERROR_RESPONSE_MESSAGE.to_string()] + ); + } + + #[tokio::test] + async fn test_process_search_stream_request_basic_flow() { + let (sender, mut receiver) = mpsc::channel(100); + + let req = Request { + query: Query { + sql: "SELECT * FROM test".to_string(), + from: 0, + size: 10, + start_time: 0, + end_time: 1000000, + track_total_hits: false, + ..Default::default() + }, + use_cache: Some(false), + search_type: Some(SearchEventType::UI), + ..Default::default() + }; + + let search_span = tracing::info_span!("test_search"); + + // This is a complex integration test that would require mocking external dependencies + // For now, we'll test the function signature and basic parameter handling + tokio::spawn(async move { + process_search_stream_request( + "test_org".to_string(), + "test_user".to_string(), + "test_trace".to_string(), + req, + StreamType::Logs, + vec!["test_stream".to_string()], + OrderBy::Desc, + search_span, + sender, + None, + None, + None, + ) + .await; + }); + + // Wait for the progress message + if let Some(Ok(StreamResponses::Progress { percent })) = receiver.recv().await { + assert_eq!(percent, 0); + } + + // Note: Full integration testing would require mocking the search service + // and other external dependencies, which is beyond the scope of unit tests + } + + #[test] + fn test_audit_context_creation() { + let audit_ctx = AuditContext { + method: "GET".to_string(), + path: "/api/search".to_string(), + query_params: "q=test".to_string(), + body: "{}".to_string(), + }; + + assert_eq!(audit_ctx.method, "GET"); + assert_eq!(audit_ctx.path, "/api/search"); + assert_eq!(audit_ctx.query_params, "q=test"); + assert_eq!(audit_ctx.body, "{}"); + } + + // Helper function to create test search requests + fn create_test_request() -> Request { + Request { + query: Query { + sql: "SELECT * FROM test_stream".to_string(), + from: 0, + size: 100, + start_time: 0, + end_time: 1000000, + track_total_hits: false, + ..Default::default() + }, + use_cache: Some(true), + search_type: Some(SearchEventType::UI), + ..Default::default() + } + } + + #[test] + fn test_create_test_request() { + let req = create_test_request(); + assert_eq!(req.query.sql, "SELECT * FROM test_stream"); + assert_eq!(req.query.from, 0); + assert_eq!(req.query.size, 100); + assert!(req.use_cache.unwrap_or(false)); + assert_eq!(req.search_type, Some(SearchEventType::UI)); + } + + // Test for edge cases in get_top_k_values with different data types + #[test] + fn test_get_top_k_values_mixed_data_types() { + let hits = vec![ + json!({ + "zo_sql_key": "item1", + "zo_sql_num": 100 + }), + json!({ + "zo_sql_key": "item2", + "zo_sql_num": "not_a_number" // This should default to 0 + }), + json!({ + "zo_sql_key": 123, // This should convert to string + "zo_sql_num": 50 + }), + ]; + + let field = "test_field"; + let top_k = 5; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(top_k_values.len(), 1); + assert_eq!(result_count, 3); + } + + // Test for boundary conditions + #[test] + fn test_get_top_k_values_boundary_conditions() { + let hits = vec![ + json!({ + "zo_sql_key": "max_value", + "zo_sql_num": i64::MAX + }), + json!({ + "zo_sql_key": "min_value", + "zo_sql_num": i64::MIN + }), + json!({ + "zo_sql_key": "zero_value", + "zo_sql_num": 0 + }), + ]; + + let field = "boundary_test"; + let top_k = 3; + let no_count = false; + + let result = get_top_k_values(&hits, field, top_k, no_count); + assert!(result.is_ok()); + + let (top_k_values, result_count) = result.unwrap(); + assert_eq!(result_count, 3); + + // Verify that max value comes first in count-based sorting + if let Some(Value::Object(field_obj)) = top_k_values.first() { + if let Some(Value::Array(values)) = field_obj.get("values") { + if let Some(Value::Object(first_item)) = values.first() { + assert_eq!( + first_item.get("zo_sql_key"), + Some(&Value::String("max_value".to_string())) + ); + } + } + } + } +} diff --git a/src/service/search/tantivy/puffin_directory/reader.rs b/src/service/search/tantivy/puffin_directory/reader.rs index aace21aa251..1868a292547 100644 --- a/src/service/search/tantivy/puffin_directory/reader.rs +++ b/src/service/search/tantivy/puffin_directory/reader.rs @@ -283,3 +283,593 @@ async fn warm_up_fastfield( .await?; Ok(()) } + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, io::ErrorKind, path::PathBuf, sync::Arc}; + + use hashbrown::HashMap as HashbrownHashMap; + use tantivy::{ + HasLen, Index, Term, + directory::{Directory, FileHandle, error::OpenReadError}, + doc, + schema::{STORED, Schema, TEXT}, + }; + use tokio::time::{Duration, Instant}; + + use super::*; + use crate::service::search::tantivy::puffin::{ + BlobMetadata, BlobMetadataBuilder, BlobTypes, reader::PuffinBytesReader, + }; + + // Mock data for testing + fn create_mock_object_meta(file_name: &str, size: usize) -> object_store::ObjectMeta { + object_store::ObjectMeta { + location: file_name.into(), + last_modified: chrono::Utc::now(), + size, + e_tag: None, + version: None, + } + } + + fn create_mock_blob_metadata( + blob_type: BlobTypes, + offset: u64, + length: u64, + file_name: &str, + ) -> Result { + let mut properties = HashMap::new(); + properties.insert("blob_tag".to_string(), file_name.to_string()); + + BlobMetadataBuilder::default() + .blob_type(blob_type) + .offset(offset) + .length(length) + .properties(properties) + .build() + } + + #[tokio::test] + async fn test_puffin_dir_reader_from_path_success() { + // This test would need to be implemented with proper mocking + // For now, we test the structure and error handling + let account = "test_account".to_string(); + let meta = create_mock_object_meta("test_file.puffin", 1024); + + // Test error case - this will fail because we don't have actual puffin data + let result = PuffinDirReader::from_path(account, meta).await; + assert!(result.is_err(), "Expected error for invalid puffin file"); + + // Verify the error message contains expected text + if let Err(e) = result { + assert!( + e.to_string() + .contains("Error reading metadata from puffin file") + ); + } + } + + #[tokio::test] + async fn test_puffin_dir_reader_from_path_no_metadata() { + let account = "test_account".to_string(); + let meta = create_mock_object_meta("empty_file.puffin", 0); + + let result = PuffinDirReader::from_path(account, meta).await; + assert!(result.is_err(), "Expected error for file without metadata"); + } + + #[test] + fn test_puffin_dir_reader_list_files() { + // Create mock blobs metadata + let mut blobs_metadata = HashbrownHashMap::new(); + + let blob1 = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "segment1.terms") + .expect("Failed to create blob metadata"); + + let blob2 = create_mock_blob_metadata(BlobTypes::O2TtvV1, 100, 200, "segment1.pos") + .expect("Failed to create blob metadata"); + + blobs_metadata.insert(PathBuf::from("segment1.terms"), Arc::new(blob1)); + blobs_metadata.insert(PathBuf::from("segment1.pos"), Arc::new(blob2)); + + // Create a mock PuffinBytesReader + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let files = reader.list_files(); + assert_eq!(files.len(), 2); + assert!(files.contains(&PathBuf::from("segment1.terms"))); + assert!(files.contains(&PathBuf::from("segment1.pos"))); + } + + #[test] + fn test_puffin_dir_reader_clone() { + let mut blobs_metadata = HashbrownHashMap::new(); + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "test_file.terms") + .expect("Failed to create blob metadata"); + + blobs_metadata.insert(PathBuf::from("test_file.terms"), Arc::new(blob)); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let cloned_reader = reader.clone(); + assert_eq!(reader.list_files(), cloned_reader.list_files()); + } + + #[test] + fn test_puffin_slice_handle_len() { + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 42, "test_file.terms") + .expect("Failed to create blob metadata"); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let handle = PuffinSliceHandle { + path: PathBuf::from("test_file.terms"), + source: Arc::new(mock_reader), + metadata: Arc::new(blob), + }; + + assert_eq!(handle.len(), 42); + } + + #[test] + fn test_puffin_slice_handle_read_bytes_sync_error() { + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "test_file.terms") + .expect("Failed to create blob metadata"); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let handle = PuffinSliceHandle { + path: PathBuf::from("test_file.terms"), + source: Arc::new(mock_reader), + metadata: Arc::new(blob), + }; + + let result = handle.read_bytes(0..10); + assert!(result.is_err()); + + if let Err(e) = result { + assert_eq!(e.kind(), ErrorKind::Other); + assert!( + e.to_string() + .contains("Not supported with PuffinSliceHandle") + ); + } + } + + #[tokio::test] + async fn test_puffin_slice_handle_read_bytes_async_empty_range() { + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "test_file.terms") + .expect("Failed to create blob metadata"); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let handle = PuffinSliceHandle { + path: PathBuf::from("test_file.terms"), + source: Arc::new(mock_reader), + metadata: Arc::new(blob), + }; + + let result = handle.read_bytes_async(0..0).await; + assert!(result.is_ok()); + + if let Ok(bytes) = result { + assert_eq!(bytes.len(), 0); + } + } + + #[test] + fn test_directory_get_file_handle_existing_file() { + let mut blobs_metadata = HashbrownHashMap::new(); + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "existing_file.terms") + .expect("Failed to create blob metadata"); + + let path = PathBuf::from("existing_file.terms"); + blobs_metadata.insert(path.clone(), Arc::new(blob)); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let result = reader.get_file_handle(&path); + assert!(result.is_ok(), "Expected success for existing file"); + } + + #[test] + fn test_directory_get_file_handle_nonexistent_file() { + let blobs_metadata = HashbrownHashMap::new(); + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let path = PathBuf::from("nonexistent_file.terms"); + let result = reader.get_file_handle(&path); + + // This should try to get from empty puffin directory, which might succeed or fail + // depending on the extension + match result { + Ok(_) => { + // File found in empty puffin directory + } + Err(OpenReadError::FileDoesNotExist(_)) => { + // Expected for files not in empty puffin directory + } + Err(e) => panic!("Unexpected error: {:?}", e), + } + } + + #[test] + fn test_directory_get_file_handle_no_extension() { + let blobs_metadata = HashbrownHashMap::new(); + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let path = PathBuf::from("file_without_extension"); + let result = reader.get_file_handle(&path); + + assert!(matches!(result, Err(OpenReadError::FileDoesNotExist(_)))); + } + + #[test] + fn test_directory_exists_existing_file() { + let mut blobs_metadata = HashbrownHashMap::new(); + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "existing_file.terms") + .expect("Failed to create blob metadata"); + + let path = PathBuf::from("existing_file.terms"); + blobs_metadata.insert(path.clone(), Arc::new(blob)); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let result = reader.exists(&path); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), true); + } + + #[test] + fn test_directory_exists_nonexistent_file() { + let blobs_metadata = HashbrownHashMap::new(); + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let path = PathBuf::from("nonexistent_file.unknown"); + let result = reader.exists(&path); + + // Should check empty puffin directory + assert!(result.is_ok()); + } + + #[test] + fn test_directory_readonly_operations() { + let blobs_metadata = HashbrownHashMap::new(); + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }; + + let path = PathBuf::from("test_file.txt"); + let data = b"test data"; + + // Test that write operations are unimplemented + let atomic_write_result = std::panic::catch_unwind(|| reader.atomic_write(&path, data)); + assert!(atomic_write_result.is_err()); + + let atomic_read_result = std::panic::catch_unwind(|| reader.atomic_read(&path)); + assert!(atomic_read_result.is_err()); + + let delete_result = std::panic::catch_unwind(|| reader.delete(&path)); + assert!(delete_result.is_err()); + + let open_write_result = std::panic::catch_unwind(|| reader.open_write(&path)); + assert!(open_write_result.is_err()); + + let sync_result = std::panic::catch_unwind(|| reader.sync_directory()); + assert!(sync_result.is_err()); + } + + #[tokio::test] + async fn test_warm_up_terms_empty_terms() { + // Create a simple in-memory index for testing + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT | STORED); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index + .writer(50_000_000) + .expect("Failed to create index writer"); + + // Add a document + index_writer + .add_document(doc!(text_field => "hello world")) + .expect("Failed to add document"); + index_writer.commit().expect("Failed to commit"); + + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::Manual) + .try_into() + .expect("Failed to create reader"); + + let searcher = reader.searcher(); + + // Test with empty terms + let terms_grouped_by_field = HashbrownHashMap::new(); + let result = warm_up_terms(&searcher, &terms_grouped_by_field, false).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_warm_up_terms_with_terms() { + // Create a simple in-memory index for testing + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT | STORED); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index + .writer(50_000_000) + .expect("Failed to create index writer"); + + // Add a document + index_writer + .add_document(doc!(text_field => "hello world")) + .expect("Failed to add document"); + index_writer.commit().expect("Failed to commit"); + + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::Manual) + .try_into() + .expect("Failed to create reader"); + + let searcher = reader.searcher(); + + // Test with specific terms + let mut terms_grouped_by_field = HashbrownHashMap::new(); + let mut field_terms = HashbrownHashMap::new(); + let term = Term::from_field_text(text_field, "hello"); + field_terms.insert(term, false); + terms_grouped_by_field.insert(text_field, field_terms); + + let result = warm_up_terms(&searcher, &terms_grouped_by_field, false).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_warm_up_terms_with_fast_fields() { + // Create a simple in-memory index for testing + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT | STORED); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index + .writer(50_000_000) + .expect("Failed to create index writer"); + + // Add a document + index_writer + .add_document(doc!(text_field => "hello world")) + .expect("Failed to add document"); + index_writer.commit().expect("Failed to commit"); + + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::Manual) + .try_into() + .expect("Failed to create reader"); + + let searcher = reader.searcher(); + + // Test with fast fields enabled + let terms_grouped_by_field = HashbrownHashMap::new(); + let result = warm_up_terms(&searcher, &terms_grouped_by_field, true).await; + // This might fail if _timestamp field is not present, which is expected in this simple test + // The important thing is that the function doesn't panic + let _ = result; + } + + #[tokio::test] + async fn test_warm_up_terms_performance() { + // Performance test to ensure warming up doesn't take too long + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT | STORED); + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index + .writer(50_000_000) + .expect("Failed to create index writer"); + + // Add multiple documents + for i in 0..100 { + index_writer + .add_document(doc!(text_field => format!("document {}", i))) + .expect("Failed to add document"); + } + index_writer.commit().expect("Failed to commit"); + + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::Manual) + .try_into() + .expect("Failed to create reader"); + + let searcher = reader.searcher(); + + let mut terms_grouped_by_field = HashbrownHashMap::new(); + let mut field_terms = HashbrownHashMap::new(); + // Add multiple terms + for i in 0..10 { + let term = Term::from_field_text(text_field, &format!("document {}", i)); + field_terms.insert(term, false); + } + terms_grouped_by_field.insert(text_field, field_terms); + + let start = Instant::now(); + let result = warm_up_terms(&searcher, &terms_grouped_by_field, false).await; + let duration = start.elapsed(); + + assert!(result.is_ok()); + // Warming up should complete within a reasonable time (10 seconds for this test) + assert!(duration < Duration::from_secs(10)); + } + + #[test] + fn test_blob_metadata_properties() { + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 100, 200, "test_file.terms") + .expect("Failed to create blob metadata"); + + // Test that properties are set correctly + assert_eq!(blob.blob_type, BlobTypes::O2FstV1); + assert_eq!(blob.offset, 100); + assert_eq!(blob.length, 200); + assert_eq!( + blob.properties.get("blob_tag"), + Some(&"test_file.terms".to_string()) + ); + } + + #[test] + fn test_blob_metadata_get_offset() { + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 100, 200, "test_file.terms") + .expect("Failed to create blob metadata"); + + // Test get_offset with no range + let offset_range = blob.get_offset(None); + assert_eq!(offset_range, 100..300); + + // Test get_offset with specific range + let offset_range = blob.get_offset(Some(10..20)); + assert_eq!(offset_range, 110..120); + } + + #[test] + fn test_error_handling_edge_cases() { + // Test various error conditions that might occur in real usage + + // Test with invalid blob types + let result = BlobMetadataBuilder::default().offset(0).length(100).build(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "blob_type is required"); + + // Test with missing offset + let result = BlobMetadataBuilder::default() + .blob_type(BlobTypes::O2FstV1) + .length(100) + .build(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "offset is required"); + } + + #[test] + fn test_concurrent_access() { + // Test that the reader can be safely shared across threads + let mut blobs_metadata = HashbrownHashMap::new(); + let blob = create_mock_blob_metadata(BlobTypes::O2FstV1, 0, 100, "shared_file.terms") + .expect("Failed to create blob metadata"); + + blobs_metadata.insert(PathBuf::from("shared_file.terms"), Arc::new(blob)); + + let mock_reader = PuffinBytesReader::new( + "test_account".to_string(), + create_mock_object_meta("test.puffin", 1024), + ); + + let reader = Arc::new(PuffinDirReader { + source: Arc::new(mock_reader), + blobs_metadata: Arc::new(blobs_metadata), + }); + + let mut handles = vec![]; + + // Spawn multiple threads that access the reader + for i in 0..10 { + let reader_clone = reader.clone(); + let handle = std::thread::spawn(move || { + let files = reader_clone.list_files(); + assert_eq!(files.len(), 1); + assert!(files.contains(&PathBuf::from("shared_file.terms"))); + + let path = PathBuf::from("shared_file.terms"); + let exists = reader_clone.exists(&path); + assert!(exists.is_ok()); + assert_eq!(exists.unwrap(), true); + + i // Return thread index for verification + }); + handles.push(handle); + } + + // Wait for all threads to complete + for (i, handle) in handles.into_iter().enumerate() { + let result = handle.join().expect("Thread panicked"); + assert_eq!(result, i); + } + } +} diff --git a/src/service/short_url.rs b/src/service/short_url.rs index bec500446cf..3798da335f8 100644 --- a/src/service/short_url.rs +++ b/src/service/short_url.rs @@ -22,21 +22,21 @@ use infra::{ use crate::service::db; -const SHORT_URL_WEB_PATH: &str = "/short/"; +const SHORT_URL_WEB_PATH: &str = "short/"; pub fn get_base_url() -> String { let config = get_config(); format!("{}{}", config.common.web_url, config.common.base_uri) } -fn construct_short_url(org_id: &str, short_id: &str) -> String { +pub fn construct_short_url(org_id: &str, short_id: &str) -> String { format!( - "{}/{}/{}{}{}", + "{}/{}/{}{}?org_identifier={}", get_base_url(), - "api", - org_id, + "web", SHORT_URL_WEB_PATH, - short_id + short_id, + org_id, ) } @@ -137,4 +137,15 @@ mod tests { // Should return the same short_id assert_eq!(short_url1, short_url2); } + + #[tokio::test] + async fn test_generate_short_id() { + let original_url = "https://www.example.com/some/long/url"; + let short_id = generate_short_id(original_url, None); + assert_eq!(short_id.len(), 16); + let timestamp = Utc::now().timestamp_micros(); + let short_id2 = generate_short_id(original_url, Some(timestamp)); + assert_eq!(short_id2.len(), 16); + assert_ne!(short_id, short_id2); + } } diff --git a/src/service/syslogs_route.rs b/src/service/syslogs_route.rs index a2d3553ecb6..65fc15e0aa4 100644 --- a/src/service/syslogs_route.rs +++ b/src/service/syslogs_route.rs @@ -18,7 +18,8 @@ use std::io; use actix_web::{HttpResponse, http::StatusCode}; use config::ider; use ipnetwork::IpNetwork; - +use tokio::fs::File; +use tokio::io::AsyncReadExt; use crate::{ common::{ infra::config::SYSLOG_ROUTES, @@ -148,6 +149,38 @@ pub async fn toggle_state(server: SyslogServer) -> Result Result { + let cfg = config::get_config(); + let ca_cert_path = &cfg.tcp.tcp_tls_ca_cert_path; + log::info!("ca_cert_path: {}", ca_cert_path); + let mut file = File::open(ca_cert_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + log::info!("contents: {}", contents); + + Ok(HttpResponse::Ok() + .content_type("application/x-pem-file") + .insert_header(("Content-Disposition", "attachment; filename=\"ca-cert.pem\"")) + .body(contents)) +} + +#[tracing::instrument(skip_all)] +pub async fn get_tcp_tls_cert() -> Result { + let cfg = config::get_config(); + let cert_path = &cfg.tcp.tcp_tls_cert_path; + log::info!("cert_path: {}", cert_path); + let mut file = File::open(cert_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + log::info!("contents: {}", contents); + + Ok(HttpResponse::Ok() + .content_type("application/x-pem-file") + .insert_header(("Content-Disposition", "attachment; filename=\"server-cert.pem\"")) + .body(contents)) +} + #[derive(Debug)] enum Response { OkMessage(String), @@ -181,4 +214,4 @@ fn subnets_overlap(net1: &IpNetwork, net2: &IpNetwork) -> bool { || net1.contains(net2.broadcast()) || net2.contains(net1.network()) || net2.contains(net1.broadcast()) -} +} \ No newline at end of file diff --git a/src/service/tls/mod.rs b/src/service/tls/mod.rs index 7334675459e..a5a3264f021 100644 --- a/src/service/tls/mod.rs +++ b/src/service/tls/mod.rs @@ -18,7 +18,10 @@ use std::{io::BufReader, sync::Arc}; use actix_tls::connect::rustls_0_23::{native_roots_cert_store, webpki_roots_cert_store}; use itertools::Itertools as _; use rustls::{ClientConfig, ServerConfig}; +use rustls::crypto::CryptoProvider; +use rustls::crypto::ring::default_provider; use rustls_pemfile::{certs, private_key}; +use config::utils::cert::SelfSignedCertVerifier; pub fn http_tls_config() -> Result { let cfg = config::get_config(); @@ -91,3 +94,88 @@ pub fn client_tls_config() -> Result, anyhow::Error> { pub fn reqwest_client_tls_config() -> Result { todo!() } + +pub fn tcp_tls_server_config() -> Result { + let cfg = config::get_config(); + let _ = CryptoProvider::install_default(default_provider()); + let cert_file = + &mut BufReader::new(std::fs::File::open(&cfg.tcp.tcp_tls_cert_path).map_err(|e| { + anyhow::anyhow!( + "Failed to open TLS certificate file {}: {}", + &cfg.tcp.tcp_tls_cert_path, + e + ) + })?); + let key_file = + &mut BufReader::new(std::fs::File::open(&cfg.tcp.tcp_tls_key_path).map_err(|e| { + anyhow::anyhow!( + "Failed to open TLS key file {}: {}", + &cfg.tcp.tcp_tls_key_path, + e + ) + })?); + + let cert_chain = certs(cert_file); + + let tls_config = ServerConfig::builder_with_protocol_versions(rustls::DEFAULT_VERSIONS) + .with_no_client_auth() + .with_single_cert( + cert_chain.try_collect::<_, Vec<_>, _>()?, + private_key(key_file)?.unwrap(), + )?; + + Ok(tls_config) +} + +pub fn tcp_tls_self_connect_client_config() -> Result, anyhow::Error> { + let cfg = config::get_config(); + let config = if cfg.tcp.tcp_tls_enabled { + if !cfg.tcp.tcp_tls_ca_cert_path.is_empty() { + // Case 1: CA certificate is provided - use it to verify the server + let mut cert_store = webpki_roots_cert_store(); + let cert_file = + &mut BufReader::new(std::fs::File::open(&cfg.tcp.tcp_tls_ca_cert_path).map_err(|e| { + anyhow::anyhow!( + "Failed to open TLS CA certificate file {}: {}", + &cfg.tcp.tcp_tls_ca_cert_path, + e + ) + })?); + let cert_chain = certs(cert_file); + cert_store.add_parsable_certificates(cert_chain.try_collect::<_, Vec<_>, _>()?); + + ClientConfig::builder() + .with_root_certificates(cert_store) + .with_no_client_auth() + } else if !cfg.tcp.tcp_tls_cert_path.is_empty() { + // Case 2: Self-signed certificate - use the server's certificate as a trusted root + // We're only using the public certificate, not the private key + let cert_file = + &mut BufReader::new(std::fs::File::open(&cfg.tcp.tcp_tls_cert_path).map_err(|e| { + anyhow::anyhow!( + "Failed to open TLS certificate file {}: {}", + &cfg.tcp.tcp_tls_cert_path, + e + ) + })?); + let server_certs = certs(cert_file).try_collect::<_, Vec<_>, _>()?; + + ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(SelfSignedCertVerifier::new(server_certs))) + .with_no_client_auth() + } else { + // Case 3: No certificates provided but TLS is enabled - use system root certificates + ClientConfig::builder() + .with_root_certificates(native_roots_cert_store()?) + .with_no_client_auth() + } + } else { + // Case 4: TLS is disabled, but we still need a config for the function to work + ClientConfig::builder() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth() + }; + + Ok(Arc::new(config)) +} \ No newline at end of file diff --git a/tests/ui-testing/pages/streamsPage.js b/tests/ui-testing/pages/streamsPage.js index 6869b3fb624..3a20889fc80 100644 --- a/tests/ui-testing/pages/streamsPage.js +++ b/tests/ui-testing/pages/streamsPage.js @@ -16,12 +16,14 @@ export class StreamsPage { async streamsPageDefaultOrg() { - await this.page.locator('[data-test="navbar-organizations-select"]').getByText('arrow_drop_down').click(); - await this.page.getByText('default', { exact: true }).click(); - - + + await this.page.waitForSelector('text=default'); + + const defaultOption = this.page.locator('text=default').first(); // Target the first occurrence + await defaultOption.click(); } + async streamsPageDefaultMultiOrg() { diff --git a/tests/ui-testing/pages/tracesPage.js b/tests/ui-testing/pages/tracesPage.js index 8db3451dd04..3fd8f715851 100644 --- a/tests/ui-testing/pages/tracesPage.js +++ b/tests/ui-testing/pages/tracesPage.js @@ -36,7 +36,12 @@ export async tracesPageDefaultOrg() { await this.page.locator('[data-test="navbar-organizations-select"]').getByText('arrow_drop_down').click(); - await this.page.getByText('default', { exact: true }).click(); + // Wait for the dropdown options to be visible + await this.page.waitForSelector('text=default'); // Wait for the text "default" to be present + + // Click the specific "default" option within the dropdown + const defaultOption = this.page.locator('text=default').first(); // Target the first occurrence + await defaultOption.click(); } diff --git a/tests/ui-testing/playwright-tests/changeOrg.spec.js b/tests/ui-testing/playwright-tests/changeOrg.spec.js index dee3446607a..2c634f0b264 100644 --- a/tests/ui-testing/playwright-tests/changeOrg.spec.js +++ b/tests/ui-testing/playwright-tests/changeOrg.spec.js @@ -23,8 +23,10 @@ test.describe("Change Organisation", () => { tracesPage, rumPage, pipelinesPage, dashboardPage, streamsPage, reportsPage, alertsPage, dataPage, iamPage, managementPage, aboutPage, createOrgPage, multiOrgIdentifier; + const timestamp = Date.now(); + const randomSuffix = Math.floor(Math.random() * 1000); + const newOrgName = `org${timestamp}${randomSuffix}`; - const newOrgName = `organisation${Math.floor(Math.random() * 10000)}`; test.beforeEach(async ({ page }) => { loginPage = new LoginPage(page); ingestionPage = new IngestionPage(page); @@ -118,7 +120,7 @@ test.describe("Change Organisation", () => { await tracesPage.navigateToTraces(); - await homePage.homePageDefaultOrg(); + await tracesPage.tracesPageDefaultOrg(); await homePage.homePageURLValidationDefaultOrg(); await tracesPage.validateTracesPage(); await tracesPage.tracesURLValidation(); @@ -199,7 +201,7 @@ test.describe("Change Organisation", () => { await streamsPage.gotoStreamsPage(); - await homePage.homePageDefaultOrg(); + await streamsPage.streamsPageDefaultOrg(); await homePage.homePageURLValidationDefaultOrg(); await streamsPage.streamsURLValidation(); @@ -309,7 +311,8 @@ test.describe("Change Organisation", () => { await ingestionPage.ingestionMultiOrg(multiOrgIdentifier); await homePage.homePageOrg(newOrgName); await managementPage.goToManagement(); - await homePage.homeURLContains(multiOrgIdentifier); + await managementPage.managementURLValidation(); + }); diff --git a/web/src/components/alerts/AddDestination.vue b/web/src/components/alerts/AddDestination.vue index 7fea630699e..6ac0ff32131 100644 --- a/web/src/components/alerts/AddDestination.vue +++ b/web/src/components/alerts/AddDestination.vue @@ -138,7 +138,9 @@ along with this program. If not, see . tabindex="0" /> -
+
. />
-
+
{ template: props.isAlerts ? formData.value.template : "", headers: headers, name: formData.value.name, - output_format: formData.value.output_format, + }; + if(!props.isAlerts){ + payload["output_format"] = formData.value.output_format; + } + if (formData.value.type === "email") { payload["type"] = "email"; payload["emails"] = (formData.value?.emails || "") diff --git a/web/src/composables/shared/router.ts b/web/src/composables/shared/router.ts index ef5b63734ed..0f4e412e703 100644 --- a/web/src/composables/shared/router.ts +++ b/web/src/composables/shared/router.ts @@ -24,6 +24,7 @@ import Tickets from "@/views/TicketsView.vue"; import About from "@/views/About.vue"; import MemberSubscription from "@/views/MemberSubscription.vue"; import Error404 from "@/views/Error404.vue"; +import ShortUrl from "@/views/ShortUrl.vue"; const Search = () => import("@/views/Search.vue"); const AppMetrics = () => import("@/views/AppMetrics.vue"); @@ -313,6 +314,15 @@ const useRoutes = () => { routeGuard(to, from, next); }, }, + { + path: "short/:id", + name: "shortUrl", + component: ShortUrl, + beforeEnter(to: any, from: any, next: any) { + routeGuard(to, from, next); + }, + props: true, + }, { path: "rum", name: "RUM", diff --git a/web/src/composables/useLogs.ts b/web/src/composables/useLogs.ts index 85b18301e7e..84981274f15 100644 --- a/web/src/composables/useLogs.ts +++ b/web/src/composables/useLogs.ts @@ -1457,10 +1457,13 @@ const useLogs = () => { // partitionDetail.partitions.forEach((item: any, index: number) => { for (const [index, item] of partitionDetail.partitions.entries()) { total = partitionDetail.partitionTotal[index]; - totalPages = Math.ceil(total / rowsPerPage); + if (!partitionDetail.paginations[pageNumber]) { partitionDetail.paginations[pageNumber] = []; } + + totalPages = getPartitionTotalPages(total); + if (totalPages > 0) { partitionFrom = 0; for (let i = 0; i < totalPages; i++) { @@ -1471,9 +1474,16 @@ const useLogs = () => { : rowsPerPage; from = partitionFrom; + + if (total < recordSize) { + recordSize = total; + } + // if (i === 0 && partitionDetail.paginations.length > 0) { lastPartitionSize = 0; - if (pageNumber > 0) { + + // if the pagination array is not empty, then we need to get the last page and add the size of the last page to the last partition size + if (partitionDetail.paginations[pageNumber]?.length) { lastPage = partitionDetail.paginations.length - 1; // partitionDetail.paginations[lastPage].forEach((item: any) => { @@ -1484,7 +1494,12 @@ const useLogs = () => { if (lastPartitionSize != rowsPerPage) { recordSize = rowsPerPage - lastPartitionSize; } + + if (total < recordSize) { + recordSize = total; + } } + if (!partitionDetail.paginations[pageNumber]) { partitionDetail.paginations[pageNumber] = []; } @@ -1531,6 +1546,10 @@ const useLogs = () => { recordSize = 0; } + if (total !== -1 && total < recordSize) { + recordSize = total; + } + partitionDetail.paginations[pageNumber].push({ startTime: item[0], endTime: item[1], @@ -1564,6 +1583,29 @@ const useLogs = () => { } }; + /** + * This function is used to get the total pages for the single partition + * This method handles the case where previous partition is not fully loaded and we are loading the next partition + * In this case, we need to add the size of the previous partition to the total size of the current partition for accurate total pages + * @param total - The total number of records in the partition + * @returns The total number of pages for the partition + */ + const getPartitionTotalPages = (total: number) => { + const lastPage = searchObj.data.queryResults.partitionDetail.paginations?.length - 1; + + let lastPartitionSize = 0; + let partitionTotal = 0; + for (const item of searchObj.data.queryResults.partitionDetail.paginations[lastPage]) { + lastPartitionSize += item.size; + } + + if (lastPartitionSize < searchObj.meta.resultGrid.rowsPerPage) { + partitionTotal = total + lastPartitionSize; + } + + return Math.ceil(partitionTotal / searchObj.meta.resultGrid.rowsPerPage); + } + const getQueryData = async (isPagination = false) => { try { //remove any data that has been cached @@ -1588,7 +1630,7 @@ const useLogs = () => { !searchObj.meta.sqlMode; // Determine communication method based on available options and constraints - if (shouldUseStreaming) { + if (shouldUseStreaming && !isMultiStreamSearch) { searchObj.communicationMethod = "streaming"; } else if (shouldUseWebSocket && !isMultiStreamSearch) { searchObj.communicationMethod = "ws"; @@ -1926,6 +1968,7 @@ const useLogs = () => { } searchObjDebug["queryDataEndTime"] = performance.now(); } catch (e: any) { + console.error(`${notificationMsg.value || "Error occurred during the search operation."}`, e); searchObj.loading = false; showErrorNotification( notificationMsg.value || "Error occurred during the search operation.", @@ -3651,7 +3694,8 @@ const useLogs = () => { let plusSign: string = ""; if ( - searchObj.data.queryResults?.partitionDetail?.partitions?.length > 1 && + searchObj.data.queryResults?.partitionDetail?.partitions?.length > 1 && + endCount < totalCount && searchObj.meta.showHistogram == false ) { plusSign = "+"; @@ -5030,9 +5074,17 @@ const useLogs = () => { const payload = buildWebSocketPayload(queryReq, isPagination, "search"); - if(shouldGetPageCount(queryReq, fnParsedSQL())) { - queryReq.query.size = queryReq.query.size + 1; + if(searchObj.meta.refreshInterval === 0) updatePageCountSearchSize(queryReq); + + // in case of live refresh, reset from to 0 + if ( + searchObj.meta.refreshInterval > 0 && + router.currentRoute.value.name == "logs" + ) { + queryReq.query.from = 0; + searchObj.meta.refreshHistogram = false; } + const requestId = initializeSearchConnection(payload); @@ -5247,31 +5299,24 @@ const useLogs = () => { // Update results const handleStreamingHits = (payload: WebSocketSearchPayload, response: WebSocketSearchResponse, isPagination: boolean, appendResult: boolean = false) => { - if ( - searchObj.meta.refreshInterval > 0 && - router.currentRoute.value.name == "logs" - ) { + // Scan-size and took time in histogram title + // For the initial request, we get histogram and logs data. So, we need to sum the scan_size and took time of both the requests. + // For the pagination request, we only get logs data. So, we need to consider scan_size and took time of only logs request. + if (appendResult) { + searchObj.data.queryResults.hits.push( + ...response.content.results.hits, + ); + } else { searchObj.data.queryResults.hits = response.content.results.hits; } - if (!searchObj.meta.refreshInterval) { - // Scan-size and took time in histogram title - // For the initial request, we get histogram and logs data. So, we need to sum the scan_size and took time of both the requests. - // For the pagination request, we only get logs data. So, we need to consider scan_size and took time of only logs request. - if (appendResult) { - searchObj.data.queryResults.hits.push( - ...response.content.results.hits, - ); - } else { - searchObj.data.queryResults.hits = response.content.results.hits; - } - - - if(shouldGetPageCount(payload.queryReq, fnParsedSQL()) && (searchObj.data.queryResults.hits.length === payload.queryReq.query.size)) { - searchObj.data.queryResults.hits = searchObj.data.queryResults.hits.slice(0, payload.queryReq.query.size - 1); - } + if (searchObj.meta.refreshInterval === 0) { + updatePageCountTotal(payload.queryReq, response.content.results.hits.length, searchObj.data.queryResults.hits.length); + trimPageCountExtraHit(payload.queryReq, searchObj.data.queryResults.hits.length); } + refreshPagination(true); + processPostPaginationData(); } @@ -5285,56 +5330,44 @@ const useLogs = () => { ////// Handle reset field values /////// resetFieldValues(); - if ( - searchObj.meta.refreshInterval > 0 && - router.currentRoute.value.name == "logs" - ) { - searchObj.data.queryResults.from = response.content.results.from; - searchObj.data.queryResults.scan_size = + // In page count we set track_total_hits + if (!payload.queryReq.query.hasOwnProperty("track_total_hits")) { + delete response.content.total; + } + // Scan-size and took time in histogram title + // For the initial request, we get histogram and logs data. So, we need to sum the scan_size and took time of both the requests. + // For the pagination request, we only get logs data. So, we need to consider scan_size and took time of only logs request. + if (appendResult) { + searchObj.data.queryResults.total += response.content.results.total; + searchObj.data.queryResults.took += response.content.results.took; + searchObj.data.queryResults.scan_size += response.content.results.scan_size; - searchObj.data.queryResults.took = response.content.results.took; - searchObj.data.queryResults.aggs = response.content.results.aggs; - } - - if (!searchObj.meta.refreshInterval) { - // In page count we set track_total_hits - if (!payload.queryReq.query.hasOwnProperty("track_total_hits")) { - delete response.content.total; - } - - // Scan-size and took time in histogram title - // For the initial request, we get histogram and logs data. So, we need to sum the scan_size and took time of both the requests. - // For the pagination request, we only get logs data. So, we need to consider scan_size and took time of only logs request. - if (appendResult) { - searchObj.data.queryResults.total += response.content.results.total; - searchObj.data.queryResults.took += response.content.results.took; - searchObj.data.queryResults.scan_size += + } else { + if (isPagination && response.content?.streaming_aggs) { + searchObj.data.queryResults.from = response.content.results.from; + searchObj.data.queryResults.scan_size = response.content.results.scan_size; - } else { - if (isPagination && response.content?.streaming_aggs) { - searchObj.data.queryResults.from = response.content.results.from; - searchObj.data.queryResults.scan_size = - response.content.results.scan_size; - searchObj.data.queryResults.took = response.content.results.took; - } else if (response.content?.streaming_aggs) { - searchObj.data.queryResults = { - ...response.content.results, - took: (searchObj.data?.queryResults?.took || 0) + response.content.results.took, - scan_size: (searchObj.data?.queryResults?.scan_size || 0) + response.content.results.scan_size, - hits: searchObj.data?.queryResults?.hits || [], - streaming_aggs: response.content?.streaming_aggs, - } - } else if (isPagination) { - searchObj.data.queryResults.from = response.content.results.from; - searchObj.data.queryResults.scan_size = - response.content.results.scan_size; - searchObj.data.queryResults.took = response.content.results.took; - searchObj.data.queryResults.total = response.content.results.total; - } else { - searchObj.data.queryResults = response.content.results; + searchObj.data.queryResults.took = response.content.results.took; + } else if (response.content?.streaming_aggs) { + searchObj.data.queryResults = { + ...response.content.results, + took: (searchObj.data?.queryResults?.took || 0) + response.content.results.took, + scan_size: (searchObj.data?.queryResults?.scan_size || 0) + response.content.results.scan_size, + hits: searchObj.data?.queryResults?.hits || [], + streaming_aggs: response.content?.streaming_aggs, } + } else if (isPagination) { + searchObj.data.queryResults.from = response.content.results.from; + searchObj.data.queryResults.scan_size = + response.content.results.scan_size; + searchObj.data.queryResults.took = response.content.results.took; + searchObj.data.queryResults.total = response.content.results.total; + } else { + searchObj.data.queryResults = response.content.results; } + } + if(searchObj.meta.refreshInterval === 0) { if(shouldGetPageCount(payload.queryReq, fnParsedSQL()) && (response.content.results.total === payload.queryReq.query.size)) { searchObj.data.queryResults.pageCountTotal = payload.queryReq.query.size * searchObj.data.resultGrid.currentPage; } @@ -5353,6 +5386,39 @@ const useLogs = () => { if (isPagination) refreshPagination(true); } + const updatePageCountTotal = (queryReq: SearchRequestPayload, currentHits: number, total: any) => { + try { + if(shouldGetPageCount(queryReq, fnParsedSQL()) && (total === queryReq.query.size)) { + searchObj.data.queryResults.pageCountTotal = (searchObj.meta.resultGrid.rowsPerPage * searchObj.data.resultGrid.currentPage) + 1; + } else if(shouldGetPageCount(queryReq, fnParsedSQL()) && (total !== queryReq.query.size)){ + searchObj.data.queryResults.pageCountTotal = ((searchObj.meta.resultGrid.rowsPerPage) * Math.max(searchObj.data.resultGrid.currentPage-1,0)) + currentHits; + } + } catch(e: any) { + console.error("Error while updating page count total", e); + } + } + + const trimPageCountExtraHit = (queryReq: SearchRequestPayload, total: any) => { + try{ + if(shouldGetPageCount(queryReq, fnParsedSQL()) && (total === queryReq.query.size)) { + searchObj.data.queryResults.hits = searchObj.data.queryResults.hits.slice(0, searchObj.data.queryResults.hits.length- 1); + } + } catch(e: any) { + console.error("Error while trimming page count extra hit", e); + } + } + + const updatePageCountSearchSize = (queryReq: SearchRequestPayload) => { + try{ + if(shouldGetPageCount(queryReq, fnParsedSQL())) { + queryReq.query.size = queryReq.query.size + 1; + } + } catch(e: any) { + console.error("Error while updating page count search size", e); + return queryReq.query.size; + } + } + const handleHistogramStreamingHits = (payload: WebSocketSearchPayload, response: WebSocketSearchResponse, isPagination: boolean, appendResult: boolean = false) => { searchObj.loading = false; @@ -5990,7 +6056,7 @@ const useLogs = () => { } function isHistogramDataMissing(searchObj: any) { - return searchObj.data.queryResults.aggs === undefined || searchObj.data.queryResults.aggs.length === 0; + return !searchObj.data.queryResults?.aggs?.length; } const handleSearchClose = (payload: any, response: any) => { diff --git a/web/src/composables/useStreamingSearch.ts b/web/src/composables/useStreamingSearch.ts index e8465e5edfc..503c928165e 100644 --- a/web/src/composables/useStreamingSearch.ts +++ b/web/src/composables/useStreamingSearch.ts @@ -194,9 +194,6 @@ const useHttpStreaming = () => { if (meta?.dashboard_id) url += `&dashboard_id=${meta?.dashboard_id}`; if (meta?.folder_id) url += `&folder_id=${meta?.folder_id}`; if (meta?.fallback_order_by_col) url += `&fallback_order_by_col=${meta?.fallback_order_by_col}`; - if (typeof queryReq.query.sql != "string") { - url = `/_search_multi_stream?type=${pageType}&search_type=${searchType}&use_cache=${use_cache}`; - } } else if(type === "values") { const fieldsString = meta?.fields.join(","); url = `/_values_stream` diff --git a/web/src/plugins/logs/SearchBar.vue b/web/src/plugins/logs/SearchBar.vue index 986f6c5cdbd..118a25f9ad5 100644 --- a/web/src/plugins/logs/SearchBar.vue +++ b/web/src/plugins/logs/SearchBar.vue @@ -129,7 +129,7 @@ along with this program. If not, see . @@ -374,100 +374,99 @@ along with this program. If not, see . - -
-
-
-
- -
- - {{ t('search.showHistogramLabel') }} - - - + v-if="store.state.isAiChatEnabled" + class="tw-text-[12px] tw-font-[500] q-ml-xs q-px-sm" + no-caps + menu-anchor="bottom left" + menu-self="top left" + icon="menu" + size="sm" + dense + > + +
+
+
+
+
+ + {{ t("search.showHistogramLabel") }} + +
- + -
+
- + >
- - Wrap Content - - - + Wrap Content
-
-
- -
+
+
+ +
{{ t("search.quickModeLabel") }} - -
-
-
- +
+ - -
- - Syntax Guide - - + > +
- -
-
- Syntax Guide +
+ +
+
+ . @click="resetFilters" > -
- - {{ t('search.resetFilters') }} - -
-
-
- + + {{ t("search.resetFilters") }} +
- - +
+
+ +
. @save:function="fnSavedFunctionDialog" /> . {{ t("search.moreActions") }} -
+
{ + return http().get(`/api/${org_identifier}/short/${id}?type=ui`); + }, }; export default shortURL; diff --git a/web/src/test/unit/helpers/handlers.ts b/web/src/test/unit/helpers/handlers.ts index 81d072b9030..0f03d3da94d 100644 --- a/web/src/test/unit/helpers/handlers.ts +++ b/web/src/test/unit/helpers/handlers.ts @@ -173,4 +173,8 @@ export const restHandlers = [ ], }); }), + + http.get(`${store.state.API_ENDPOINT}/api/${store.state.selectedOrganization.identifier}/short/:id`, ({ request }) => { + return HttpResponse.json("http://localhost:5080/web/logs?stream_type=logs&stream=default1&from=1749120770351000&to=1749121670351000&refresh=0&defined_schemas=user_defined_schema&org_identifier=default&quick_mode=false&show_histogram=true"); + }), ]; diff --git a/web/src/views/ShortUrl.spec.ts b/web/src/views/ShortUrl.spec.ts new file mode 100644 index 00000000000..20094fa5dd4 --- /dev/null +++ b/web/src/views/ShortUrl.spec.ts @@ -0,0 +1,81 @@ +// Copyright 2023 OpenObserve Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + + +import { describe, it, expect, afterEach, beforeEach, vi } from "vitest"; +import { mount } from "@vue/test-utils"; +import ShortUrl from "@/views/ShortUrl.vue"; +import i18n from "@/locales"; +import store from "@/test/unit/helpers/store"; +import router from "@/test/unit/helpers/router"; +import { installQuasar } from "@/test/unit/helpers/install-quasar-plugin"; +import { Dialog, Notify } from "quasar"; +import * as zincutils from "@/utils/zincutils"; + +// Mock only the routeGuard function +vi.spyOn(zincutils, 'routeGuard').mockImplementation(async (to, from, next) => { + next(); +}); + +const node = document.createElement("div"); +node.setAttribute("id", "app"); +document.body.appendChild(node); + +installQuasar({ + plugins: [Dialog, Notify], +}); + + +describe("ShortUrl", () => { + let wrapper; + + beforeEach(async () => { + // Set up the route with an ID parameter + wrapper = mount(ShortUrl, { + attachTo: "#app", + global: { + provide: { + store: store, + }, + plugins: [i18n, router], + }, + props: { + id: "test-id", + } + }); + }); + + afterEach(() => { + wrapper.unmount(); + }); + + it("Should match snapshot", () => { + expect(wrapper.html()).toMatchSnapshot(); + }); + + it("Should render spinner", () => { + expect(wrapper.find('[data-test="spinner"]').exists()).toBe(true); + }); + + it("Should render message", () => { + expect(wrapper.find('[data-test="message"]').text()).toBe("Redirecting..."); + }); + + it("Should redirect to the correct page", async () => { + // Wait for the api call to resolve + await new Promise(resolve => setTimeout(resolve, 1000)); + expect(wrapper.vm.$router.currentRoute.value.name).toBe("logs"); + }); +}); \ No newline at end of file diff --git a/web/src/views/ShortUrl.vue b/web/src/views/ShortUrl.vue new file mode 100644 index 00000000000..69c4b6aaef1 --- /dev/null +++ b/web/src/views/ShortUrl.vue @@ -0,0 +1,111 @@ + + + + + diff --git a/web/src/views/__snapshots__/ShortUrl.spec.ts.snap b/web/src/views/__snapshots__/ShortUrl.spec.ts.snap new file mode 100644 index 00000000000..ad7bbb3b723 --- /dev/null +++ b/web/src/views/__snapshots__/ShortUrl.spec.ts.snap @@ -0,0 +1,9 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ShortUrl > Should match snapshot 1`] = ` +"
+ + +
Redirecting...
+
" +`;