|
| 1 | +use anyhow::{Context as _, Result, anyhow, bail}; |
| 2 | +use core::future::{Future as _, poll_fn}; |
| 3 | +use core::pin::pin; |
| 4 | +use core::str; |
| 5 | +use core::task::{Poll, ready}; |
| 6 | +use futures::try_join; |
| 7 | +use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses; |
| 8 | +use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket}; |
| 9 | +use test_programs::p3::wasi::tls; |
| 10 | +use test_programs::p3::wasi::tls::client::Hello; |
| 11 | +use test_programs::p3::wit_stream; |
| 12 | +use wit_bindgen::StreamResult; |
| 13 | + |
| 14 | +struct Component; |
| 15 | + |
| 16 | +test_programs::p3::export!(Component); |
| 17 | + |
| 18 | +const PORT: u16 = 443; |
| 19 | + |
| 20 | +async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { |
| 21 | + let request = format!( |
| 22 | + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" |
| 23 | + ); |
| 24 | + |
| 25 | + let sock = TcpSocket::create(ip.family()).unwrap(); |
| 26 | + sock.connect(IpSocketAddress::new(ip, PORT)) |
| 27 | + .await |
| 28 | + .context("tcp connect failed")?; |
| 29 | + |
| 30 | + let (sock_rx, sock_rx_fut) = sock.receive(); |
| 31 | + let hello = Hello::new(); |
| 32 | + hello |
| 33 | + .set_server_name(domain) |
| 34 | + .map_err(|()| anyhow!("failed to set SNI"))?; |
| 35 | + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); |
| 36 | + let sock_tx_fut = sock.send(sock_tx); |
| 37 | + |
| 38 | + let mut conn = pin!(conn.into_future()); |
| 39 | + let mut sock_rx_fut = pin!(sock_rx_fut.into_future()); |
| 40 | + let mut sock_tx_fut = pin!(sock_tx_fut); |
| 41 | + let conn = poll_fn(|cx| match conn.as_mut().poll(cx) { |
| 42 | + Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)), |
| 43 | + Poll::Ready(Err(())) => Poll::Ready(Err(anyhow!("tls handshake failed"))), |
| 44 | + Poll::Pending => match sock_tx_fut.as_mut().poll(cx) { |
| 45 | + Poll::Ready(Ok(())) => Poll::Ready(Err(anyhow!("Tx stream closed unexpectedly"))), |
| 46 | + Poll::Ready(Err(err)) => { |
| 47 | + Poll::Ready(Err(anyhow!("Tx stream closed with error: {err:?}"))) |
| 48 | + } |
| 49 | + Poll::Pending => match ready!(sock_rx_fut.as_mut().poll(cx)) { |
| 50 | + Ok(_) => Poll::Ready(Err(anyhow!("Rx stream closed unexpectedly"))), |
| 51 | + Err(err) => Poll::Ready(Err(anyhow!("Rx stream closed with error: {err:?}"))), |
| 52 | + }, |
| 53 | + }, |
| 54 | + }) |
| 55 | + .await?; |
| 56 | + |
| 57 | + let (mut req_tx, req_rx) = wit_stream::new(); |
| 58 | + let (mut res_rx, result_fut) = tls::client::Handshake::finish(conn, req_rx); |
| 59 | + |
| 60 | + let res = Vec::with_capacity(8192); |
| 61 | + try_join!( |
| 62 | + async { |
| 63 | + let buf = req_tx.write_all(request.into()).await; |
| 64 | + assert_eq!(buf, []); |
| 65 | + drop(req_tx); |
| 66 | + Ok(()) |
| 67 | + }, |
| 68 | + async { |
| 69 | + let (result, buf) = res_rx.read(res).await; |
| 70 | + match result { |
| 71 | + StreamResult::Complete(..) => { |
| 72 | + drop(res_rx); |
| 73 | + let res = String::from_utf8(buf)?; |
| 74 | + if res.contains("HTTP/1.1 200 OK") { |
| 75 | + Ok(()) |
| 76 | + } else { |
| 77 | + bail!("server did not respond with 200 OK: {res}") |
| 78 | + } |
| 79 | + } |
| 80 | + StreamResult::Dropped => bail!("read dropped"), |
| 81 | + StreamResult::Cancelled => bail!("read cancelled"), |
| 82 | + } |
| 83 | + }, |
| 84 | + async { result_fut.await.map_err(|()| anyhow!("TLS session failed")) }, |
| 85 | + async { sock_rx_fut.await.context("TCP receipt failed") }, |
| 86 | + async { sock_tx_fut.await.context("TCP transmit failed") }, |
| 87 | + )?; |
| 88 | + Ok(()) |
| 89 | +} |
| 90 | + |
| 91 | +/// This test sets up a TCP connection using one domain, and then attempts to |
| 92 | +/// perform a TLS handshake using another unrelated domain. This should result |
| 93 | +/// in a handshake error. |
| 94 | +async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { |
| 95 | + const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; |
| 96 | + |
| 97 | + let sock = TcpSocket::create(ip.family()).unwrap(); |
| 98 | + sock.connect(IpSocketAddress::new(ip, PORT)) |
| 99 | + .await |
| 100 | + .context("tcp connect failed")?; |
| 101 | + |
| 102 | + let (sock_rx, sock_rx_fut) = sock.receive(); |
| 103 | + let hello = Hello::new(); |
| 104 | + hello |
| 105 | + .set_server_name(BAD_DOMAIN) |
| 106 | + .map_err(|()| anyhow!("failed to set SNI"))?; |
| 107 | + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); |
| 108 | + let sock_tx_fut = sock.send(sock_tx); |
| 109 | + |
| 110 | + try_join!( |
| 111 | + async { |
| 112 | + match conn.await { |
| 113 | + Err(()) => Ok(()), |
| 114 | + Ok(_) => panic!("expecting server name mismatch"), |
| 115 | + } |
| 116 | + }, |
| 117 | + async { sock_rx_fut.await.context("TCP receipt failed") }, |
| 118 | + async { sock_tx_fut.await.context("TCP transmit failed") }, |
| 119 | + )?; |
| 120 | + Ok(()) |
| 121 | +} |
| 122 | + |
| 123 | +async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) |
| 124 | +where |
| 125 | + Fut: Future<Output = Result<()>> + 'a, |
| 126 | +{ |
| 127 | + // since this is testing remote endpoints to ensure system cert store works |
| 128 | + // the test uses a couple different endpoints to reduce the number of flakes |
| 129 | + const DOMAINS: &'static [&'static str] = &[ |
| 130 | + "example.com", |
| 131 | + "api.github.com", |
| 132 | + "docs.wasmtime.dev", |
| 133 | + "bytecodealliance.org", |
| 134 | + "www.rust-lang.org", |
| 135 | + ]; |
| 136 | + |
| 137 | + for &domain in DOMAINS { |
| 138 | + let result = (|| async { |
| 139 | + let ip = resolve_addresses(domain.into()) |
| 140 | + .await? |
| 141 | + .first() |
| 142 | + .map(|a| a.to_owned()) |
| 143 | + .ok_or_else(|| anyhow!("DNS lookup failed."))?; |
| 144 | + test(&domain, ip).await |
| 145 | + })(); |
| 146 | + |
| 147 | + match result.await { |
| 148 | + Ok(()) => return, |
| 149 | + Err(e) => { |
| 150 | + eprintln!("test for {domain} failed: {e:#}"); |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + panic!("all tests failed"); |
| 156 | +} |
| 157 | + |
| 158 | +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { |
| 159 | + async fn run() -> Result<(), ()> { |
| 160 | + println!("sample app"); |
| 161 | + try_live_endpoints(test_tls_sample_application).await; |
| 162 | + println!("invalid cert"); |
| 163 | + try_live_endpoints(test_tls_invalid_certificate).await; |
| 164 | + Ok(()) |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +fn main() {} |
0 commit comments