Skip to content

Commit 589ff19

Browse files
committed
feat: support customizing the reqwest client in the Client builder
1 parent 7011839 commit 589ff19

File tree

4 files changed

+134
-3
lines changed

4 files changed

+134
-3
lines changed

typesense/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ trybuild = "1.0.42"
4848
# native-only dev deps
4949
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
5050
tokio = { workspace = true}
51+
tokio-rustls = "0.26"
52+
rcgen = "0.14"
5153
wiremock = "0.6"
5254

5355
# wasm test deps
@@ -64,4 +66,4 @@ required-features = ["derive"]
6466

6567
[[test]]
6668
name = "client"
67-
path = "tests/client/mod.rs"
69+
path = "tests/client/mod.rs"

typesense/src/client/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl Client {
210210
/// - **healthcheck_interval**: 60 seconds.
211211
/// - **retry_policy**: Exponential backoff with a maximum of 3 retries. (disabled on WASM)
212212
/// - **connection_timeout**: 5 seconds. (disabled on WASM)
213+
/// - **reqwest_builder**: An `Fn()` closure returning a `reqwest::ClientBuilder` instance (optional).
213214
#[builder]
214215
pub fn new(
215216
/// The Typesense API key used for authentication.
@@ -235,21 +236,35 @@ impl Client {
235236
#[builder(default = Duration::from_secs(5))]
236237
/// The timeout for each individual network request.
237238
connection_timeout: Duration,
239+
240+
/// An optional custom builder for the HTTP client.
241+
///
242+
/// This is useful if you need to configure custom settings on the HTTP client.
243+
/// The value should be a closure that returns a `reqwest::ClientBuilder` instance.
244+
///
245+
/// Note that this library may apply its own settings before building the client (eg. `connection_timeout`),
246+
/// so not all custom settings may be preserved.
247+
#[builder(with = |f: impl Fn() -> reqwest::ClientBuilder + 'static| Box::new(f))]
248+
reqwest_builder: Option<Box<dyn Fn() -> reqwest::ClientBuilder>>,
238249
) -> Result<Self, &'static str> {
239250
let is_nearest_node_set = nearest_node.is_some();
240251

241252
let nodes: Vec<_> = nodes
242253
.into_iter()
243254
.chain(nearest_node)
244255
.map(|mut url| {
256+
let http_buidler = reqwest_builder
257+
.as_ref()
258+
.map(|f| f())
259+
.unwrap_or_else(|| reqwest::Client::builder());
245260
#[cfg(target_arch = "wasm32")]
246-
let http_client = reqwest::Client::builder()
261+
let http_client = http_buidler
247262
.build()
248263
.expect("Failed to build reqwest client");
249264

250265
#[cfg(not(target_arch = "wasm32"))]
251266
let http_client = ReqwestMiddlewareClientBuilder::new(
252-
reqwest::Client::builder()
267+
http_buidler
253268
.timeout(connection_timeout)
254269
.build()
255270
.expect("Failed to build reqwest client"),

typesense/tests/client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod operations_test;
1010
mod presets_test;
1111
mod stemming_dictionaries_test;
1212
mod stopwords_test;
13+
mod tls_certificate_test;
1314

1415
use std::time::Duration;
1516
use typesense::{Client, ExponentialBackoff};
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use std::{
2+
net::{IpAddr, Ipv4Addr},
3+
sync::Arc,
4+
time::Duration,
5+
};
6+
use tokio::{
7+
io::{AsyncReadExt, AsyncWriteExt as _},
8+
net::TcpListener,
9+
};
10+
use tokio_rustls::{
11+
TlsAcceptor,
12+
rustls::{
13+
self, ServerConfig,
14+
pki_types::{CertificateDer, PrivateKeyDer},
15+
},
16+
};
17+
use typesense::ExponentialBackoff;
18+
19+
/// Reqwest custom builder test.
20+
///
21+
/// In this test we exercise the `reqwest_builder` option by setting up a custom root TLS certificate.
22+
/// If the cusomization doesn't work, reqwest would be unable to connect to the mocked Typesense node.
23+
#[tokio::test]
24+
async fn test_reqwest_cusomt_builder() {
25+
rustls::crypto::aws_lc_rs::default_provider()
26+
.install_default()
27+
.expect("Failed to install crypto provider");
28+
29+
let api_key = "xxx-api-key";
30+
31+
// generate a self-signed key pair and build TLS config out of it
32+
let (cert, key) = generate_self_signed_cert();
33+
let tls_config = ServerConfig::builder()
34+
.with_no_client_auth()
35+
.with_single_cert(vec![cert.clone()], key)
36+
.expect("failed to build TLS config");
37+
38+
let localhost = IpAddr::V4(Ipv4Addr::LOCALHOST);
39+
let listener = TcpListener::bind((localhost, 0))
40+
.await
41+
.expect("Failed to bind to address");
42+
let server_addr = listener.local_addr().expect("Failed to get local address");
43+
44+
// spawn a handler which handles one /health request over a TLS connection
45+
let handler = tokio::spawn(mock_node_handler(listener, tls_config, api_key));
46+
47+
// create the client, configuring the certificate with reqwest
48+
let client_cert = reqwest::Certificate::from_der(&cert)
49+
.expect("Failed to convert certificate to Certificate");
50+
let client = typesense::Client::builder()
51+
.nodes(vec![format!("https://localhost:{}", server_addr.port())])
52+
.api_key(api_key)
53+
.reqwest_builder(move || {
54+
reqwest::Client::builder().add_root_certificate(client_cert.clone())
55+
})
56+
.healthcheck_interval(Duration::from_secs(9001)) // we'll do a healthcheck manually
57+
.retry_policy(ExponentialBackoff::builder().build_with_max_retries(0)) // no retries
58+
.connection_timeout(Duration::from_secs(1)) // short
59+
.build()
60+
.expect("Failed to create Typesense client");
61+
62+
// request /health
63+
client
64+
.operations()
65+
.health()
66+
.await
67+
.expect("Failed to get collection health");
68+
69+
handler.await.expect("Failed to join handler");
70+
}
71+
72+
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
73+
let pair = rcgen::generate_simple_self_signed(["localhost".into()])
74+
.expect("Failed to generate self-signed certificate");
75+
let cert = pair.cert.der().clone();
76+
let signing_key = pair.signing_key.serialize_der();
77+
let signing_key = PrivateKeyDer::try_from(signing_key)
78+
.expect("Failed to convert signing key to PrivateKeyDer");
79+
(cert, signing_key)
80+
}
81+
82+
async fn mock_node_handler(listener: TcpListener, tls_config: ServerConfig, api_key: &'static str) {
83+
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
84+
let (stream, _addr) = listener
85+
.accept()
86+
.await
87+
.expect("Failed to accept connection");
88+
let mut stream = tls_acceptor
89+
.accept(stream)
90+
.await
91+
.expect("Failed to accept TLS connection");
92+
93+
let mut buf = vec![0u8; 1024];
94+
stream
95+
.read(&mut buf[..])
96+
.await
97+
.expect("Failed to read request");
98+
let request = String::from_utf8(buf).expect("Failed to parse request as UTF-8");
99+
assert!(request.contains("/health"));
100+
assert!(request.contains(api_key));
101+
102+
// mock a /health response
103+
let response = r#"HTTP/1.1 200 OK\r\n\
104+
Content-Type: application/json;\r\n\
105+
Connection: close\r\n
106+
107+
{"ok": true}"#;
108+
stream
109+
.write_all(&response.as_bytes())
110+
.await
111+
.expect("Failed to write to stream");
112+
stream.shutdown().await.expect("Failed to shutdown stream");
113+
}

0 commit comments

Comments
 (0)