Skip to content

Commit a8de364

Browse files
wprzytulamuzarski
authored andcommitted
network: extract tls module from connection
As the tls module got quite large, it makes sense to extract it to the network supermodule.
1 parent 7c95833 commit a8de364

File tree

6 files changed

+231
-237
lines changed

6 files changed

+231
-237
lines changed

scylla/src/client/session.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::errors::{
1818
RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError,
1919
};
2020
use crate::frame::response::result;
21-
use crate::network::TlsProvider;
21+
use crate::network::tls::TlsProvider;
2222
use crate::network::{Connection, ConnectionConfig, PoolConfig, VerifiedKeyspaceName};
2323
use crate::observability::driver_tracing::RequestSpan;
2424
use crate::observability::history::{self, HistoryListener};

scylla/src/cloud/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use uuid::Uuid;
1111
use crate::client::session::TlsContext;
1212
use crate::cluster::node::resolve_hostname;
1313
use crate::errors::TranslationError;
14-
use crate::network::{TlsConfig, TlsError};
14+
use crate::network::tls::{TlsConfig, TlsError};
1515
use crate::policies::address_translator::{AddressTranslator, UntranslatedPeer};
1616

1717
#[non_exhaustive]

scylla/src/errors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub use crate::response::query_result::{
2222
pub use crate::authentication::AuthError;
2323

2424
// Re-export error type from network module.
25-
pub use crate::network::TlsError;
25+
pub use crate::network::tls::TlsError;
2626

2727
// Re-export error types from scylla-cql.
2828
pub use scylla_cql::deserialize::{DeserializationError, TypeCheckError};

scylla/src/network/connection.rs

Lines changed: 5 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::tls::{TlsConfig, TlsProvider};
12
use crate::authentication::AuthenticatorProvider;
23
use crate::batch::{Batch, BatchStatement};
34
use crate::client::pager::{NextRowError, QueryPager};
@@ -56,8 +57,6 @@ use std::{
5657
cmp::Ordering,
5758
net::{Ipv4Addr, Ipv6Addr},
5859
};
59-
pub use tls_config::TlsError;
60-
pub(crate) use tls_config::{TlsConfig, TlsProvider};
6160
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
6261
use tokio::net::{TcpSocket, TcpStream};
6362
use tokio::sync::{mpsc, oneshot};
@@ -206,230 +205,6 @@ struct TaskResponse {
206205
body: Bytes,
207206
}
208207

209-
mod tls_config {
210-
//! This module contains abstractions related to the TLS layer of driver connections.
211-
//!
212-
//! The full picture looks like this:
213-
//!
214-
//! ┌─←─ TlsContext (openssl::SslContext / rustls::ClientConfig)
215-
//! │
216-
//! ├─←─ CloudConfig (powered by either TLS backend)
217-
//! │
218-
//! │ gets wrapped in
219-
//! │
220-
//! ↳TlsProvider (same for all connections)
221-
//! │
222-
//! │ produces
223-
//! │
224-
//! ↳TlsConfig (specific for the particular connection)
225-
//! │
226-
//! │ produces
227-
//! │
228-
//! ↳Tls (wrapper over TCP stream which adds encryption)
229-
230-
use std::io;
231-
#[cfg(feature = "unstable-cloud")]
232-
use std::sync::Arc;
233-
234-
#[cfg(feature = "unstable-cloud")]
235-
use tracing::warn;
236-
#[cfg(feature = "unstable-cloud")]
237-
use uuid::Uuid;
238-
239-
use crate::client::session::TlsContext;
240-
#[cfg(feature = "unstable-cloud")]
241-
use crate::cloud::CloudConfig;
242-
#[cfg(feature = "unstable-cloud")]
243-
use crate::cluster::metadata::PeerEndpoint;
244-
use crate::cluster::metadata::UntranslatedEndpoint;
245-
#[cfg(feature = "unstable-cloud")]
246-
use crate::cluster::node::ResolvedContactPoint;
247-
248-
/// Abstraction capable of producing [TlsConfig] for connections on-demand.
249-
#[derive(Clone)] // Cheaply clonable (reference-counted)
250-
pub(crate) enum TlsProvider {
251-
GlobalContext(TlsContext),
252-
#[cfg(feature = "unstable-cloud")]
253-
ScyllaCloud(Arc<CloudConfig>),
254-
}
255-
256-
impl TlsProvider {
257-
/// Used in case when the user provided their own [TlsContext] to be used in all connections.
258-
pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
259-
Self::GlobalContext(context)
260-
}
261-
262-
/// Used in the cloud case.
263-
#[cfg(feature = "unstable-cloud")]
264-
pub(crate) fn new_cloud(cloud_config: Arc<CloudConfig>) -> Self {
265-
Self::ScyllaCloud(cloud_config)
266-
}
267-
268-
/// Produces a [TlsConfig] that is specific for the given endpoint.
269-
pub(crate) fn make_tls_config(
270-
&self,
271-
// Currently, this is only used for cloud; but it makes abstract sense to pass endpoint here
272-
// also for non-cloud cases, so let's just allow(unused).
273-
#[allow(unused)] endpoint: &UntranslatedEndpoint,
274-
) -> Option<TlsConfig> {
275-
match self {
276-
TlsProvider::GlobalContext(context) => {
277-
Some(TlsConfig::new_with_global_context(context.clone()))
278-
}
279-
#[cfg(feature = "unstable-cloud")]
280-
TlsProvider::ScyllaCloud(cloud_config) => {
281-
let (host_id, address, dc) = match *endpoint {
282-
UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
283-
address,
284-
ref datacenter,
285-
}) => (None, address, datacenter.as_deref()), // FIXME: Pass DC in ContactPoint
286-
UntranslatedEndpoint::Peer(PeerEndpoint {
287-
host_id,
288-
address,
289-
ref datacenter,
290-
..
291-
}) => (Some(host_id), address.into_inner(), datacenter.as_deref()),
292-
};
293-
294-
cloud_config.make_tls_config_for_scylla_cloud_host(host_id, dc, address)
295-
// inspect_err() is stable since 1.76.
296-
// TODO: use inspect_err once we bump MSRV to at least 1.76.
297-
.map_err(|err| {
298-
warn!(
299-
"TlsProvider for SNI connection to Scylla Cloud node {{ host_id={:?}, dc={:?} at {} }} could not be set up: {}\n Proceeding with attempting probably nonworking connection",
300-
host_id,
301-
dc,
302-
address,
303-
err
304-
);
305-
}).ok().flatten()
306-
}
307-
}
308-
}
309-
}
310-
311-
/// Encapsulates TLS-regarding configuration that is specific for a particular endpoint.
312-
///
313-
/// Both use cases are supported:
314-
/// 1. User-provided global TlsContext. Then, the global TlsContext is simply cloned here.
315-
/// 2. Serverless Cloud. Then the TlsContext is customized for the given endpoint,
316-
/// and its SNI information is stored alongside.
317-
#[derive(Clone)]
318-
pub(crate) struct TlsConfig {
319-
context: TlsContext,
320-
#[cfg(feature = "unstable-cloud")]
321-
sni: Option<String>,
322-
}
323-
324-
/// An abstraction over connection's TLS layer which holds its state and configuration.
325-
pub(crate) enum Tls {
326-
#[cfg(feature = "openssl-010")]
327-
OpenSsl010(openssl::ssl::Ssl),
328-
#[cfg(feature = "rustls-023")]
329-
Rustls023 {
330-
connector: tokio_rustls::TlsConnector,
331-
#[cfg(feature = "unstable-cloud")]
332-
sni: Option<rustls::pki_types::ServerName<'static>>,
333-
},
334-
}
335-
336-
/// A wrapper around a TLS error.
337-
///
338-
/// The original error came from one of the supported TLS backends.
339-
#[derive(Debug, thiserror::Error)]
340-
#[error(transparent)]
341-
#[non_exhaustive]
342-
pub enum TlsError {
343-
#[cfg(feature = "openssl-010")]
344-
OpenSsl010(#[from] openssl::error::ErrorStack),
345-
#[cfg(feature = "rustls-023")]
346-
InvalidName(#[from] rustls::pki_types::InvalidDnsNameError),
347-
#[cfg(feature = "rustls-023")]
348-
PemParse(#[from] rustls::pki_types::pem::Error),
349-
#[cfg(feature = "rustls-023")]
350-
Rustls023(#[from] rustls::Error),
351-
}
352-
353-
impl From<TlsError> for io::Error {
354-
fn from(value: TlsError) -> Self {
355-
match value {
356-
#[cfg(feature = "openssl-010")]
357-
TlsError::OpenSsl010(e) => e.into(),
358-
#[cfg(feature = "rustls-023")]
359-
TlsError::InvalidName(e) => io::Error::new(io::ErrorKind::Other, e),
360-
#[cfg(feature = "rustls-023")]
361-
TlsError::PemParse(e) => io::Error::new(io::ErrorKind::Other, e),
362-
#[cfg(feature = "rustls-023")]
363-
TlsError::Rustls023(e) => io::Error::new(io::ErrorKind::Other, e),
364-
}
365-
}
366-
}
367-
368-
impl TlsConfig {
369-
/// Used in case when the user provided their own TlsContext to be used in all connections.
370-
pub(crate) fn new_with_global_context(context: TlsContext) -> Self {
371-
Self {
372-
context,
373-
#[cfg(feature = "unstable-cloud")]
374-
sni: None,
375-
}
376-
}
377-
378-
/// Used in case of Serverless Cloud connections.
379-
#[cfg(feature = "unstable-cloud")]
380-
pub(crate) fn new_for_sni(
381-
context: TlsContext,
382-
domain_name: &str,
383-
host_id: Option<Uuid>,
384-
) -> Self {
385-
Self {
386-
context,
387-
#[cfg(feature = "unstable-cloud")]
388-
sni: Some(if let Some(host_id) = host_id {
389-
format!("{}.{}", host_id, domain_name)
390-
} else {
391-
domain_name.into()
392-
}),
393-
}
394-
}
395-
396-
/// Produces a new Tls object that is able to wrap a TCP stream.
397-
pub(crate) fn new_tls(&self) -> Result<Tls, TlsError> {
398-
// To silence warnings when TlsContext is an empty enum (tls features are disabled).
399-
#[allow(unreachable_code)]
400-
match self.context {
401-
#[cfg(feature = "openssl-010")]
402-
TlsContext::OpenSsl010(ref context) => {
403-
#[allow(unused_mut)]
404-
let mut ssl = openssl::ssl::Ssl::new(context)?;
405-
#[cfg(feature = "unstable-cloud")]
406-
if let Some(sni) = self.sni.as_ref() {
407-
ssl.set_hostname(sni)?;
408-
}
409-
Ok(Tls::OpenSsl010(ssl))
410-
}
411-
#[cfg(feature = "rustls-023")]
412-
TlsContext::Rustls023(ref config) => {
413-
let connector = tokio_rustls::TlsConnector::from(config.clone());
414-
#[cfg(feature = "unstable-cloud")]
415-
let sni = self
416-
.sni
417-
.as_deref()
418-
.map(rustls::pki_types::ServerName::try_from)
419-
.transpose()?
420-
.map(|s| s.to_owned());
421-
422-
Ok(Tls::Rustls023 {
423-
connector,
424-
#[cfg(feature = "unstable-cloud")]
425-
sni,
426-
})
427-
}
428-
}
429-
}
430-
}
431-
}
432-
433208
impl<'id: 'map, 'map> SelfIdentity<'id> {
434209
fn add_startup_options(&'id self, options: &'map mut HashMap<Cow<'id, str>, Cow<'id, str>>) {
435210
/* Driver identity. */
@@ -1598,9 +1373,9 @@ impl Connection {
15981373
#[allow(unreachable_code)]
15991374
match tls_config.new_tls()? {
16001375
#[cfg(feature = "openssl-010")]
1601-
tls_config::Tls::OpenSsl010(ssl) => {
1602-
let mut stream =
1603-
tokio_openssl::SslStream::new(ssl, stream).map_err(TlsError::OpenSsl010)?;
1376+
crate::network::tls::Tls::OpenSsl010(ssl) => {
1377+
let mut stream = tokio_openssl::SslStream::new(ssl, stream)
1378+
.map_err(crate::network::tls::TlsError::OpenSsl010)?;
16041379
std::pin::Pin::new(&mut stream)
16051380
.connect()
16061381
.await
@@ -1617,7 +1392,7 @@ impl Connection {
16171392
.await);
16181393
}
16191394
#[cfg(feature = "rustls-023")]
1620-
tls_config::Tls::Rustls023 {
1395+
crate::network::tls::Tls::Rustls023 {
16211396
connector,
16221397
#[cfg(feature = "unstable-cloud")]
16231398
sni,

scylla/src/network/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
//! - NodeConnectionPool - a manager that keeps a desired number of connections opened to each shard.
66
77
mod connection;
8-
#[cfg(feature = "unstable-cloud")]
9-
pub(crate) use connection::TlsConfig;
10-
pub use connection::TlsError;
11-
pub(crate) use connection::TlsProvider;
128
pub(crate) use connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
139

1410
mod connection_pool;
1511

1612
pub use connection_pool::PoolSize;
1713
pub(crate) use connection_pool::{NodeConnectionPool, PoolConfig};
14+
15+
pub(crate) mod tls;

0 commit comments

Comments
 (0)