diff --git a/Cargo.lock b/Cargo.lock index e1a76d404b..8f4b84ef89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8517,6 +8517,7 @@ dependencies = [ "reqwest 0.12.9", "rustls 0.23.18", "serde", + "spin-common", "spin-factor-outbound-networking", "spin-factor-variables", "spin-factors", diff --git a/crates/common/src/assert.rs b/crates/common/src/assert.rs new file mode 100644 index 0000000000..ad41eed6c3 --- /dev/null +++ b/crates/common/src/assert.rs @@ -0,0 +1,33 @@ +//! Assertion macros. + +/// Asserts that the expression matches the pattern. +/// +/// This is equivalent to `assert!(matches!(...))` except that it produces nicer +/// errors. +#[macro_export] +macro_rules! assert_matches { + ($expr:expr, $pat:pat $(,)?) => {{ + let val = $expr; + assert!( + matches!(val, $pat), + "expected {val:?} to match {}", + stringify!($pat), + ) + }}; +} + +/// Asserts that the expression does not match the pattern. +/// +/// This is equivalent to `assert!(!matches!(...))` except that it produces +/// nicer errors. +#[macro_export] +macro_rules! assert_not_matches { + ($expr:expr, $pat:pat $(,)?) => {{ + let val = $expr; + assert!( + !matches!(val, $pat), + "expected {val:?} to NOT match {}", + stringify!($pat), + ) + }}; +} diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index f747d1544f..778852d379 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -9,6 +9,7 @@ // - Code should have at least 2 dependents pub mod arg_parser; +pub mod assert; pub mod data_dir; pub mod paths; pub mod sha256; diff --git a/crates/factor-outbound-http/Cargo.toml b/crates/factor-outbound-http/Cargo.toml index 2dba3e9f13..03f5136296 100644 --- a/crates/factor-outbound-http/Cargo.toml +++ b/crates/factor-outbound-http/Cargo.toml @@ -27,6 +27,7 @@ wasmtime-wasi = { workspace = true } wasmtime-wasi-http = { workspace = true } [dev-dependencies] +spin-common = { path = "../common" } spin-factor-variables = { path = "../factor-variables" } spin-factors-test = { path = "../factors-test" } diff --git a/crates/factor-outbound-http/tests/factor_test.rs b/crates/factor-outbound-http/tests/factor_test.rs index 7074605f41..95816712d7 100644 --- a/crates/factor-outbound-http/tests/factor_test.rs +++ b/crates/factor-outbound-http/tests/factor_test.rs @@ -2,6 +2,7 @@ use std::time::Duration; use anyhow::bail; use http::{Request, Uri}; +use spin_common::{assert_matches, assert_not_matches}; use spin_factor_outbound_http::{OutboundHttpFactor, SelfRequestOrigin}; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_variables::VariablesFactor; @@ -31,10 +32,10 @@ async fn allowed_host_is_allowed() -> anyhow::Result<()> { // Different systems handle the discard prefix differently; some will // immediately reject it while others will silently let it time out - match future_resp.unwrap_ready().unwrap() { - Err(ErrorCode::ConnectionRefused | ErrorCode::ConnectionTimeout) => (), - other => bail!("expected Err(ConnectionRefused | ConnectionTimeout), got {other:?}"), - }; + assert_matches!( + future_resp.unwrap_ready().unwrap(), + Err(ErrorCode::ConnectionRefused | ErrorCode::ConnectionTimeout), + ); Ok(()) } @@ -52,10 +53,10 @@ async fn self_request_smoke_test() -> anyhow::Result<()> { // Different systems handle the discard prefix differently; some will // immediately reject it while others will silently let it time out - match future_resp.unwrap_ready().unwrap() { - Err(ErrorCode::ConnectionRefused | ErrorCode::ConnectionTimeout) => (), - other => bail!("expected Err(ConnectionRefused | ConnectionTimeout), got {other:?}"), - }; + assert_matches!( + future_resp.unwrap_ready().unwrap(), + Err(ErrorCode::ConnectionRefused | ErrorCode::ConnectionTimeout), + ); Ok(()) } @@ -67,10 +68,10 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { let req = Request::get("https://denied.test").body(Default::default())?; let mut future_resp = wasi_http.send_request(req, test_request_config())?; future_resp.ready().await; - match future_resp.unwrap_ready().unwrap() { - Ok(_) => bail!("expected Err, got Ok"), - Err(err) => assert!(matches!(err, ErrorCode::HttpRequestDenied)), - }; + assert_matches!( + future_resp.unwrap_ready().unwrap(), + Err(ErrorCode::HttpRequestDenied), + ); Ok(()) } @@ -89,11 +90,11 @@ async fn disallowed_private_ips_fails() -> anyhow::Result<()> { Ok(_) => {} // If private IPs are disallowed, we should get an error saying the destination is prohibited Err(err) if !allow_private_ips => { - assert!(matches!(err, ErrorCode::DestinationIpProhibited)) + assert_matches!(err, ErrorCode::DestinationIpProhibited); } // Otherwise, we should get some non-DestinationIpProhibited error Err(err) => { - assert!(!matches!(err, ErrorCode::DestinationIpProhibited)) + assert_not_matches!(err, ErrorCode::DestinationIpProhibited); } }; Ok(())