Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 214 additions & 20 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ use tonic::{
body::Body,
client::GrpcService,
codegen::InterceptedService,
metadata::{MetadataKey, MetadataMap, MetadataValue},
metadata::{
AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue, MetadataMap,
MetadataValue,
},
service::Interceptor,
transport::{Certificate, Channel, Endpoint, Identity},
};
Expand Down Expand Up @@ -146,9 +149,20 @@ pub struct ClientOptions {
pub keep_alive: Option<ClientKeepAliveConfig>,

/// HTTP headers to include on every RPC call.
///
/// These must be valid gRPC metadata keys, and must not be binary metadata keys (ending in
/// `-bin). To set binary headers, use [ClientOptions::binary_headers]. Invalid header keys or
/// values will cause an error to be returned when connecting.
#[builder(default)]
pub headers: Option<HashMap<String, String>>,

/// HTTP headers to include on every RPC call as binary gRPC metadata (encoded as base64).
///
/// These must be valid binary gRPC metadata keys (and end with a `-bin` suffix). Invalid
/// header keys will cause an error to be returned when connecting.
#[builder(default)]
pub binary_headers: Option<HashMap<String, Vec<u8>>>,

/// API key which is set as the "Authorization" header with "Bearer " prepended. This will only
/// be applied if the headers don't already have an "Authorization" header.
#[builder(default)]
Expand Down Expand Up @@ -322,6 +336,9 @@ pub enum ClientInitError {
/// Invalid URI. Configuration error, fatal.
#[error("Invalid URI: {0:?}")]
InvalidUri(#[from] InvalidUri),
/// Invalid gRPC metadata headers. Configuration error.
#[error("Invalid headers: {0}")]
InvalidHeaders(#[from] InvalidHeaderError),
/// Server connection error. Crashing and restarting the worker is likely best.
#[error("Server connection error: {0:?}")]
TonicTransportError(#[from] tonic::transport::Error),
Expand All @@ -331,6 +348,37 @@ pub enum ClientInitError {
SystemInfoCallError(tonic::Status),
}

/// Errors thrown when a gRPC metadata header is invalid.
#[derive(thiserror::Error, Debug)]
pub enum InvalidHeaderError {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would have also been ok w/ just anyhow::Error (if it works here) to keep it simple since this is rare, but this more verbose one is probably fine too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to switch over to an anyhow::Error if you feel strongly; otherwise I'll leave it as-is

Part of why I biased towards wrapper errors is that the tonic errors don't include the key/value that was invalid (though we could always add one/both to the anyhow error context if we wanted)

/// A binary header key was invalid
#[error("Invalid binary header key '{key}': {source}")]
InvalidBinaryHeaderKey {
/// The invalid key
key: String,
/// The source error from tonic
source: tonic::metadata::errors::InvalidMetadataKey,
},
/// An ASCII header key was invalid
#[error("Invalid ASCII header key '{key}': {source}")]
InvalidAsciiHeaderKey {
/// The invalid key
key: String,
/// The source error from tonic
source: tonic::metadata::errors::InvalidMetadataKey,
},
/// An ASCII header value was invalid
#[error("Invalid ASCII header value for key '{key}': {source}")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't include the invalid value here because I was a bit worried that including it in error messages may lead to a credentials logging risk for users (if headers are used for security values, for example).

Let me know if that makes sense, or if I should make any changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me

InvalidAsciiHeaderValue {
/// The key
key: String,
/// The invalid value
value: String,
/// The source error from tonic
source: tonic::metadata::errors::InvalidMetadataValue,
},
}

/// A client with [ClientOptions] attached, which can be passed to initialize workers,
/// or can be used directly. Is cheap to clone.
#[derive(Clone, Debug)]
Expand All @@ -344,9 +392,33 @@ pub struct ConfiguredClient<C> {
}

impl<C> ConfiguredClient<C> {
/// Set HTTP request headers overwriting previous headers
pub fn set_headers(&self, headers: HashMap<String, String>) {
self.headers.write().user_headers = headers;
/// Set HTTP request headers overwriting previous headers.
///
/// This will not affect headers set via [ClientOptions::binary_headers].
///
/// # Errors
///
/// Will return an error if any of the provided keys or values are not valid gRPC metadata.
/// If an error is returned, the previous headers will remain unchanged.
pub fn set_headers(&self, headers: HashMap<String, String>) -> Result<(), InvalidHeaderError> {
self.headers.write().user_headers = parse_ascii_headers(headers)?;
Ok(())
}

/// Set binary HTTP request headers overwriting previous headers.
///
/// This will not affect headers set via [ClientOptions::headers].
///
/// # Errors
///
/// Will return an error if any of the provided keys are not valid gRPC binary metadata keys.
/// If an error is returned, the previous headers will remain unchanged.
pub fn set_binary_headers(
&self,
binary_headers: HashMap<String, Vec<u8>>,
) -> Result<(), InvalidHeaderError> {
self.headers.write().user_binary_headers = parse_binary_headers(binary_headers)?;
Ok(())
}

/// Set API key, overwriting previous
Expand All @@ -373,7 +445,8 @@ impl<C> ConfiguredClient<C> {

#[derive(Debug)]
struct ClientHeaders {
user_headers: HashMap<String, String>,
user_headers: HashMap<AsciiMetadataKey, AsciiMetadataValue>,
user_binary_headers: HashMap<BinaryMetadataKey, BinaryMetadataValue>,
api_key: Option<String>,
}

Expand All @@ -382,10 +455,13 @@ impl ClientHeaders {
for (key, val) in self.user_headers.iter() {
// Only if not already present
if !metadata.contains_key(key) {
// Ignore invalid keys/values
if let (Ok(key), Ok(val)) = (MetadataKey::from_str(key), val.parse()) {
metadata.insert(key, val);
}
metadata.insert(key, val.clone());
}
}
for (key, val) in self.user_binary_headers.iter() {
// Only if not already present
if !metadata.contains_key(key) {
metadata.insert_bin(key, val.clone());
}
}
if let Some(api_key) = &self.api_key {
Expand Down Expand Up @@ -491,7 +567,10 @@ impl ClientOptions {
};

let headers = Arc::new(RwLock::new(ClientHeaders {
user_headers: self.headers.clone().unwrap_or_default(),
user_headers: parse_ascii_headers(self.headers.clone().unwrap_or_default())?,
user_binary_headers: parse_binary_headers(
self.binary_headers.clone().unwrap_or_default(),
)?,
api_key: self.api_key.clone(),
}));
let interceptor = ServiceCallInterceptor {
Expand Down Expand Up @@ -558,6 +637,57 @@ impl ClientOptions {
}
}

fn parse_ascii_headers(
headers: HashMap<String, String>,
) -> Result<HashMap<AsciiMetadataKey, AsciiMetadataValue>, InvalidHeaderError> {
let mut parsed_headers = HashMap::with_capacity(headers.len());
for (k, v) in headers.into_iter() {
let key = match AsciiMetadataKey::from_str(&k) {
Ok(key) => key,
Err(err) => {
return Err(InvalidHeaderError::InvalidAsciiHeaderKey {
Copy link
Contributor Author

@jazev-stripe jazev-stripe Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The match statement here is somewhat verbose, but I used it instead of .map_err to avoid cloning the key, in the case that we need to return an error

key: k,
source: err,
});
}
};
let value = match MetadataValue::from_str(&v) {
Ok(value) => value,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (for the value)

Err(err) => {
return Err(InvalidHeaderError::InvalidAsciiHeaderValue {
key: k,
value: v,
source: err,
});
}
};
parsed_headers.insert(key, value);
}

Ok(parsed_headers)
}

fn parse_binary_headers(
headers: HashMap<String, Vec<u8>>,
) -> Result<HashMap<BinaryMetadataKey, BinaryMetadataValue>, InvalidHeaderError> {
let mut parsed_headers = HashMap::with_capacity(headers.len());
for (k, v) in headers.into_iter() {
let key = match BinaryMetadataKey::from_str(&k) {
Ok(key) => key,
Err(err) => {
return Err(InvalidHeaderError::InvalidBinaryHeaderKey {
key: k,
source: err,
});
}
};
let value = BinaryMetadataValue::from_bytes(&v);
parsed_headers.insert(key, value);
}

Ok(parsed_headers)
}

/// Interceptor which attaches common metadata (like "client-name") to every outgoing call
#[derive(Clone)]
pub struct ServiceCallInterceptor {
Expand Down Expand Up @@ -1770,13 +1900,17 @@ mod tests {
// Initial header set
let headers = Arc::new(RwLock::new(ClientHeaders {
user_headers: HashMap::new(),
user_binary_headers: HashMap::new(),
api_key: Some("my-api-key".to_owned()),
}));
headers
.clone()
.write()
.user_headers
.insert("my-meta-key".to_owned(), "my-meta-val".to_owned());
headers.clone().write().user_headers.insert(
"my-meta-key".parse().unwrap(),
"my-meta-val".parse().unwrap(),
);
headers.clone().write().user_binary_headers.insert(
"my-bin-meta-key-bin".parse().unwrap(),
vec![1, 2, 3].try_into().unwrap(),
);
let mut interceptor = ServiceCallInterceptor {
opts,
headers: headers.clone(),
Expand All @@ -1789,33 +1923,44 @@ mod tests {
req.metadata().get("authorization").unwrap(),
"Bearer my-api-key"
);
assert_eq!(
req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
vec![1, 2, 3].as_slice()
);

// Overwrite at request time
let mut req = tonic::Request::new(());
req.metadata_mut()
.insert("my-meta-key", "my-meta-val2".parse().unwrap());
req.metadata_mut()
.insert("authorization", "my-api-key2".parse().unwrap());
req.metadata_mut()
.insert_bin("my-bin-meta-key-bin", vec![4, 5, 6].try_into().unwrap());
let req = interceptor.call(req).unwrap();
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
assert_eq!(
req.metadata().get_bin("my-bin-meta-key-bin").unwrap(),
vec![4, 5, 6].as_slice()
);

// Overwrite auth on header
headers
.clone()
.write()
.user_headers
.insert("authorization".to_owned(), "my-api-key3".to_owned());
headers.clone().write().user_headers.insert(
"authorization".parse().unwrap(),
"my-api-key3".parse().unwrap(),
);
let req = interceptor.call(tonic::Request::new(())).unwrap();
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");

// Remove headers and auth and confirm gone
headers.clone().write().user_headers.clear();
headers.clone().write().user_binary_headers.clear();
headers.clone().write().api_key.take();
let req = interceptor.call(tonic::Request::new(())).unwrap();
assert!(!req.metadata().contains_key("my-meta-key"));
assert!(!req.metadata().contains_key("authorization"));
assert!(!req.metadata().contains_key("my-bin-meta-key-bin"));

// Timeout header not overriden
let mut req = tonic::Request::new(());
Expand All @@ -1828,6 +1973,55 @@ mod tests {
);
}

#[test]
fn invalid_ascii_header_key() {
let invalid_headers = {
let mut h = HashMap::new();
h.insert("x-binary-key-bin".to_owned(), "value".to_owned());
h
};

let result = parse_ascii_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid ASCII header key 'x-binary-key-bin': invalid gRPC metadata key name"
);
}

#[test]
fn invalid_ascii_header_value() {
let invalid_headers = {
let mut h = HashMap::new();
// Nul bytes are valid UTF-8, but not valid ascii gRPC headers:
h.insert("x-ascii-key".to_owned(), "\x00value".to_owned());
h
};

let result = parse_ascii_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid ASCII header value for key 'x-ascii-key': failed to parse metadata value"
);
}

#[test]
fn invalid_binary_header_key() {
let invalid_headers = {
let mut h = HashMap::new();
h.insert("x-ascii-key".to_owned(), vec![1, 2, 3]);
h
};

let result = parse_binary_headers(invalid_headers);
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"Invalid binary header key 'x-ascii-key': invalid gRPC metadata key name"
);
}

#[test]
fn keep_alive_defaults() {
let mut builder = ClientOptionsBuilder::default();
Expand Down
2 changes: 1 addition & 1 deletion core-c-bridge/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ pub extern "C" fn temporal_core_client_update_metadata(
metadata: ByteArrayRef,
) {
let client = unsafe { &*client };
client
let _result = client
.core
.get_client()
.set_headers(metadata.to_string_map_on_newlines());
Expand Down
2 changes: 1 addition & 1 deletion tests/integ_tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async fn per_call_timeout_respected_whole_client() {
let mut raw_client = opts.connect_no_namespace(None).await.unwrap();
let mut hm = HashMap::new();
hm.insert("grpc-timeout".to_string(), "0S".to_string());
raw_client.get_client().set_headers(hm);
raw_client.get_client().set_headers(hm).unwrap();
let err = raw_client
.describe_namespace(DescribeNamespaceRequest {
namespace: NAMESPACE.to_string(),
Expand Down
Loading