diff --git a/client/src/lib.rs b/client/src/lib.rs index 6135acdcb..41cd2b79f 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -78,7 +78,10 @@ use tonic::{ body::Body, client::GrpcService, codegen::InterceptedService, - metadata::{MetadataKey, MetadataMap, MetadataValue}, + metadata::{ + AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap, + MetadataValue, + }, service::Interceptor, transport::{Certificate, Channel, Endpoint, Identity}, }; @@ -146,9 +149,20 @@ pub struct ClientOptions { pub keep_alive: Option, /// HTTP headers to include on every RPC call. + /// + /// These must be valid gRPC metadata keys, and must not be binary metadata keys (ending in + /// `-bin). To set binary headers, use [ClientOptions::binary_headers]. Invalid header keys or + /// values will cause an error to be returned when connecting. #[builder(default)] pub headers: Option>, + /// HTTP headers to include on every RPC call as binary gRPC metadata (encoded as base64). + /// + /// These must be valid binary gRPC metadata keys (and end with a `-bin` suffix). Invalid + /// header keys will cause an error to be returned when connecting. + #[builder(default)] + pub binary_headers: Option>>, + /// API key which is set as the "Authorization" header with "Bearer " prepended. This will only /// be applied if the headers don't already have an "Authorization" header. #[builder(default)] @@ -322,6 +336,9 @@ pub enum ClientInitError { /// Invalid URI. Configuration error, fatal. #[error("Invalid URI: {0:?}")] InvalidUri(#[from] InvalidUri), + /// Invalid gRPC metadata headers. Configuration error. + #[error("Invalid headers: {0}")] + InvalidHeaders(#[from] InvalidHeaderError), /// Server connection error. Crashing and restarting the worker is likely best. #[error("Server connection error: {0:?}")] TonicTransportError(#[from] tonic::transport::Error), @@ -331,6 +348,37 @@ pub enum ClientInitError { SystemInfoCallError(tonic::Status), } +/// Errors thrown when a gRPC metadata header is invalid. +#[derive(thiserror::Error, Debug)] +pub enum InvalidHeaderError { + /// A binary header key was invalid + #[error("Invalid binary header key '{key}': {source}")] + InvalidBinaryHeaderKey { + /// The invalid key + key: String, + /// The source error from tonic + source: tonic::metadata::errors::InvalidMetadataKey, + }, + /// An ASCII header key was invalid + #[error("Invalid ASCII header key '{key}': {source}")] + InvalidAsciiHeaderKey { + /// The invalid key + key: String, + /// The source error from tonic + source: tonic::metadata::errors::InvalidMetadataKey, + }, + /// An ASCII header value was invalid + #[error("Invalid ASCII header value for key '{key}': {source}")] + InvalidAsciiHeaderValue { + /// The key + key: String, + /// The invalid value + value: String, + /// The source error from tonic + source: tonic::metadata::errors::InvalidMetadataValue, + }, +} + /// A client with [ClientOptions] attached, which can be passed to initialize workers, /// or can be used directly. Is cheap to clone. #[derive(Clone, Debug)] @@ -344,9 +392,33 @@ pub struct ConfiguredClient { } impl ConfiguredClient { - /// Set HTTP request headers overwriting previous headers - pub fn set_headers(&self, headers: HashMap) { - self.headers.write().user_headers = headers; + /// Set HTTP request headers overwriting previous headers. + /// + /// This will not affect headers set via [ClientOptions::binary_headers]. + /// + /// # Errors + /// + /// Will return an error if any of the provided keys or values are not valid gRPC metadata. + /// If an error is returned, the previous headers will remain unchanged. + pub fn set_headers(&self, headers: HashMap) -> Result<(), InvalidHeaderError> { + self.headers.write().user_headers = parse_ascii_headers(headers)?; + Ok(()) + } + + /// Set binary HTTP request headers overwriting previous headers. + /// + /// This will not affect headers set via [ClientOptions::headers]. + /// + /// # Errors + /// + /// Will return an error if any of the provided keys are not valid gRPC binary metadata keys. + /// If an error is returned, the previous headers will remain unchanged. + pub fn set_binary_headers( + &self, + binary_headers: HashMap>, + ) -> Result<(), InvalidHeaderError> { + self.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?; + Ok(()) } /// Set API key, overwriting previous @@ -373,7 +445,8 @@ impl ConfiguredClient { #[derive(Debug)] struct ClientHeaders { - user_headers: HashMap, + user_headers: HashMap, + user_binary_headers: HashMap, api_key: Option, } @@ -382,10 +455,13 @@ impl ClientHeaders { for (key, val) in self.user_headers.iter() { // Only if not already present if !metadata.contains_key(key) { - // Ignore invalid keys/values - if let (Ok(key), Ok(val)) = (MetadataKey::from_str(key), val.parse()) { - metadata.insert(key, val); - } + metadata.insert(key, val.clone()); + } + } + for (key, val) in self.user_binary_headers.iter() { + // Only if not already present + if !metadata.contains_key(key) { + metadata.insert_bin(key, val.clone()); } } if let Some(api_key) = &self.api_key { @@ -491,7 +567,10 @@ impl ClientOptions { }; let headers = Arc::new(RwLock::new(ClientHeaders { - user_headers: self.headers.clone().unwrap_or_default(), + user_headers: parse_ascii_headers(self.headers.clone().unwrap_or_default())?, + user_binary_headers: parse_binary_headers( + self.binary_headers.clone().unwrap_or_default(), + )?, api_key: self.api_key.clone(), })); let interceptor = ServiceCallInterceptor { @@ -558,6 +637,57 @@ impl ClientOptions { } } +fn parse_ascii_headers( + headers: HashMap, +) -> Result, InvalidHeaderError> { + let mut parsed_headers = HashMap::with_capacity(headers.len()); + for (k, v) in headers.into_iter() { + let key = match AsciiMetadataKey::from_str(&k) { + Ok(key) => key, + Err(err) => { + return Err(InvalidHeaderError::InvalidAsciiHeaderKey { + key: k, + source: err, + }); + } + }; + let value = match MetadataValue::from_str(&v) { + Ok(value) => value, + Err(err) => { + return Err(InvalidHeaderError::InvalidAsciiHeaderValue { + key: k, + value: v, + source: err, + }); + } + }; + parsed_headers.insert(key, value); + } + + Ok(parsed_headers) +} + +fn parse_binary_headers( + headers: HashMap>, +) -> Result, InvalidHeaderError> { + let mut parsed_headers = HashMap::with_capacity(headers.len()); + for (k, v) in headers.into_iter() { + let key = match BinaryMetadataKey::from_str(&k) { + Ok(key) => key, + Err(err) => { + return Err(InvalidHeaderError::InvalidBinaryHeaderKey { + key: k, + source: err, + }); + } + }; + let value = BinaryMetadataValue::from_bytes(&v); + parsed_headers.insert(key, value); + } + + Ok(parsed_headers) +} + /// Interceptor which attaches common metadata (like "client-name") to every outgoing call #[derive(Clone)] pub struct ServiceCallInterceptor { @@ -1770,13 +1900,17 @@ mod tests { // Initial header set let headers = Arc::new(RwLock::new(ClientHeaders { user_headers: HashMap::new(), + user_binary_headers: HashMap::new(), api_key: Some("my-api-key".to_owned()), })); - headers - .clone() - .write() - .user_headers - .insert("my-meta-key".to_owned(), "my-meta-val".to_owned()); + headers.clone().write().user_headers.insert( + "my-meta-key".parse().unwrap(), + "my-meta-val".parse().unwrap(), + ); + headers.clone().write().user_binary_headers.insert( + "my-bin-meta-key-bin".parse().unwrap(), + vec![1, 2, 3].try_into().unwrap(), + ); let mut interceptor = ServiceCallInterceptor { opts, headers: headers.clone(), @@ -1789,6 +1923,10 @@ mod tests { req.metadata().get("authorization").unwrap(), "Bearer my-api-key" ); + assert_eq!( + req.metadata().get_bin("my-bin-meta-key-bin").unwrap(), + vec![1, 2, 3].as_slice() + ); // Overwrite at request time let mut req = tonic::Request::new(()); @@ -1796,26 +1934,33 @@ mod tests { .insert("my-meta-key", "my-meta-val2".parse().unwrap()); req.metadata_mut() .insert("authorization", "my-api-key2".parse().unwrap()); + req.metadata_mut() + .insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap()); let req = interceptor.call(req).unwrap(); assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2"); assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2"); + assert_eq!( + req.metadata().get_bin("my-bin-meta-key-bin").unwrap(), + vec![4, 5, 6].as_slice() + ); // Overwrite auth on header - headers - .clone() - .write() - .user_headers - .insert("authorization".to_owned(), "my-api-key3".to_owned()); + headers.clone().write().user_headers.insert( + "authorization".parse().unwrap(), + "my-api-key3".parse().unwrap(), + ); let req = interceptor.call(tonic::Request::new(())).unwrap(); assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val"); assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3"); // Remove headers and auth and confirm gone headers.clone().write().user_headers.clear(); + headers.clone().write().user_binary_headers.clear(); headers.clone().write().api_key.take(); let req = interceptor.call(tonic::Request::new(())).unwrap(); assert!(!req.metadata().contains_key("my-meta-key")); assert!(!req.metadata().contains_key("authorization")); + assert!(!req.metadata().contains_key("my-bin-meta-key-bin")); // Timeout header not overriden let mut req = tonic::Request::new(()); @@ -1828,6 +1973,55 @@ mod tests { ); } + #[test] + fn invalid_ascii_header_key() { + let invalid_headers = { + let mut h = HashMap::new(); + h.insert("x-binary-key-bin".to_owned(), "value".to_owned()); + h + }; + + let result = parse_ascii_headers(invalid_headers); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name" + ); + } + + #[test] + fn invalid_ascii_header_value() { + let invalid_headers = { + let mut h = HashMap::new(); + // Nul bytes are valid UTF-8, but not valid ascii gRPC headers: + h.insert("x-ascii-key".to_owned(), "\x00value".to_owned()); + h + }; + + let result = parse_ascii_headers(invalid_headers); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value" + ); + } + + #[test] + fn invalid_binary_header_key() { + let invalid_headers = { + let mut h = HashMap::new(); + h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]); + h + }; + + let result = parse_binary_headers(invalid_headers); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name" + ); + } + #[test] fn keep_alive_defaults() { let mut builder = ClientOptionsBuilder::default(); diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index f7e6a7cef..3ae47e6f4 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -240,7 +240,7 @@ pub extern "C" fn temporal_core_client_update_metadata( metadata: ByteArrayRef, ) { let client = unsafe { &*client }; - client + let _result = client .core .get_client() .set_headers(metadata.to_string_map_on_newlines()); diff --git a/tests/integ_tests/client_tests.rs b/tests/integ_tests/client_tests.rs index bf456f6ed..bd3f6e15d 100644 --- a/tests/integ_tests/client_tests.rs +++ b/tests/integ_tests/client_tests.rs @@ -77,7 +77,7 @@ async fn per_call_timeout_respected_whole_client() { let mut raw_client = opts.connect_no_namespace(None).await.unwrap(); let mut hm = HashMap::new(); hm.insert("grpc-timeout".to_string(), "0S".to_string()); - raw_client.get_client().set_headers(hm); + raw_client.get_client().set_headers(hm).unwrap(); let err = raw_client .describe_namespace(DescribeNamespaceRequest { namespace: NAMESPACE.to_string(),