Skip to content

Commit 68e16dc

Browse files
committed
feat: support customizing the reqwest client in the Client builder
1 parent 7011839 commit 68e16dc

File tree

4 files changed

+207
-3
lines changed

4 files changed

+207
-3
lines changed

typesense/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ trybuild = "1.0.42"
4848
# native-only dev deps
4949
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
5050
tokio = { workspace = true}
51+
tokio-rustls = "0.26"
52+
rcgen = "0.14"
5153
wiremock = "0.6"
5254

5355
# wasm test deps
@@ -64,4 +66,4 @@ required-features = ["derive"]
6466

6567
[[test]]
6668
name = "client"
67-
path = "tests/client/mod.rs"
69+
path = "tests/client/mod.rs"

typesense/src/client/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl Client {
210210
/// - **healthcheck_interval**: 60 seconds.
211211
/// - **retry_policy**: Exponential backoff with a maximum of 3 retries. (disabled on WASM)
212212
/// - **connection_timeout**: 5 seconds. (disabled on WASM)
213+
/// - **http_builder**: An `Fn()` closure returning a `reqwest::ClientBuilder` instance (optional).
213214
#[builder]
214215
pub fn new(
215216
/// The Typesense API key used for authentication.
@@ -235,21 +236,35 @@ impl Client {
235236
#[builder(default = Duration::from_secs(5))]
236237
/// The timeout for each individual network request.
237238
connection_timeout: Duration,
239+
240+
/// An optional custom builder for the HTTP client.
241+
///
242+
/// This is useful if you need to configure custom settings on the HTTP client.
243+
/// The value should be a closure that returns a `reqwest::ClientBuilder` instance.
244+
///
245+
/// Note that this library may apply its own settings before building the client (eg. `connection_timeout`),
246+
/// so not all custom settings may be preserved.
247+
#[builder(with = |f: impl Fn() -> reqwest::ClientBuilder + 'static| Box::new(f))]
248+
http_builder: Option<Box<dyn Fn() -> reqwest::ClientBuilder>>,
238249
) -> Result<Self, &'static str> {
239250
let is_nearest_node_set = nearest_node.is_some();
240251

241252
let nodes: Vec<_> = nodes
242253
.into_iter()
243254
.chain(nearest_node)
244255
.map(|mut url| {
256+
let http_buidler = http_builder
257+
.as_ref()
258+
.map(|f| f())
259+
.unwrap_or_else(reqwest::Client::builder);
245260
#[cfg(target_arch = "wasm32")]
246-
let http_client = reqwest::Client::builder()
261+
let http_client = http_buidler
247262
.build()
248263
.expect("Failed to build reqwest client");
249264

250265
#[cfg(not(target_arch = "wasm32"))]
251266
let http_client = ReqwestMiddlewareClientBuilder::new(
252-
reqwest::Client::builder()
267+
http_buidler
253268
.timeout(connection_timeout)
254269
.build()
255270
.expect("Failed to build reqwest client"),
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use std::{
2+
net::{IpAddr, Ipv4Addr},
3+
sync::{
4+
Arc,
5+
atomic::{AtomicBool, Ordering},
6+
},
7+
time::Duration,
8+
};
9+
use tokio::{
10+
io::{AsyncReadExt, AsyncWriteExt as _},
11+
net::TcpListener,
12+
};
13+
use tokio_rustls::{
14+
TlsAcceptor,
15+
rustls::{
16+
self, ServerConfig,
17+
pki_types::{CertificateDer, PrivateKeyDer},
18+
},
19+
};
20+
use typesense::ExponentialBackoff;
21+
22+
/// Test that the `http_builder` option can be used to set up a custom DNS resolver.
23+
///
24+
/// In this test we exercise the `http_builder` option by setting up a custom DNS resolver that
25+
/// will set an atomic boolean to true when it is called. It then simply returns an error.
26+
///
27+
/// This test should run on WASM as well so we're not setting up any mock server.
28+
async fn test_http_builder_sideeffect() {
29+
#[derive(Debug, thiserror::Error)]
30+
#[error("test error")]
31+
struct TestError;
32+
33+
#[derive(Default, Clone)]
34+
struct TestResolver(Arc<AtomicBool>);
35+
36+
impl reqwest::dns::Resolve for TestResolver {
37+
fn resolve(&self, _name: reqwest::dns::Name) -> reqwest::dns::Resolving {
38+
self.0.store(true, Ordering::SeqCst);
39+
Box::pin(async move { Err(Box::new(TestError) as _) })
40+
}
41+
}
42+
43+
let test_resolver = TestResolver::default();
44+
let client = typesense::Client::builder()
45+
.nodes(vec!["http://localhost:8108"])
46+
.api_key("xyz")
47+
.http_builder({
48+
let test_resolver = test_resolver.clone();
49+
move || reqwest::Client::builder().dns_resolver2(test_resolver.clone())
50+
})
51+
.build()
52+
.expect("Failed to create Typesense client");
53+
54+
// call the health endpoint, this will fail as the custom DNS fails intentionally
55+
client.operations().health().await.unwrap_err();
56+
57+
// make sure the custom DNS resolver was called
58+
assert!(test_resolver.0.load(Ordering::SeqCst));
59+
}
60+
61+
#[cfg(not(target_arch = "wasm32"))]
62+
/// Reqwest custom builder test.
63+
///
64+
/// In this test we exercise the `reqwest_builder` option by setting up a custom root TLS certificate.
65+
/// If the cusomization doesn't work, reqwest would be unable to connect to the mocked Typesense node.
66+
///
67+
/// This test is non-WASM as it needs TCP.
68+
async fn test_http_builder_tls() {
69+
rustls::crypto::aws_lc_rs::default_provider()
70+
.install_default()
71+
.expect("Failed to install crypto provider");
72+
73+
let api_key = "xxx-api-key";
74+
75+
// generate a self-signed key pair and build TLS config out of it
76+
let (cert, key) = generate_self_signed_cert();
77+
let tls_config = ServerConfig::builder()
78+
.with_no_client_auth()
79+
.with_single_cert(vec![cert.clone()], key)
80+
.expect("failed to build TLS config");
81+
82+
let localhost = IpAddr::V4(Ipv4Addr::LOCALHOST);
83+
let listener = TcpListener::bind((localhost, 0))
84+
.await
85+
.expect("Failed to bind to address");
86+
let server_addr = listener.local_addr().expect("Failed to get local address");
87+
88+
// spawn a handler which handles one /health request over a TLS connection
89+
let handler = tokio::spawn(mock_node_handler(listener, tls_config, api_key));
90+
91+
// create the client, configuring the certificate with reqwest
92+
let client_cert = reqwest::Certificate::from_der(&cert)
93+
.expect("Failed to convert certificate to Certificate");
94+
let client = typesense::Client::builder()
95+
.nodes(vec![format!("https://localhost:{}", server_addr.port())])
96+
.api_key(api_key)
97+
.http_builder(move || {
98+
reqwest::Client::builder()
99+
.add_root_certificate(client_cert.clone())
100+
.https_only(true)
101+
})
102+
.healthcheck_interval(Duration::from_secs(9001)) // we'll do a healthcheck manually
103+
.retry_policy(ExponentialBackoff::builder().build_with_max_retries(0)) // no retries
104+
.connection_timeout(Duration::from_secs(1)) // short
105+
.build()
106+
.expect("Failed to create Typesense client");
107+
108+
// request /health
109+
client
110+
.operations()
111+
.health()
112+
.await
113+
.expect("Failed to get collection health");
114+
115+
handler.await.expect("Failed to join handler");
116+
}
117+
118+
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
119+
let pair = rcgen::generate_simple_self_signed(["localhost".into()])
120+
.expect("Failed to generate self-signed certificate");
121+
let cert = pair.cert.der().clone();
122+
let signing_key = pair.signing_key.serialize_der();
123+
let signing_key = PrivateKeyDer::try_from(signing_key)
124+
.expect("Failed to convert signing key to PrivateKeyDer");
125+
(cert, signing_key)
126+
}
127+
128+
async fn mock_node_handler(listener: TcpListener, tls_config: ServerConfig, api_key: &'static str) {
129+
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
130+
let (stream, _addr) = listener
131+
.accept()
132+
.await
133+
.expect("Failed to accept connection");
134+
let mut stream = tls_acceptor
135+
.accept(stream)
136+
.await
137+
.expect("Failed to accept TLS connection");
138+
139+
let mut buf = vec![0u8; 1024];
140+
stream
141+
.read(&mut buf[..])
142+
.await
143+
.expect("Failed to read request");
144+
let request = String::from_utf8(buf).expect("Failed to parse request as UTF-8");
145+
assert!(request.contains("/health"));
146+
assert!(request.contains(api_key));
147+
148+
// mock a /health response
149+
let response = r#"HTTP/1.1 200 OK\r\n\
150+
Content-Type: application/json;\r\n\
151+
Connection: close\r\n
152+
153+
{"ok": true}"#;
154+
stream
155+
.write_all(&response.as_bytes())
156+
.await
157+
.expect("Failed to write to stream");
158+
stream.shutdown().await.expect("Failed to shutdown stream");
159+
}
160+
161+
#[cfg(all(test, not(target_arch = "wasm32")))]
162+
mod tokio_test {
163+
#[tokio::test]
164+
async fn test_http_builder_sideeffect() {
165+
super::test_http_builder_sideeffect().await;
166+
}
167+
168+
#[tokio::test]
169+
async fn test_http_builder_tls() {
170+
super::test_http_builder_tls().await;
171+
}
172+
}
173+
174+
#[cfg(all(test, target_arch = "wasm32"))]
175+
mod wasm_test {
176+
use super::*;
177+
use wasm_bindgen_test::wasm_bindgen_test;
178+
179+
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
180+
181+
#[wasm_bindgen_test]
182+
async fn test_http_builder_sideeffect() {
183+
console_error_panic_hook::set_once();
184+
super::test_http_builder_sideeffect().await;
185+
}
186+
}

typesense/tests/client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod collections_test;
44
mod conversation_models_test;
55
mod derive_integration_test;
66
mod documents_test;
7+
mod http_builder_test;
78
mod keys_test;
89
mod multi_search_test;
910
mod operations_test;

0 commit comments

Comments
 (0)