Skip to content

Commit 9d70d44

Browse files
authored
Add support for gRPC binary metadata values (#1070)
1 parent ce8dc4a commit 9d70d44

File tree

10 files changed

+362
-152
lines changed

10 files changed

+362
-152
lines changed

temporalio/bridge/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dataclasses import dataclass
99
from datetime import timedelta
10-
from typing import Mapping, Optional, Tuple, Type, TypeVar
10+
from typing import Mapping, Optional, Tuple, Type, TypeVar, Union
1111

1212
import google.protobuf.message
1313

@@ -59,7 +59,7 @@ class ClientConfig:
5959
"""Python representation of the Rust struct for configuring the client."""
6060

6161
target_url: str
62-
metadata: Mapping[str, str]
62+
metadata: Mapping[str, Union[str, bytes]]
6363
api_key: Optional[str]
6464
identity: str
6565
tls_config: Optional[ClientTlsConfig]
@@ -77,7 +77,7 @@ class RpcCall:
7777
rpc: str
7878
req: bytes
7979
retry: bool
80-
metadata: Mapping[str, str]
80+
metadata: Mapping[str, Union[str, bytes]]
8181
timeout_millis: Optional[int]
8282

8383

@@ -108,7 +108,7 @@ def __init__(
108108
self._runtime = runtime
109109
self._ref = ref
110110

111-
def update_metadata(self, metadata: Mapping[str, str]) -> None:
111+
def update_metadata(self, metadata: Mapping[str, Union[str, bytes]]) -> None:
112112
"""Update underlying metadata on Core client."""
113113
self._ref.update_metadata(metadata)
114114

@@ -124,7 +124,7 @@ async def call(
124124
req: google.protobuf.message.Message,
125125
resp_type: Type[ProtoMessage],
126126
retry: bool,
127-
metadata: Mapping[str, str],
127+
metadata: Mapping[str, Union[str, bytes]],
128128
timeout: Optional[timedelta],
129129
) -> ProtoMessage:
130130
"""Make RPC call using SDK Core."""

temporalio/bridge/src/client.rs

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use temporal_client::{
88
ConfiguredClient, HealthService, HttpConnectProxyOptions, RetryClient, RetryConfig,
99
TemporalServiceClientWithMetrics, TestService, TlsConfig, WorkflowService,
1010
};
11-
use tonic::metadata::MetadataKey;
11+
use tonic::metadata::{
12+
AsciiMetadataKey, AsciiMetadataValue, BinaryMetadataKey, BinaryMetadataValue,
13+
};
1214
use url::Url;
1315

1416
use crate::runtime;
@@ -28,7 +30,7 @@ pub struct ClientConfig {
2830
target_url: String,
2931
client_name: String,
3032
client_version: String,
31-
metadata: HashMap<String, String>,
33+
metadata: HashMap<String, RpcMetadataValue>,
3234
api_key: Option<String>,
3335
identity: String,
3436
tls_config: Option<ClientTlsConfig>,
@@ -72,10 +74,18 @@ struct RpcCall {
7274
rpc: String,
7375
req: Vec<u8>,
7476
retry: bool,
75-
metadata: HashMap<String, String>,
77+
metadata: HashMap<String, RpcMetadataValue>,
7678
timeout_millis: Option<u64>,
7779
}
7880

81+
#[derive(FromPyObject)]
82+
enum RpcMetadataValue {
83+
#[pyo3(transparent, annotation = "str")]
84+
Str(String),
85+
#[pyo3(transparent, annotation = "bytes")]
86+
Bytes(Vec<u8>),
87+
}
88+
7989
pub fn connect_client<'a>(
8090
py: Python<'a>,
8191
runtime_ref: &runtime::RuntimeRef,
@@ -116,8 +126,19 @@ macro_rules! rpc_call_on_trait {
116126

117127
#[pymethods]
118128
impl ClientRef {
119-
fn update_metadata(&self, headers: HashMap<String, String>) {
120-
self.retry_client.get_client().set_headers(headers);
129+
fn update_metadata(&self, headers: HashMap<String, RpcMetadataValue>) -> PyResult<()> {
130+
let (ascii_headers, binary_headers) = partition_headers(headers);
131+
132+
self.retry_client
133+
.get_client()
134+
.set_headers(ascii_headers)
135+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
136+
self.retry_client
137+
.get_client()
138+
.set_binary_headers(binary_headers)
139+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
140+
141+
Ok(())
121142
}
122143

123144
fn update_api_key(&self, api_key: Option<String>) {
@@ -536,12 +557,32 @@ fn rpc_req<P: prost::Message + Default>(call: RpcCall) -> PyResult<tonic::Reques
536557
.map_err(|err| PyValueError::new_err(format!("Invalid proto: {err}")))?;
537558
let mut req = tonic::Request::new(proto);
538559
for (k, v) in call.metadata {
539-
req.metadata_mut().insert(
540-
MetadataKey::from_str(k.as_str())
541-
.map_err(|err| PyValueError::new_err(format!("Invalid metadata key: {err}")))?,
542-
v.parse()
543-
.map_err(|err| PyValueError::new_err(format!("Invalid metadata value: {err}")))?,
544-
);
560+
if let Ok(binary_key) = BinaryMetadataKey::from_str(&k) {
561+
let RpcMetadataValue::Bytes(bytes) = v else {
562+
return Err(PyValueError::new_err(format!(
563+
"Invalid metadata value for binary key {k}: expected bytes"
564+
)));
565+
};
566+
567+
req.metadata_mut()
568+
.insert_bin(binary_key, BinaryMetadataValue::from_bytes(&bytes));
569+
} else {
570+
let ascii_key = AsciiMetadataKey::from_str(&k)
571+
.map_err(|err| PyValueError::new_err(format!("Invalid metadata key: {err}")))?;
572+
573+
let RpcMetadataValue::Str(string) = v else {
574+
return Err(PyValueError::new_err(format!(
575+
"Invalid metadata value for ASCII key {k}: expected str"
576+
)));
577+
};
578+
579+
req.metadata_mut().insert(
580+
ascii_key,
581+
AsciiMetadataValue::from_str(&string).map_err(|err| {
582+
PyValueError::new_err(format!("Invalid metadata value: {err}"))
583+
})?,
584+
);
585+
}
545586
}
546587
if let Some(timeout_millis) = call.timeout_millis {
547588
req.set_timeout(Duration::from_millis(timeout_millis));
@@ -568,11 +609,41 @@ where
568609
}
569610
}
570611

612+
fn partition_headers(
613+
headers: HashMap<String, RpcMetadataValue>,
614+
) -> (HashMap<String, String>, HashMap<String, Vec<u8>>) {
615+
let (ascii_enum_headers, binary_enum_headers): (HashMap<_, _>, HashMap<_, _>) = headers
616+
.into_iter()
617+
.partition(|(_, v)| matches!(v, RpcMetadataValue::Str(_)));
618+
619+
let ascii_headers = ascii_enum_headers
620+
.into_iter()
621+
.map(|(k, v)| {
622+
let RpcMetadataValue::Str(s) = v else {
623+
unreachable!();
624+
};
625+
(k, s)
626+
})
627+
.collect();
628+
let binary_headers = binary_enum_headers
629+
.into_iter()
630+
.map(|(k, v)| {
631+
let RpcMetadataValue::Bytes(b) = v else {
632+
unreachable!();
633+
};
634+
(k, b)
635+
})
636+
.collect();
637+
638+
(ascii_headers, binary_headers)
639+
}
640+
571641
impl TryFrom<ClientConfig> for ClientOptions {
572642
type Error = PyErr;
573643

574644
fn try_from(opts: ClientConfig) -> PyResult<Self> {
575645
let mut gateway_opts = ClientOptionsBuilder::default();
646+
let (ascii_headers, binary_headers) = partition_headers(opts.metadata);
576647
gateway_opts
577648
.target_url(
578649
Url::parse(&opts.target_url)
@@ -587,7 +658,8 @@ impl TryFrom<ClientConfig> for ClientOptions {
587658
)
588659
.keep_alive(opts.keep_alive_config.map(Into::into))
589660
.http_connect_proxy(opts.http_connect_proxy_config.map(Into::into))
590-
.headers(Some(opts.metadata))
661+
.headers(Some(ascii_headers))
662+
.binary_headers(Some(binary_headers))
591663
.api_key(opts.api_key);
592664
// Builder does not allow us to set option here, so we have to make
593665
// a conditional to even call it

0 commit comments

Comments
 (0)