diff --git a/Cargo.lock b/Cargo.lock index 71ad9c779..a73929e4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,12 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "async-broadcast" version = "0.7.2" @@ -3103,13 +3109,17 @@ dependencies = [ name = "stackable-webhook" version = "0.4.0" dependencies = [ + "arc-swap", "axum", + "clap", "futures-util", "hyper", "hyper-util", "k8s-openapi", "kube", "opentelemetry", + "opentelemetry-semantic-conventions", + "rand 0.9.1", "serde_json", "snafu 0.8.6", "stackable-certs", @@ -3121,6 +3131,7 @@ dependencies = [ "tower-http", "tracing", "tracing-opentelemetry", + "x509-cert", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1bc0bc32f..e50a67c40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/stackabletech/operator-rs" [workspace.dependencies] product-config = { git = "https://github.com/stackabletech/product-config.git", tag = "0.7.0" } +arc-swap = "1.7" axum = { version = "0.8.1", features = ["http2"] } chrono = { version = "0.4.38", default-features = false } clap = { version = "4.5.17", features = ["derive", "cargo", "env"] } diff --git a/crates/stackable-certs/src/ca/consts.rs b/crates/stackable-certs/src/ca/consts.rs index 125a63a05..bcd080cd4 100644 --- a/crates/stackable-certs/src/ca/consts.rs +++ b/crates/stackable-certs/src/ca/consts.rs @@ -1,6 +1,6 @@ use stackable_operator::time::Duration; -/// The default CA validity time span of one hour (3600 seconds). +/// The default CA validity time span pub const DEFAULT_CA_VALIDITY: Duration = Duration::from_hours_unchecked(1); /// The root CA subject name containing only the common name. diff --git a/crates/stackable-certs/src/ca/mod.rs b/crates/stackable-certs/src/ca/mod.rs index d04da34ef..08f57d918 100644 --- a/crates/stackable-certs/src/ca/mod.rs +++ b/crates/stackable-certs/src/ca/mod.rs @@ -38,7 +38,7 @@ pub enum Error { #[snafu(display("failed to generate RSA signing key"))] GenerateRsaSigningKey { source: rsa::Error }, - #[snafu(display("failed to generate ECDSA signign key"))] + #[snafu(display("failed to generate ECDSA signing key"))] GenerateEcdsaSigningKey { source: ecdsa::Error }, #[snafu(display("failed to parse {subject:?} as subject"))] diff --git a/crates/stackable-operator/CHANGELOG.md b/crates/stackable-operator/CHANGELOG.md index 86945a87f..481aff0b8 100644 --- a/crates/stackable-operator/CHANGELOG.md +++ b/crates/stackable-operator/CHANGELOG.md @@ -4,15 +4,24 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Added + +- BREAKING: Add two new required CLI arguments: `--operator-namespace` and `--operator-service-name`. + These two values are used to construct the service name in the CRD conversion webhook ([#1066]). + ### Changed - BREAKING: The `ResolvedProductImage` field `app_version_label` was renamed to `app_version_label_value` to match changes to its type ([#1076]). +- BREAKING: Rename two fields of the `ProductOperatorRun` struct for consistency and clarity ([#1066]): + - `telemetry_arguments` -> `telemetry` + - `cluster_info_opts` -> `cluster_info` ### Fixed - BREAKING: Fix bug where `ResolvedProductImage::app_version_label` could not be used as a label value because it can contain invalid characters. This is the case when referencing custom images via a `@sha256:...` hash. As such, the `product_image_selection::resolve` function is now fallible ([#1076]). +[#1066]: https://github.com/stackabletech/operator-rs/pull/1066 [#1076]: https://github.com/stackabletech/operator-rs/pull/1076 ## [0.94.0] - 2025-07-10 @@ -65,6 +74,7 @@ All notable changes to this project will be documented in this file. [#1058]: https://github.com/stackabletech/operator-rs/pull/1058 [#1060]: https://github.com/stackabletech/operator-rs/pull/1060 [#1064]: https://github.com/stackabletech/operator-rs/pull/1064 +[#1066]: https://github.com/stackabletech/operator-rs/pull/1066 [#1068]: https://github.com/stackabletech/operator-rs/pull/1068 [#1069]: https://github.com/stackabletech/operator-rs/pull/1069 [#1071]: https://github.com/stackabletech/operator-rs/pull/1071 diff --git a/crates/stackable-operator/src/cli.rs b/crates/stackable-operator/src/cli.rs index d9dafeb07..bfaa1704f 100644 --- a/crates/stackable-operator/src/cli.rs +++ b/crates/stackable-operator/src/cli.rs @@ -116,7 +116,7 @@ use product_config::ProductConfigManager; use snafu::{ResultExt, Snafu}; use stackable_telemetry::tracing::TelemetryOptions; -use crate::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOpts}; +use crate::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOptions}; pub const AUTHOR: &str = "Stackable GmbH - info@stackable.tech"; @@ -163,10 +163,10 @@ pub enum Command { /// Can be embedded into an extended argument set: /// /// ```rust -/// # use stackable_operator::cli::{Command, ProductOperatorRun, ProductConfigPath}; +/// # use stackable_operator::cli::{Command, OperatorEnvironmentOptions, ProductOperatorRun, ProductConfigPath}; +/// # use stackable_operator::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOptions}; +/// # use stackable_telemetry::tracing::TelemetryOptions; /// use clap::Parser; -/// use stackable_operator::{namespace::WatchNamespace, utils::cluster_info::KubernetesClusterInfoOpts}; -/// use stackable_telemetry::tracing::TelemetryOptions; /// /// #[derive(clap::Parser, Debug, PartialEq, Eq)] /// struct Run { @@ -176,17 +176,36 @@ pub enum Command { /// common: ProductOperatorRun, /// } /// -/// let opts = Command::::parse_from(["foobar-operator", "run", "--name", "foo", "--product-config", "bar", "--watch-namespace", "foobar", "--kubernetes-node-name", "baz"]); +/// let opts = Command::::parse_from([ +/// "foobar-operator", +/// "run", +/// "--name", +/// "foo", +/// "--product-config", +/// "bar", +/// "--watch-namespace", +/// "foobar", +/// "--operator-namespace", +/// "stackable-operators", +/// "--operator-service-name", +/// "foo-operator", +/// "--kubernetes-node-name", +/// "baz", +/// ]); /// assert_eq!(opts, Command::Run(Run { /// name: "foo".to_string(), /// common: ProductOperatorRun { /// product_config: ProductConfigPath::from("bar".as_ref()), /// watch_namespace: WatchNamespace::One("foobar".to_string()), -/// telemetry_arguments: TelemetryOptions::default(), -/// cluster_info_opts: KubernetesClusterInfoOpts { +/// telemetry: TelemetryOptions::default(), +/// cluster_info: KubernetesClusterInfoOptions { /// kubernetes_cluster_domain: None, /// kubernetes_node_name: "baz".to_string(), /// }, +/// operator_environment: OperatorEnvironmentOptions { +/// operator_namespace: "stackable-operators".to_string(), +/// operator_service_name: "foo-operator".to_string(), +/// }, /// }, /// })); /// ``` @@ -220,10 +239,13 @@ pub struct ProductOperatorRun { pub watch_namespace: WatchNamespace, #[command(flatten)] - pub telemetry_arguments: TelemetryOptions, + pub operator_environment: OperatorEnvironmentOptions, + + #[command(flatten)] + pub telemetry: TelemetryOptions, #[command(flatten)] - pub cluster_info_opts: KubernetesClusterInfoOpts, + pub cluster_info: KubernetesClusterInfoOptions, } /// A path to a [`ProductConfigManager`] spec file @@ -281,11 +303,26 @@ impl ProductConfigPath { } } +#[derive(clap::Parser, Debug, PartialEq, Eq)] +pub struct OperatorEnvironmentOptions { + /// The namespace the operator is running in, usually `stackable-operators`. + /// + /// Note that when running the operator on Kubernetes we recommend to use the + /// [downward API](https://kubernetes.io/docs/concepts/workloads/pods/downward-api/) + /// to let Kubernetes project the namespace as the `OPERATOR_NAMESPACE` env variable. + #[arg(long, env)] + pub operator_namespace: String, + + /// The name of the service the operator is reachable at, usually + /// something like `-operator`. + #[arg(long, env)] + pub operator_service_name: String, +} + #[cfg(test)] mod tests { - use std::{env, fs::File}; + use std::fs::File; - use clap::Parser; use rstest::*; use tempfile::tempdir; @@ -294,7 +331,6 @@ mod tests { const USER_PROVIDED_PATH: &str = "user_provided_path_properties.yaml"; const DEPLOY_FILE_PATH: &str = "deploy_config_spec_properties.yaml"; const DEFAULT_FILE_PATH: &str = "default_file_path_properties.yaml"; - const WATCH_NAMESPACE: &str = "WATCH_NAMESPACE"; #[test] fn verify_cli() { @@ -378,76 +414,4 @@ mod tests { panic!("must return RequiredFileMissing when file was not found") } } - - #[test] - fn product_operator_run_watch_namespace() { - // clean env var to not interfere if already set - unsafe { env::remove_var(WATCH_NAMESPACE) }; - - // cli with namespace - let opts = ProductOperatorRun::parse_from([ - "run", - "--product-config", - "bar", - "--watch-namespace", - "foo", - "--kubernetes-node-name", - "baz", - ]); - assert_eq!( - opts, - ProductOperatorRun { - product_config: ProductConfigPath::from("bar".as_ref()), - watch_namespace: WatchNamespace::One("foo".to_string()), - cluster_info_opts: KubernetesClusterInfoOpts { - kubernetes_cluster_domain: None, - kubernetes_node_name: "baz".to_string() - }, - telemetry_arguments: Default::default(), - } - ); - - // no cli / no env - let opts = ProductOperatorRun::parse_from([ - "run", - "--product-config", - "bar", - "--kubernetes-node-name", - "baz", - ]); - assert_eq!( - opts, - ProductOperatorRun { - product_config: ProductConfigPath::from("bar".as_ref()), - watch_namespace: WatchNamespace::All, - cluster_info_opts: KubernetesClusterInfoOpts { - kubernetes_cluster_domain: None, - kubernetes_node_name: "baz".to_string() - }, - telemetry_arguments: Default::default(), - } - ); - - // env with namespace - unsafe { env::set_var(WATCH_NAMESPACE, "foo") }; - let opts = ProductOperatorRun::parse_from([ - "run", - "--product-config", - "bar", - "--kubernetes-node-name", - "baz", - ]); - assert_eq!( - opts, - ProductOperatorRun { - product_config: ProductConfigPath::from("bar".as_ref()), - watch_namespace: WatchNamespace::One("foo".to_string()), - cluster_info_opts: KubernetesClusterInfoOpts { - kubernetes_cluster_domain: None, - kubernetes_node_name: "baz".to_string() - }, - telemetry_arguments: Default::default(), - } - ); - } } diff --git a/crates/stackable-operator/src/client.rs b/crates/stackable-operator/src/client.rs index 5d493866e..f79a1eb91 100644 --- a/crates/stackable-operator/src/client.rs +++ b/crates/stackable-operator/src/client.rs @@ -21,7 +21,7 @@ use tracing::trace; use crate::{ kvp::LabelSelectorExt, - utils::cluster_info::{KubernetesClusterInfo, KubernetesClusterInfoOpts}, + utils::cluster_info::{KubernetesClusterInfo, KubernetesClusterInfoOptions}, }; pub type Result = std::result::Result; @@ -529,13 +529,13 @@ impl Client { /// use k8s_openapi::api::core::v1::Pod; /// use stackable_operator::{ /// client::{Client, initialize_operator}, - /// utils::cluster_info::KubernetesClusterInfoOpts, + /// utils::cluster_info::KubernetesClusterInfoOptions, /// }; /// /// #[tokio::main] /// async fn main() { - /// let cluster_info_opts = KubernetesClusterInfoOpts::parse(); - /// let client = initialize_operator(None, &cluster_info_opts) + /// let cluster_info_options = KubernetesClusterInfoOptions::parse(); + /// let client = initialize_operator(None, &cluster_info_options) /// .await /// .expect("Unable to construct client."); /// let watcher_config: watcher::Config = @@ -652,7 +652,7 @@ where pub async fn initialize_operator( field_manager: Option, - cluster_info_opts: &KubernetesClusterInfoOpts, + cluster_info_opts: &KubernetesClusterInfoOptions, ) -> Result { let kubeconfig: Config = kube::Config::infer() .await @@ -687,10 +687,10 @@ mod tests { }; use tokio::time::error::Elapsed; - use crate::utils::cluster_info::KubernetesClusterInfoOpts; + use crate::utils::cluster_info::KubernetesClusterInfoOptions; - async fn test_cluster_info_opts() -> KubernetesClusterInfoOpts { - KubernetesClusterInfoOpts { + async fn test_cluster_info_opts() -> KubernetesClusterInfoOptions { + KubernetesClusterInfoOptions { // We have to hard-code a made-up cluster domain, // since kubernetes_node_name (probably) won't be a valid Node that we can query. kubernetes_cluster_domain: Some( diff --git a/crates/stackable-operator/src/utils/cluster_info.rs b/crates/stackable-operator/src/utils/cluster_info.rs index 56c718f9e..d8c64976b 100644 --- a/crates/stackable-operator/src/utils/cluster_info.rs +++ b/crates/stackable-operator/src/utils/cluster_info.rs @@ -16,13 +16,17 @@ pub struct KubernetesClusterInfo { } #[derive(clap::Parser, Debug, PartialEq, Eq)] -pub struct KubernetesClusterInfoOpts { +pub struct KubernetesClusterInfoOptions { /// Kubernetes cluster domain, usually this is `cluster.local`. // We are not using a default value here, as we query the cluster if it is not specified. #[arg(long, env)] pub kubernetes_cluster_domain: Option, /// Name of the Kubernetes Node that the operator is running on. + /// + /// Note that when running the operator on Kubernetes we recommend to use the + /// [downward API](https://kubernetes.io/docs/concepts/workloads/pods/downward-api/) + /// to let Kubernetes project the namespace as the `KUBERNETES_NODE_NAME` env variable. #[arg(long, env)] pub kubernetes_node_name: String, } @@ -30,10 +34,10 @@ pub struct KubernetesClusterInfoOpts { impl KubernetesClusterInfo { pub async fn new( client: &Client, - cluster_info_opts: &KubernetesClusterInfoOpts, + cluster_info_opts: &KubernetesClusterInfoOptions, ) -> Result { let cluster_domain = match cluster_info_opts { - KubernetesClusterInfoOpts { + KubernetesClusterInfoOptions { kubernetes_cluster_domain: Some(cluster_domain), .. } => { @@ -41,7 +45,7 @@ impl KubernetesClusterInfo { cluster_domain.clone() } - KubernetesClusterInfoOpts { + KubernetesClusterInfoOptions { kubernetes_node_name: node_name, .. } => { diff --git a/crates/stackable-telemetry/src/instrumentation/axum/mod.rs b/crates/stackable-telemetry/src/instrumentation/axum/mod.rs index b72450c99..14e7928cd 100644 --- a/crates/stackable-telemetry/src/instrumentation/axum/mod.rs +++ b/crates/stackable-telemetry/src/instrumentation/axum/mod.rs @@ -73,22 +73,6 @@ const OTEL_TRACE_ID_TO: &str = "opentelemetry.trace_id.to"; /// # let _: Router = router; /// ``` /// -/// ### Example with Webhook -/// -/// The usage is even simpler when combined with the `stackable_webhook` crate. -/// The webhook server has built-in support to automatically emit HTTP spans on -/// every incoming request. -/// -/// ``` -/// use stackable_webhook::{WebhookServer, Options}; -/// use axum::Router; -/// -/// let router = Router::new(); -/// let server = WebhookServer::new(router, Options::default()); -/// -/// # let _: WebhookServer = server; -/// ``` -/// /// This layer is implemented based on [this][1] official Tower guide. /// /// [1]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md diff --git a/crates/stackable-webhook/CHANGELOG.md b/crates/stackable-webhook/CHANGELOG.md index 1cd10e2d9..1871f9102 100644 --- a/crates/stackable-webhook/CHANGELOG.md +++ b/crates/stackable-webhook/CHANGELOG.md @@ -4,6 +4,25 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed + +- BREAKING: Re-write the `ConversionWebhookServer`. + It can now do CRD conversions, handle multiple CRDs and takes care of reconciling the CRDs ([#1066]). +- BREAKING: The `TlsServer` can now handle certificate rotation. + To achieve this, a new `CertificateResolver` was added. + Also, `TlsServer::new` now returns an additional `mpsc::Receiver`, so that the caller + can get notified about certificate rotations happening ([#1066]). +- `stackable_webhook::Options` has been renamed to `stackable_webhook::WebhookOptions`, as well as + `OptionsBuilder` to `WebhookOptionsBuilder` ([#1066]). + +### Removed + +- Remove `StatefulWebhookHandler` to reduce maintenance effort. + Also, webhooks are ideally stateless, so that they can be scaled horizontally. + It can be re-added once needed ([#1066]). + +[#1066]: https://github.com/stackabletech/operator-rs/pull/1066 + ## [0.4.0] - 2025-07-10 ### Fixed @@ -23,7 +42,7 @@ All notable changes to this project will be documented in this file. ## [0.3.1] - 2024-07-10 -## Changed +### Changed - Remove instrumentation of long running functions, add more granular instrumentation of futures. Adjust span and event levels ([#811]). - Bump rust-toolchain to 1.79.0 ([#822]). diff --git a/crates/stackable-webhook/Cargo.toml b/crates/stackable-webhook/Cargo.toml index 398a8a3f0..5c1bf70eb 100644 --- a/crates/stackable-webhook/Cargo.toml +++ b/crates/stackable-webhook/Cargo.toml @@ -11,6 +11,7 @@ stackable-certs = { path = "../stackable-certs", features = ["rustls"] } stackable-telemetry = { path = "../stackable-telemetry" } stackable-operator = { path = "../stackable-operator" } +arc-swap.workspace = true axum.workspace = true futures-util.workspace = true hyper-util.workspace = true @@ -18,6 +19,8 @@ hyper.workspace = true k8s-openapi.workspace = true kube.workspace = true opentelemetry.workspace = true +opentelemetry-semantic-conventions.workspace = true +rand.workspace = true serde_json.workspace = true snafu.workspace = true tokio-rustls.workspace = true @@ -26,3 +29,7 @@ tower-http.workspace = true tower.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true +x509-cert.workspace = true + +[dev-dependencies] +clap.workspace = true diff --git a/crates/stackable-webhook/src/lib.rs b/crates/stackable-webhook/src/lib.rs index 186f19e12..c98515e1e 100644 --- a/crates/stackable-webhook/src/lib.rs +++ b/crates/stackable-webhook/src/lib.rs @@ -1,6 +1,6 @@ //! Utility types and functions to easily create ready-to-use webhook servers //! which can handle different tasks, for example CRD conversions. All webhook -//! servers use HTTPS by defaultThis library is fully compatible with the +//! servers use HTTPS by default. This library is fully compatible with the //! [`tracing`] crate and emits debug level tracing data. //! //! Most users will only use the top-level exported generic [`WebhookServer`] @@ -8,27 +8,34 @@ //! routes and their handler functions. //! //! ``` -//! use stackable_webhook::{WebhookServer, Options}; +//! use stackable_webhook::{WebhookServer, WebhookOptions}; //! use axum::Router; //! +//! # async fn test() { //! let router = Router::new(); -//! let server = WebhookServer::new(router, Options::default()); +//! let (server, cert_rx) = WebhookServer::new(router, WebhookOptions::default()) +//! .await +//! .expect("failed to create WebhookServer"); +//! # } //! ``` //! //! For some usages, complete end-to-end [`WebhookServer`] implementations -//! exist. One such implementation is the [`ConversionWebhookServer`][1]. The -//! only required parameters are a conversion handler function and [`Options`]. +//! exist. One such implementation is the [`ConversionWebhookServer`][1]. //! //! This library additionally also exposes lower-level structs and functions to -//! enable complete controll over these details if needed. +//! enable complete control over these details if needed. //! //! [1]: crate::servers::ConversionWebhookServer use axum::{Router, routing::get}; use futures_util::{FutureExt as _, pin_mut, select}; use snafu::{ResultExt, Snafu}; use stackable_telemetry::AxumTraceLayer; -use tokio::signal::unix::{SignalKind, signal}; +use tokio::{ + signal::unix::{SignalKind, signal}, + sync::mpsc, +}; use tower::ServiceBuilder; +use x509_cert::Certificate; // use tower_http::trace::TraceLayer; use crate::tls::TlsServer; @@ -39,11 +46,7 @@ pub mod servers; pub mod tls; // Selected re-exports -pub use crate::options::Options; - -/// A result type alias with the library-level [`Error`] type as teh default -/// error type. -pub type Result = std::result::Result; +pub use crate::options::WebhookOptions; /// A generic webhook handler receiving a request and sending back a response. /// @@ -56,25 +59,16 @@ pub trait WebhookHandler { fn call(self, req: Req) -> Res; } -/// A generic webhook handler receiving a request and state and sending back -/// a response. -/// -/// This trait is not intended to be implemented by external crates and this -/// library provides various ready-to-use implementations for it. One such an -/// implementation is part of the [`ConversionWebhookServer`][1]. -/// -/// [1]: crate::servers::ConversionWebhookServer -pub trait StatefulWebhookHandler { - fn call(self, req: Req, state: S) -> Res; -} +/// A result type alias with the [`WebhookError`] type as the default error type. +pub type Result = std::result::Result; #[derive(Debug, Snafu)] -pub enum Error { +pub enum WebhookError { #[snafu(display("failed to create TLS server"))] - CreateTlsServer { source: tls::Error }, + CreateTlsServer { source: tls::TlsServerError }, #[snafu(display("failed to run TLS server"))] - RunTlsServer { source: tls::Error }, + RunTlsServer { source: tls::TlsServerError }, } /// A ready-to-use webhook server. @@ -88,17 +82,16 @@ pub enum Error { /// /// [1]: crate::servers::ConversionWebhookServer pub struct WebhookServer { - options: Options, - router: Router, + tls_server: TlsServer, } impl WebhookServer { /// Creates a new ready-to-use webhook server. /// - /// The server listens on `socket_addr` which is provided via the [`Options`] - /// and handles routing based on the provided Axum `router`. Most of the time - /// it is sufficient to use [`Options::default()`]. See the documentation - /// for [`Options`] for more details on the default values. + /// The server listens on `socket_addr` which is provided via the [`WebhookOptions`] and handles + /// routing based on the provided Axum `router`. Most of the time it is sufficient to use + /// [`WebhookOptions::default()`]. See the documentation for [`WebhookOptions`] for more details + /// on the default values. /// /// To start the server, use the [`WebhookServer::run()`] function. This will /// run the server using the Tokio runtime until it is terminated. @@ -106,29 +99,66 @@ impl WebhookServer { /// ### Basic Example /// /// ``` - /// use stackable_webhook::{WebhookServer, Options}; + /// use stackable_webhook::{WebhookServer, WebhookOptions}; /// use axum::Router; /// + /// # async fn test() { /// let router = Router::new(); - /// let server = WebhookServer::new(router, Options::default()); + /// let (server, cert_rx) = WebhookServer::new(router, WebhookOptions::default()) + /// .await + /// .expect("failed to create WebhookServer"); + /// # } /// ``` /// /// ### Example with Custom Options /// /// ``` - /// use stackable_webhook::{WebhookServer, Options}; + /// use stackable_webhook::{WebhookServer, WebhookOptions}; /// use axum::Router; /// - /// let options = Options::builder() + /// # async fn test() { + /// let options = WebhookOptions::builder() /// .bind_address([127, 0, 0, 1], 8080) + /// .add_subject_alterative_dns_name("my-san-entry") /// .build(); /// /// let router = Router::new(); - /// let server = WebhookServer::new(router, options); + /// let (server, cert_rx) = WebhookServer::new(router, options) + /// .await + /// .expect("failed to create WebhookServer"); + /// # } /// ``` - pub fn new(router: Router, options: Options) -> Self { + pub async fn new( + router: Router, + options: WebhookOptions, + ) -> Result<(Self, mpsc::Receiver)> { tracing::trace!("create new webhook server"); - Self { options, router } + + // TODO (@Techassi): Make opt-in configurable from the outside + // Create an OpenTelemetry tracing layer + tracing::trace!("create tracing service (layer)"); + let trace_layer = AxumTraceLayer::new().with_opt_in(); + + // Use a service builder to provide multiple layers at once. Recommended + // by the Axum project. + // + // See https://docs.rs/axum/latest/axum/middleware/index.html#applying-multiple-middleware + // TODO (@NickLarsenNZ): rename this server_builder and keep it specific to tracing, since it's placement in the chain is important + let service_builder = ServiceBuilder::new().layer(trace_layer); + + // Create the root router and merge the provided router into it. + tracing::debug!("create core router and merge provided router"); + let router = router + .layer(service_builder) + // The health route is below the AxumTraceLayer so as not to be instrumented + .route("/health", get(|| async { "ok" })); + + tracing::debug!("create TLS server"); + let (tls_server, cert_rx) = TlsServer::new(router, options) + .await + .context(CreateTlsServerSnafu)?; + + Ok((Self { tls_server }, cert_rx)) } /// Runs the Webhook server and sets up signal handlers for shutting down. @@ -170,33 +200,6 @@ impl WebhookServer { async fn run_server(self) -> Result<()> { tracing::debug!("run webhook server"); - // TODO (@Techassi): Make opt-in configurable from the outside - // Create an OpenTelemetry tracing layer - tracing::trace!("create tracing service (layer)"); - let trace_layer = AxumTraceLayer::new().with_opt_in(); - - // Use a service builder to provide multiple layers at once. Recommended - // by the Axum project. - // - // See https://docs.rs/axum/latest/axum/middleware/index.html#applying-multiple-middleware - // TODO (@NickLarsenNZ): rename this server_builder and keep it specific to tracing, since it's placement in the chain is important - let service_builder = ServiceBuilder::new().layer(trace_layer); - - // Create the root router and merge the provided router into it. - tracing::debug!("create core router and merge provided router"); - let router = self - .router - .layer(service_builder) - // The health route is below the AxumTraceLayer so as not to be instrumented - .route("/health", get(|| async { "ok" })); - - // Create server for TLS termination - tracing::debug!("create TLS server"); - let tls_server = TlsServer::new(self.options.socket_addr, router) - .await - .context(CreateTlsServerSnafu)?; - - tracing::info!("running TLS server"); - tls_server.run().await.context(RunTlsServerSnafu) + self.tls_server.run().await.context(RunTlsServerSnafu) } } diff --git a/crates/stackable-webhook/src/options.rs b/crates/stackable-webhook/src/options.rs index 99a01133e..90623b093 100644 --- a/crates/stackable-webhook/src/options.rs +++ b/crates/stackable-webhook/src/options.rs @@ -10,65 +10,69 @@ use crate::constants::DEFAULT_SOCKET_ADDRESS; /// Specifies available webhook server options. /// -/// The [`Default`] implemention for this struct contains the following -/// values: +/// The [`Default`] implementation for this struct contains the following values: /// /// - The socket binds to 127.0.0.1 on port 8443 (HTTPS) -/// - The TLS cert used gets auto-generated +/// - An empty list of SANs is provided to the certificate the TLS server uses. /// /// ### Example with Custom HTTPS IP Address and Port /// /// ``` -/// use stackable_webhook::Options; +/// use stackable_webhook::WebhookOptions; /// /// // Set IP address and port at the same time -/// let options = Options::builder() +/// let options = WebhookOptions::builder() /// .bind_address([0, 0, 0, 0], 12345) /// .build(); /// /// // Set IP address only -/// let options = Options::builder() +/// let options = WebhookOptions::builder() /// .bind_ip([0, 0, 0, 0]) /// .build(); /// /// // Set port only -/// let options = Options::builder() +/// let options = WebhookOptions::builder() /// .bind_port(12345) /// .build(); /// ``` #[derive(Debug)] -pub struct Options { +pub struct WebhookOptions { /// The default HTTPS socket address the [`TcpListener`][tokio::net::TcpListener] /// binds to. pub socket_addr: SocketAddr, + + /// The subject alterative DNS names that should be added to the certificates generated for this + /// webhook. + pub subject_alterative_dns_names: Vec, } -impl Default for Options { +impl Default for WebhookOptions { fn default() -> Self { Self::builder().build() } } -impl Options { - /// Returns the default [`OptionsBuilder`] which allows to selectively - /// customize the options. See the documention for [`Options`] for more +impl WebhookOptions { + /// Returns the default [`WebhookOptionsBuilder`] which allows to selectively + /// customize the options. See the documentation for [`WebhookOptions`] for more /// information on available functions. - pub fn builder() -> OptionsBuilder { - OptionsBuilder::default() + pub fn builder() -> WebhookOptionsBuilder { + WebhookOptionsBuilder::default() } } -/// The [`OptionsBuilder`] which allows to selectively customize the webhook -/// server [`Options`]. +/// The [`WebhookOptionsBuilder`] which allows to selectively customize the webhook +/// server [`WebhookOptions`]. /// /// Usually, this struct is not constructed manually, but instead by calling -/// [`Options::builder()`] or [`OptionsBuilder::default()`]. +/// [`WebhookOptions::builder()`] or [`WebhookOptionsBuilder::default()`]. #[derive(Debug, Default)] -pub struct OptionsBuilder { +pub struct WebhookOptionsBuilder { socket_addr: Option, + subject_alterative_dns_names: Vec, } -impl OptionsBuilder { +impl WebhookOptionsBuilder { /// Sets the socket address the webhook server uses to bind for HTTPS. pub fn bind_address(mut self, bind_ip: impl Into, bind_port: u16) -> Self { self.socket_addr = Some(SocketAddr::new(bind_ip.into(), bind_port)); @@ -91,11 +95,32 @@ impl OptionsBuilder { self } - /// Builds the final [`Options`] by using default values for any not + /// Sets the subject alterative DNS names that should be added to the certificates generated for + /// this webhook. + pub fn subject_alterative_dns_names( + mut self, + subject_alterative_dns_name: Vec, + ) -> Self { + self.subject_alterative_dns_names = subject_alterative_dns_name; + self + } + + /// Adds the subject alterative DNS name to the list of names. + pub fn add_subject_alterative_dns_name( + mut self, + subject_alterative_dns_name: impl Into, + ) -> Self { + self.subject_alterative_dns_names + .push(subject_alterative_dns_name.into()); + self + } + + /// Builds the final [`WebhookOptions`] by using default values for any not /// explicitly set option. - pub fn build(self) -> Options { - Options { + pub fn build(self) -> WebhookOptions { + WebhookOptions { socket_addr: self.socket_addr.unwrap_or(DEFAULT_SOCKET_ADDRESS), + subject_alterative_dns_names: self.subject_alterative_dns_names, } } } diff --git a/crates/stackable-webhook/src/servers/conversion.rs b/crates/stackable-webhook/src/servers/conversion.rs index 9b1ff197b..0a0412e6c 100644 --- a/crates/stackable-webhook/src/servers/conversion.rs +++ b/crates/stackable-webhook/src/servers/conversion.rs @@ -1,14 +1,62 @@ -use std::fmt::Debug; +use std::{fmt::Debug, net::SocketAddr}; -use axum::{Json, Router, extract::State, routing::post}; +use axum::{Json, Router, routing::post}; +use k8s_openapi::{ + ByteString, + apiextensions_apiserver::pkg::apis::apiextensions::v1::{ + CustomResourceConversion, CustomResourceDefinition, ServiceReference, WebhookClientConfig, + WebhookConversion, + }, +}; // Re-export this type because users of the conversion webhook server require // this type to write the handler function. Instead of importing this type from // kube directly, consumers can use this type instead. This also eliminates // keeping the kube dependency version in sync between here and the operator. pub use kube::core::conversion::ConversionReview; +use kube::{ + Api, Client, ResourceExt, + api::{Patch, PatchParams}, +}; +use snafu::{OptionExt, ResultExt, Snafu}; +use stackable_operator::cli::OperatorEnvironmentOptions; +use tokio::{sync::mpsc, try_join}; use tracing::instrument; +use x509_cert::{ + Certificate, + der::{EncodePem, pem::LineEnding}, +}; -use crate::{StatefulWebhookHandler, WebhookHandler, WebhookServer, options::Options}; +use crate::{ + WebhookError, WebhookHandler, WebhookServer, constants::DEFAULT_HTTPS_PORT, + options::WebhookOptions, +}; + +#[derive(Debug, Snafu)] +pub enum ConversionWebhookError { + #[snafu(display("failed to create webhook server"))] + CreateWebhookServer { source: WebhookError }, + + #[snafu(display("failed to run webhook server"))] + RunWebhookServer { source: WebhookError }, + + #[snafu(display("failed to receive certificate from channel"))] + ReceiveCertificateFromChannel, + + #[snafu(display("failed to convert CA certificate into PEM format"))] + ConvertCaToPem { source: x509_cert::der::Error }, + + #[snafu(display("failed to reconcile CRDs"))] + ReconcileCrds { + #[snafu(source(from(ConversionWebhookError, Box::new)))] + source: Box, + }, + + #[snafu(display("failed to update CRD {crd_name:?}"))] + UpdateCrd { + source: stackable_operator::kube::Error, + crd_name: String, + }, +} impl WebhookHandler for F where @@ -19,141 +67,266 @@ where } } -impl StatefulWebhookHandler for F -where - F: FnOnce(ConversionReview, S) -> ConversionReview, -{ - fn call(self, req: ConversionReview, state: S) -> ConversionReview { - self(req, state) - } +// TODO: Add a builder, maybe with `bon`. +#[derive(Debug)] +pub struct ConversionWebhookOptions { + /// The environment the operator is running in, notably the namespace and service name it is + /// reachable at. + pub operator_environment: OperatorEnvironmentOptions, + + /// The bind address to bind the HTTPS server to. + pub socket_addr: SocketAddr, + + /// The field manager used to apply Kubernetes objects, typically the operator name, e.g. + /// `airflow-operator`. + pub field_manager: String, } /// A ready-to-use CRD conversion webhook server. /// -/// See [`ConversionWebhookServer::new()`] and [`ConversionWebhookServer::new_with_state()`] -/// for usage examples. +/// See [`ConversionWebhookServer::new()`] for usage examples. pub struct ConversionWebhookServer { - options: Options, + crds: Vec, + options: ConversionWebhookOptions, router: Router, + client: Client, } impl ConversionWebhookServer { - /// Creates a new conversion webhook server **without** state which expects - /// POST requests being made to the `/convert` endpoint. + /// Creates a new conversion webhook server, which expects POST requests being made to the + /// `/convert/{crd name}` endpoint. + /// + /// You need to provide two things for every CRD passed in via the `crds_and_handlers` argument: /// - /// Each request is handled by the provided `handler` function. Any function - /// with the signature `(ConversionReview) -> ConversionReview` can be - /// provided. The [`ConversionReview`] type can be imported via a re-export at - /// [`crate::servers::ConversionReview`]. + /// 1. The CRD + /// 2. A conversion function to convert between CRD versions. Typically you would use the + /// the auto-generated `try_convert` function on CRD spec definition structs for this. + /// + /// The [`ConversionWebhookServer`] takes care of reconciling the CRDs into the Kubernetes + /// cluster and takes care of adding itself as conversion webhook. This includes TLS + /// certificates and CA bundles. /// /// # Example /// - /// ``` + /// ```no_run + /// use clap::Parser; /// use stackable_webhook::{ - /// servers::{ConversionReview, ConversionWebhookServer}, - /// Options + /// servers::{ConversionReview, ConversionWebhookServer, ConversionWebhookOptions}, + /// WebhookOptions + /// }; + /// use stackable_operator::{ + /// kube::Client, + /// crd::s3::{S3Connection, S3ConnectionVersion}, + /// cli::OperatorEnvironmentOptions, + /// }; + /// + /// # async fn test() { + /// let crds_and_handlers = [ + /// ( + /// S3Connection::merged_crd(S3ConnectionVersion::V1Alpha1) + /// .expect("failed to merge S3Connection CRD"), + /// S3Connection::try_convert as fn(_) -> _, + /// ), + /// ]; + /// + /// let client = Client::try_default().await.expect("failed to create Kubernetes client"); + /// let operator_environment = OperatorEnvironmentOptions::parse(); + /// + /// let options = ConversionWebhookOptions { + /// operator_environment, + /// socket_addr: "127.0.0.1:8080".parse().unwrap(), + /// field_manager: String::from("product-operator"), /// }; /// /// // Construct the conversion webhook server - /// let server = ConversionWebhookServer::new(handler, Options::default()); + /// let conversion_webhook = ConversionWebhookServer::new( + /// crds_and_handlers, + /// options, + /// client, + /// ) + /// .await + /// .expect("failed to create ConversionWebhookServer"); /// - /// // Define the handler function - /// fn handler(req: ConversionReview) -> ConversionReview { - /// // In here we can do the CRD conversion - /// req - /// } + /// conversion_webhook.run().await.expect("failed to run ConversionWebhookServer"); + /// # } /// ``` - #[instrument(name = "create_conversion_webhook_server", skip(handler))] - pub fn new(handler: H, options: Options) -> Self + #[instrument( + name = "create_conversion_webhook_server", + skip(crds_and_handlers, client) + )] + pub async fn new( + crds_and_handlers: impl IntoIterator, + options: ConversionWebhookOptions, + client: Client, + ) -> Result where H: WebhookHandler + Clone + Send + Sync + 'static, { tracing::debug!("create new conversion webhook server"); - let handler_fn = |Json(review): Json| async { - let review = handler.call(review); - Json(review) - }; + let mut router = Router::new(); + let mut crds = Vec::new(); + for (crd, handler) in crds_and_handlers { + let crd_name = crd.name_any(); + let handler_fn = |Json(review): Json| async { + let review = handler.call(review); + Json(review) + }; + + let route = format!("/convert/{crd_name}"); + router = router.route(&route, post(handler_fn)); + crds.push(crd); + } - let router = Router::new().route("/convert", post(handler_fn)); - Self { router, options } + Ok(Self { + options, + router, + client, + crds, + }) } - /// Creates a new conversion webhook server **with** state which expects - /// POST requests being made to the `/convert` endpoint. - /// - /// Each request is handled by the provided `handler` function. Any function - /// with the signature `(ConversionReview, S) -> ConversionReview` can be - /// provided. The [`ConversionReview`] type can be imported via a re-export at - /// [`crate::servers::ConversionReview`]. - /// - /// It is recommended to wrap the state in an [`Arc`][std::sync::Arc] if it - /// needs to be mutable, see - /// . - /// - /// # Example - /// - /// ``` - /// use std::sync::Arc; - /// - /// use stackable_webhook::{ - /// servers::{ConversionReview, ConversionWebhookServer}, - /// Options - /// }; - /// - /// #[derive(Debug, Clone)] - /// struct State {} - /// - /// let shared_state = Arc::new(State {}); - /// let server = ConversionWebhookServer::new_with_state( - /// handler, - /// shared_state, - /// Options::default(), - /// ); - /// - /// // Define the handler function - /// fn handler(req: ConversionReview, state: Arc) -> ConversionReview { - /// // In here we can do the CRD conversion - /// req - /// } - /// ``` - #[instrument(name = "create_conversion_webhook_server_with_state", skip(handler))] - pub fn new_with_state(handler: H, state: S, options: Options) -> Self - where - H: StatefulWebhookHandler - + Clone - + Send - + Sync - + 'static, - S: Clone + Debug + Send + Sync + 'static, - { - tracing::debug!("create new conversion webhook server with state"); - - // NOTE (@Techassi): Initially, after adding the state extractor, the - // compiler kept throwing a trait error at me stating that the closure - // below doesn't implement the Handler trait from Axum. This had nothing - // to do with the state itself, but rather the order of extractors. All - // body consuming extractors, like the JSON extractor need to come last - // in the handler. - // https://docs.rs/axum/latest/axum/extract/index.html#the-order-of-extractors - let handler_fn = |State(state): State, Json(review): Json| async { - let review = handler.call(review, state); - Json(review) + pub async fn run(self) -> Result<(), ConversionWebhookError> { + tracing::info!("starting conversion webhook server"); + + let Self { + options, + router, + client, + crds, + } = self; + + let ConversionWebhookOptions { + operator_environment: + OperatorEnvironmentOptions { + operator_namespace, + operator_service_name, + }, + socket_addr, + field_manager, + } = &options; + + // This is how Kubernetes calls us, so it decides about the naming. + // AFAIK we can not influence this, so this is the only SAN entry needed. + let subject_alterative_dns_name = + format!("{operator_service_name}.{operator_namespace}.svc",); + + let webhook_options = WebhookOptions { + subject_alterative_dns_names: vec![subject_alterative_dns_name], + socket_addr: *socket_addr, }; - let router = Router::new() - .route("/convert", post(handler_fn)) - .with_state(state); + let (server, mut cert_rx) = WebhookServer::new(router, webhook_options) + .await + .context(CreateWebhookServerSnafu)?; - Self { router, options } + // We block the ConversionWebhookServer creation until the certificates have been generated. + // This way we + // 1. Are able to apply the CRDs before we start the actual controllers relying on them + // 2. Avoid updating them shortly after as cert have been generated. Doing so would cause + // unnecessary "too old resource version" errors in the controllers as the CRD was updated. + let current_cert = cert_rx + .recv() + .await + .context(ReceiveCertificateFromChannelSnafu)?; + Self::reconcile_crds( + &client, + field_manager, + &crds, + &options.operator_environment, + current_cert, + ) + .await + .context(ReconcileCrdsSnafu)?; + + try_join!( + Self::run_webhook_server(server), + Self::run_crd_reconciliation_loop( + cert_rx, + &client, + field_manager, + &crds, + &options.operator_environment, + ), + )?; + + Ok(()) } - /// Starts the conversion webhook server by starting the underlying - /// [`WebhookServer`]. - pub async fn run(self) -> Result<(), crate::Error> { - tracing::info!("starting conversion webhook server"); + async fn run_webhook_server(server: WebhookServer) -> Result<(), ConversionWebhookError> { + server.run().await.context(RunWebhookServerSnafu) + } + + async fn run_crd_reconciliation_loop( + mut cert_rx: mpsc::Receiver, + client: &Client, + field_manager: &str, + crds: &[CustomResourceDefinition], + operator_environment: &OperatorEnvironmentOptions, + ) -> Result<(), ConversionWebhookError> { + while let Some(current_cert) = cert_rx.recv().await { + Self::reconcile_crds( + client, + field_manager, + crds, + operator_environment, + current_cert, + ) + .await + .context(ReconcileCrdsSnafu)?; + } + Ok(()) + } + + #[instrument(skip_all)] + async fn reconcile_crds( + client: &Client, + field_manager: &str, + crds: &[CustomResourceDefinition], + operator_environment: &OperatorEnvironmentOptions, + current_cert: Certificate, + ) -> Result<(), ConversionWebhookError> { + tracing::info!( + crds = ?crds.iter().map(CustomResourceDefinition::name_any).collect::>(), + "Reconciling CRDs" + ); + let ca_bundle = current_cert + .to_pem(LineEnding::LF) + .context(ConvertCaToPemSnafu)?; + + let crd_api: Api = Api::all(client.clone()); + for mut crd in crds.iter().cloned() { + let crd_name = crd.name_any(); + + crd.spec.conversion = Some(CustomResourceConversion { + strategy: "Webhook".to_string(), + webhook: Some(WebhookConversion { + // conversionReviewVersions indicates what ConversionReview versions are understood/preferred by the webhook. + // The first version in the list understood by the API server is sent to the webhook. + // The webhook must respond with a ConversionReview object in the same version it received. + conversion_review_versions: vec!["v1".to_string()], + client_config: Some(WebhookClientConfig { + service: Some(ServiceReference { + name: operator_environment.operator_service_name.to_owned(), + namespace: operator_environment.operator_namespace.to_owned(), + path: Some(format!("/convert/{crd_name}")), + port: Some(DEFAULT_HTTPS_PORT.into()), + }), + ca_bundle: Some(ByteString(ca_bundle.as_bytes().to_vec())), + url: None, + }), + }), + }); - let server = WebhookServer::new(self.router, self.options); - server.run().await + let patch = Patch::Apply(&crd); + let patch_params = PatchParams::apply(field_manager); + crd_api + .patch(&crd_name, &patch_params, &patch) + .await + .with_context(|_| UpdateCrdSnafu { + crd_name: crd_name.to_string(), + })?; + } + Ok(()) } } diff --git a/crates/stackable-webhook/src/servers/mod.rs b/crates/stackable-webhook/src/servers/mod.rs index b242df779..6fbadc12d 100644 --- a/crates/stackable-webhook/src/servers/mod.rs +++ b/crates/stackable-webhook/src/servers/mod.rs @@ -2,4 +2,4 @@ //! purposes. mod conversion; -pub use conversion::*; +pub use conversion::{ConversionWebhookError, ConversionWebhookOptions, ConversionWebhookServer}; diff --git a/crates/stackable-webhook/src/tls.rs b/crates/stackable-webhook/src/tls.rs deleted file mode 100644 index 2aad52ee4..000000000 --- a/crates/stackable-webhook/src/tls.rs +++ /dev/null @@ -1,283 +0,0 @@ -//! This module contains structs and functions to easily create a TLS termination -//! server, which can be used in combination with an Axum [`Router`]. -use std::{net::SocketAddr, sync::Arc}; - -use axum::{Router, extract::Request}; -use futures_util::pin_mut; -use hyper::{body::Incoming, service::service_fn}; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use opentelemetry::trace::{FutureExt, SpanKind}; -use snafu::{ResultExt, Snafu}; -use stackable_certs::{ - CertificatePairError, - ca::{CertificateAuthority, DEFAULT_CA_VALIDITY}, - keys::rsa, -}; -use tokio::net::TcpListener; -use tokio_rustls::{ - TlsAcceptor, - rustls::{ - ServerConfig, - crypto::ring::default_provider, - version::{TLS12, TLS13}, - }, -}; -use tower::{Service, ServiceExt}; -use tracing::{Instrument, Span, field::Empty, instrument}; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -pub type Result = std::result::Result; - -#[derive(Debug, Snafu)] -pub enum Error { - #[snafu(display("failed to construct TLS server config, bad certificate/key"))] - InvalidTlsPrivateKey { source: tokio_rustls::rustls::Error }, - - #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))] - BindTcpListener { - source: std::io::Error, - socket_addr: SocketAddr, - }, - - #[snafu(display("failed to create CA to generate and sign webhook leaf certificate"))] - CreateCertificateAuthority { source: stackable_certs::ca::Error }, - - #[snafu(display("failed to generate webhook leaf certificate"))] - GenerateLeafCertificate { source: stackable_certs::ca::Error }, - - #[snafu(display("failed to encode leaf certificate as DER"))] - EncodeCertificateDer { - source: CertificatePairError, - }, - - #[snafu(display("failed to encode private key as DER"))] - EncodePrivateKeyDer { - source: CertificatePairError, - }, - - #[snafu(display("failed to set safe TLS protocol versions"))] - SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error }, - - #[snafu(display("failed to run task in blocking thread"))] - TokioSpawnBlocking { source: tokio::task::JoinError }, -} - -/// Custom implementation of [`std::cmp::PartialEq`] because some inner types -/// don't implement it. -/// -/// Note that this implementation is restritced to testing because there are -/// variants that use [`stackable_certs::ca::Error`] which only implements -/// [`PartialEq`] for tests. -#[cfg(test)] -impl PartialEq for Error { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - ( - Self::BindTcpListener { - source: lhs_source, - socket_addr: lhs_socket_addr, - }, - Self::BindTcpListener { - source: rhs_source, - socket_addr: rhs_socket_addr, - }, - ) => lhs_socket_addr == rhs_socket_addr && lhs_source.kind() == rhs_source.kind(), - (lhs, rhs) => lhs == rhs, - } - } -} - -/// A server which terminates TLS connections and allows clients to commnunicate -/// via HTTPS with the underlying HTTP router. -pub struct TlsServer { - config: Arc, - socket_addr: SocketAddr, - router: Router, -} - -impl TlsServer { - #[instrument(name = "create_tls_server", skip(router))] - pub async fn new(socket_addr: SocketAddr, router: Router) -> Result { - // NOTE(@NickLarsenNZ): This code is not async, and does take some - // non-negligable amount of time to complete (moreso in debug ). - // We run this in a thread reserved for blocking code so that the Tokio - // executor is able to make progress on other futures instead of being - // blocked. - // See https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html - let task = tokio::task::spawn_blocking(move || { - let mut certificate_authority = - CertificateAuthority::new_rsa().context(CreateCertificateAuthoritySnafu)?; - - let leaf_certificate = certificate_authority - .generate_rsa_leaf_certificate("Leaf", "webhook", [], DEFAULT_CA_VALIDITY) - .context(GenerateLeafCertificateSnafu)?; - - let certificate_der = leaf_certificate - .certificate_der() - .context(EncodeCertificateDerSnafu)?; - - let private_key_der = leaf_certificate - .private_key_der() - .context(EncodePrivateKeyDerSnafu)?; - - let tls_provider = default_provider(); - let mut config = ServerConfig::builder_with_provider(tls_provider.into()) - .with_protocol_versions(&[&TLS12, &TLS13]) - .context(SetSafeTlsProtocolVersionsSnafu)? - .with_no_client_auth() - .with_single_cert(vec![certificate_der], private_key_der) - .context(InvalidTlsPrivateKeySnafu)?; - - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let config = Arc::new(config); - - Ok(Self { - socket_addr, - config, - router, - }) - }) - .await - .context(TokioSpawnBlockingSnafu)??; - - Ok(task) - } - - /// Runs the TLS server by listening for incoming TCP connections on the - /// bound socket address. It only accepts TLS connections. Internally each - /// TLS stream get handled by a Hyper service, which in turn is an Axum - /// router. - pub async fn run(self) -> Result<()> { - let tls_acceptor = TlsAcceptor::from(self.config); - let tcp_listener = - TcpListener::bind(self.socket_addr) - .await - .context(BindTcpListenerSnafu { - socket_addr: self.socket_addr, - })?; - - // To be able to extract the connect info from incoming requests, it is - // required to turn the router into a Tower service which is capable of - // doing that. Calling `into_make_service_with_connect_info` returns a - // new struct `IntoMakeServiceWithConnectInfo` which implements the - // Tower Service trait. This service is called after the TCP connection - // has been accepted. - // - // Inspired by: - // - https://github.com/tokio-rs/axum/discussions/2397 - // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98 - - let mut router = self - .router - .into_make_service_with_connect_info::(); - - pin_mut!(tcp_listener); - loop { - let tls_acceptor = tls_acceptor.clone(); - - // Wait for new tcp connection - let (tcp_stream, remote_addr) = match tcp_listener.accept().await { - Ok((stream, addr)) => (stream, addr), - Err(err) => { - tracing::trace!(%err, "failed to accept incoming TCP connection"); - continue; - } - }; - - // Here, the connect info is extracted by calling Tower's Service - // trait function on `IntoMakeServiceWithConnectInfo` - let tower_service = router.call(remote_addr).await.unwrap(); - - let span = tracing::debug_span!("accept tcp connection"); - tokio::spawn( - async move { - let span = tracing::trace_span!( - "accept tls connection", - "otel.kind" = ?SpanKind::Server, - "otel.status_code" = Empty, - "otel.status_message" = Empty, - "client.address" = remote_addr.ip().to_string(), - "client.port" = remote_addr.port() as i64, - "server.address" = Empty, - "server.port" = Empty, - "network.peer.address" = remote_addr.ip().to_string(), - "network.peer.port" = remote_addr.port() as i64, - "network.local.address" = Empty, - "network.local.port" = Empty, - "network.transport" = "tcp", - "network.type" = self.socket_addr.semantic_convention_network_type(), - ); - - if let Ok(local_addr) = tcp_stream.local_addr() { - let addr = &local_addr.ip().to_string(); - let port = local_addr.port(); - span.record("server.address", addr) - .record("server.port", port as i64) - .record("network.local.address", addr) - .record("network.local.port", port as i64); - } - - // Wait for tls handshake to happen - let tls_stream = match tls_acceptor - .accept(tcp_stream) - .instrument(span.clone()) - .await - { - Ok(tls_stream) => tls_stream, - Err(err) => { - span.record("otel.status_code", "Error") - .record("otel.status_message", err.to_string()); - tracing::trace!(%remote_addr, "error during tls handshake connection"); - return; - } - }; - - // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. - // `TokioIo` converts between them. - let tls_stream = TokioIo::new(tls_stream); - - // Hyper also has its own `Service` trait and doesn't use tower. We can use - // `hyper::service::service_fn` to create a hyper `Service` that calls our app through - // `tower::Service::call`. - let hyper_service = service_fn(move |request: Request| { - // This carries the current context with the trace id so that the TraceLayer can use that as a parent - let otel_context = Span::current().context(); - // We need to clone here, because oneshot consumes self - tower_service - .clone() - .oneshot(request) - .with_context(otel_context) - }); - - let span = tracing::debug_span!("serve connection"); - hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(tls_stream, hyper_service) - .instrument(span.clone()) - .await - .unwrap_or_else(|err| { - span.record("otel.status_code", "Error") - .record("otel.status_message", err.to_string()); - tracing::warn!(%err, %remote_addr, "failed to serve connection"); - }) - } - .instrument(span), - ); - } - } -} - -pub trait SocketAddrExt { - fn semantic_convention_network_type(&self) -> &'static str; -} - -impl SocketAddrExt for SocketAddr { - fn semantic_convention_network_type(&self) -> &'static str { - match self { - SocketAddr::V4(_) => "ipv4", - SocketAddr::V6(_) => "ipv6", - } - } -} - -// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields -// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields diff --git a/crates/stackable-webhook/src/tls/cert_resolver.rs b/crates/stackable-webhook/src/tls/cert_resolver.rs new file mode 100644 index 000000000..c06df8ffe --- /dev/null +++ b/crates/stackable-webhook/src/tls/cert_resolver.rs @@ -0,0 +1,166 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use snafu::{OptionExt, ResultExt, Snafu}; +use stackable_certs::{CertificatePairError, ca::CertificateAuthority, keys::ecdsa}; +use tokio::sync::mpsc; +use tokio_rustls::rustls::{ + crypto::CryptoProvider, server::ResolvesServerCert, sign::CertifiedKey, +}; +use x509_cert::{Certificate, certificate::CertificateInner}; + +use super::{WEBHOOK_CA_LIFETIME, WEBHOOK_CERTIFICATE_LIFETIME}; + +type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum CertificateResolverError { + #[snafu(display("failed send certificate to channel"))] + SendCertificateToChannel, + + #[snafu(display("failed to generate ECDSA signing key"))] + GenerateEcdsaSigningKey { source: ecdsa::Error }, + + #[snafu(display("failed to generate new certificate"))] + GenerateNewCertificate { + #[snafu(source(from(CertificateResolverError, Box::new)))] + source: Box, + }, + + #[snafu(display("failed to create CA to generate and sign webhook leaf certificate"))] + CreateCertificateAuthority { source: stackable_certs::ca::Error }, + + #[snafu(display("failed to generate webhook leaf certificate"))] + GenerateLeafCertificate { source: stackable_certs::ca::Error }, + + #[snafu(display("failed to encode leaf certificate as DER"))] + EncodeCertificateDer { + source: CertificatePairError, + }, + + #[snafu(display("failed to encode private key as DER"))] + EncodePrivateKeyDer { + source: CertificatePairError, + }, + + #[snafu(display("failed to create packaged certificate chain from DER"))] + DecodeCertifiedKeyFromDer { source: tokio_rustls::rustls::Error }, + + #[snafu(display("failed to run task in blocking thread"))] + TokioSpawnBlocking { source: tokio::task::JoinError }, + + #[snafu(display("no default rustls CryptoProvider installed"))] + NoDefaultCryptoProviderInstalled {}, +} + +/// This struct serves as [`ResolvesServerCert`] to always hand out the current certificate for TLS +/// client connections. +/// +/// It offers the [`Self::rotate_certificate`] function to create a fresh certificate and basically +/// hot-reload the certificate in the running webhook. +#[derive(Debug)] +pub struct CertificateResolver { + /// Using a [`ArcSwap`] (over e.g. [`tokio::sync::RwLock`]), so that we can easily + /// (and performant) bridge between async write and sync read. + current_certified_key: ArcSwap, + subject_alterative_dns_names: Arc>, + + cert_tx: mpsc::Sender, +} + +impl CertificateResolver { + pub async fn new( + subject_alterative_dns_names: Vec, + cert_tx: mpsc::Sender, + ) -> Result { + let subject_alterative_dns_names = Arc::new(subject_alterative_dns_names); + let (cert, certified_key) = Self::generate_new_cert(subject_alterative_dns_names.clone()) + .await + .context(GenerateNewCertificateSnafu)?; + + Self::send_certificate_to_channel(cert, &cert_tx).await?; + + Ok(Self { + subject_alterative_dns_names, + current_certified_key: ArcSwap::new(certified_key), + cert_tx, + }) + } + + pub async fn rotate_certificate(&self) -> Result<()> { + let (cert, certified_key) = + Self::generate_new_cert(self.subject_alterative_dns_names.clone()) + .await + .context(GenerateNewCertificateSnafu)?; + + // TODO: Sign the new cert somehow with the old cert. See https://github.com/stackabletech/decisions/issues/56 + + Self::send_certificate_to_channel(cert, &self.cert_tx).await?; + self.current_certified_key.store(certified_key); + + Ok(()) + } + + /// FIXME: This should *not* construct a CA cert and cert, but only a cert! + /// This needs some changes in stackable-certs though. + /// See [the relevant decision](https://github.com/stackabletech/decisions/issues/56) + async fn generate_new_cert( + subject_alterative_dns_names: Arc>, + ) -> Result<(Certificate, Arc)> { + // The certificate generations can take a while, so we use `spawn_blocking` + tokio::task::spawn_blocking(move || { + let tls_provider = + CryptoProvider::get_default().context(NoDefaultCryptoProviderInstalledSnafu)?; + + let ca_key = ecdsa::SigningKey::new().context(GenerateEcdsaSigningKeySnafu)?; + let mut ca = + CertificateAuthority::new_with(ca_key, rand::random::(), WEBHOOK_CA_LIFETIME) + .context(CreateCertificateAuthoritySnafu)?; + + let certificate_pair = ca + .generate_ecdsa_leaf_certificate( + "Leaf", + "webhook", + subject_alterative_dns_names.iter().map(|san| san.as_str()), + WEBHOOK_CERTIFICATE_LIFETIME, + ) + .context(GenerateLeafCertificateSnafu)?; + + let certificate_der = certificate_pair + .certificate_der() + .context(EncodeCertificateDerSnafu)?; + let private_key_der = certificate_pair + .private_key_der() + .context(EncodePrivateKeyDerSnafu)?; + let certificate_key = + CertifiedKey::from_der(vec![certificate_der], private_key_der, tls_provider) + .context(DecodeCertifiedKeyFromDerSnafu)?; + + Ok(( + certificate_pair.certificate().clone(), + Arc::new(certificate_key), + )) + }) + .await + .context(TokioSpawnBlockingSnafu)? + } + + async fn send_certificate_to_channel( + cert: CertificateInner, + cert_tx: &mpsc::Sender, + ) -> Result<()> { + cert_tx + .send(cert) + .await + .map_err(|_err| CertificateResolverError::SendCertificateToChannel) + } +} + +impl ResolvesServerCert for CertificateResolver { + fn resolve( + &self, + _client_hello: tokio_rustls::rustls::server::ClientHello<'_>, + ) -> Option> { + Some(self.current_certified_key.load().clone()) + } +} diff --git a/crates/stackable-webhook/src/tls/mod.rs b/crates/stackable-webhook/src/tls/mod.rs new file mode 100644 index 000000000..7d10ecb94 --- /dev/null +++ b/crates/stackable-webhook/src/tls/mod.rs @@ -0,0 +1,290 @@ +//! This module contains structs and functions to easily create a TLS termination +//! server, which can be used in combination with an Axum [`Router`]. +use std::{convert::Infallible, net::SocketAddr, sync::Arc}; + +use axum::{ + Router, + extract::{ConnectInfo, Request}, + middleware::AddExtension, +}; +use hyper::{body::Incoming, service::service_fn}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use opentelemetry::trace::{FutureExt, SpanKind}; +use opentelemetry_semantic_conventions as semconv; +use snafu::{ResultExt, Snafu}; +use stackable_operator::time::Duration; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, +}; +use tokio_rustls::{ + TlsAcceptor, + rustls::{ + ServerConfig, + crypto::ring::default_provider, + version::{TLS12, TLS13}, + }, +}; +use tower::{Service, ServiceExt}; +use tracing::{Instrument, Span, field::Empty, instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; +use x509_cert::Certificate; + +use crate::{ + options::WebhookOptions, + tls::cert_resolver::{CertificateResolver, CertificateResolverError}, +}; + +mod cert_resolver; + +pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_minutes_unchecked(3); +pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_minutes_unchecked(2); +pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_minutes_unchecked(1); + +pub type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum TlsServerError { + #[snafu(display("failed to create certificate resolver"))] + CreateCertificateResolver { source: CertificateResolverError }, + + #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))] + BindTcpListener { + source: std::io::Error, + socket_addr: SocketAddr, + }, + + #[snafu(display("failed to rotate certificate"))] + RotateCertificate { source: CertificateResolverError }, + + #[snafu(display("failed to set safe TLS protocol versions"))] + SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error }, +} + +/// A server which terminates TLS connections and allows clients to communicate +/// via HTTPS with the underlying HTTP router. +/// +/// It also rotates the generated certificates as needed. +pub struct TlsServer { + config: ServerConfig, + cert_resolver: Arc, + + socket_addr: SocketAddr, + router: Router, +} + +impl TlsServer { + /// Create a new [`TlsServer`]. + /// + /// This internally creates a `CertificateResolver` with the provided + /// `subject_alterative_dns_names`, which takes care of the certificate rotation. Afterwards it + /// creates the [`ServerConfig`], which let's the `CertificateResolver` provide the needed + /// certificates. + #[instrument(name = "create_tls_server", skip(router))] + pub async fn new( + router: Router, + options: WebhookOptions, + ) -> Result<(Self, mpsc::Receiver)> { + let (cert_tx, cert_rx) = mpsc::channel(1); + + let WebhookOptions { + socket_addr, + subject_alterative_dns_names, + } = options; + + let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, cert_tx) + .await + .context(CreateCertificateResolverSnafu)?; + let cert_resolver = Arc::new(cert_resolver); + + let tls_provider = default_provider(); + let mut config = ServerConfig::builder_with_provider(tls_provider.into()) + .with_protocol_versions(&[&TLS12, &TLS13]) + .context(SetSafeTlsProtocolVersionsSnafu)? + .with_no_client_auth() + .with_cert_resolver(cert_resolver.clone()); + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let tls_server = Self { + config, + cert_resolver, + socket_addr, + router, + }; + + Ok((tls_server, cert_rx)) + } + + /// Runs the TLS server by listening for incoming TCP connections on the + /// bound socket address. It only accepts TLS connections. Internally each + /// TLS stream get handled by a Hyper service, which in turn is an Axum + /// router. + /// + /// It also starts a background task to rotate the certificate as needed. + pub async fn run(self) -> Result<()> { + let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL; + let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL); + + let tls_acceptor = TlsAcceptor::from(Arc::new(self.config)); + let tcp_listener = + TcpListener::bind(self.socket_addr) + .await + .context(BindTcpListenerSnafu { + socket_addr: self.socket_addr, + })?; + + // To be able to extract the connect info from incoming requests, it is + // required to turn the router into a Tower service which is capable of + // doing that. Calling `into_make_service_with_connect_info` returns a + // new struct `IntoMakeServiceWithConnectInfo` which implements the + // Tower Service trait. This service is called after the TCP connection + // has been accepted. + // + // Inspired by: + // - https://github.com/tokio-rs/axum/discussions/2397 + // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98 + + let mut router = self + .router + .into_make_service_with_connect_info::(); + + loop { + let tls_acceptor = tls_acceptor.clone(); + + // Wait for either a new TCP connection or the certificate rotation interval tick + tokio::select! { + // We opt for a biased execution of arms to make sure we always check if the + // certificate needs rotation based on the interval. This ensures, we always use + // a valid certificate for the TLS connection. + biased; + + // This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed. + // As such, we will not miss rotating the certificate. + _ = interval.tick() => { + self.cert_resolver + .rotate_certificate() + .await + .context(RotateCertificateSnafu)? + } + + // This is cancellation-safe. If cancelled, no new connections are accepted. + tcp_connection = tcp_listener.accept() => { + let (tcp_stream, remote_addr) = match tcp_connection { + Ok((stream, addr)) => (stream, addr), + Err(err) => { + tracing::trace!(%err, "failed to accept incoming TCP connection"); + continue; + } + }; + + // Here, the connect info is extracted by calling Tower's Service + // trait function on `IntoMakeServiceWithConnectInfo` + let tower_service: Result<_, Infallible> = router.call(remote_addr).await; + let tower_service = tower_service.expect("Infallible error can never happen"); + + let span = tracing::debug_span!("accept tcp connection"); + tokio::spawn( + async move { + Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr) + .instrument(span) + .await + } + ); + } + }; + } + } + + async fn handle_request( + tcp_stream: TcpStream, + remote_addr: SocketAddr, + tls_acceptor: TlsAcceptor, + tower_service: AddExtension>, + socket_addr: SocketAddr, + ) { + let span = tracing::trace_span!( + "accept tls connection", + "otel.kind" = ?SpanKind::Server, + { semconv::attribute::OTEL_STATUS_CODE } = Empty, + { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty, + { semconv::trace::CLIENT_ADDRESS } = remote_addr.ip().to_string(), + { semconv::trace::CLIENT_PORT } = remote_addr.port() as i64, + { semconv::trace::SERVER_ADDRESS } = Empty, + { semconv::trace::SERVER_PORT } = Empty, + { semconv::trace::NETWORK_PEER_ADDRESS } = remote_addr.ip().to_string(), + { semconv::trace::NETWORK_PEER_PORT } = remote_addr.port() as i64, + { semconv::trace::NETWORK_LOCAL_ADDRESS } = Empty, + { semconv::trace::NETWORK_LOCAL_PORT } = Empty, + { semconv::trace::NETWORK_TRANSPORT } = "tcp", + { semconv::trace::NETWORK_TYPE } = socket_addr.semantic_convention_network_type(), + ); + + if let Ok(local_addr) = tcp_stream.local_addr() { + let addr = &local_addr.ip().to_string(); + let port = local_addr.port(); + span.record(semconv::trace::SERVER_ADDRESS, addr) + .record(semconv::trace::SERVER_PORT, port as i64) + .record(semconv::trace::NETWORK_LOCAL_ADDRESS, addr) + .record(semconv::trace::NETWORK_LOCAL_PORT, port as i64); + } + + // Wait for tls handshake to happen + let tls_stream = match tls_acceptor + .accept(tcp_stream) + .instrument(span.clone()) + .await + { + Ok(tls_stream) => tls_stream, + Err(err) => { + span.record(semconv::attribute::OTEL_STATUS_CODE, "Error") + .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string()); + tracing::trace!(%remote_addr, "error during tls handshake connection"); + return; + } + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let tls_stream = TokioIo::new(tls_stream); + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = service_fn(move |request: Request| { + // This carries the current context with the trace id so that the TraceLayer can use that as a parent + let otel_context = Span::current().context(); + // We need to clone here, because oneshot consumes self + tower_service + .clone() + .oneshot(request) + .with_context(otel_context) + }); + + let span = tracing::debug_span!("serve connection"); + hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(tls_stream, hyper_service) + .instrument(span.clone()) + .await + .unwrap_or_else(|err| { + span.record(semconv::attribute::OTEL_STATUS_CODE, "Error") + .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string()); + tracing::warn!(%err, %remote_addr, "failed to serve connection"); + }) + } +} + +pub trait SocketAddrExt { + fn semantic_convention_network_type(&self) -> &'static str; +} + +impl SocketAddrExt for SocketAddr { + fn semantic_convention_network_type(&self) -> &'static str { + match self { + SocketAddr::V4(_) => "ipv4", + SocketAddr::V6(_) => "ipv6", + } + } +} + +// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields +// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields