Skip to content

Commit 8e57847

Browse files
authored
Support binary values in gRPC metadata (#348)
Fixes #329
1 parent 7354b08 commit 8e57847

File tree

2 files changed

+102
-11
lines changed

2 files changed

+102
-11
lines changed

temporalio/ext/src/client.rs

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use temporal_client::{
77
};
88

99
use magnus::{
10-
DataTypeFunctions, Error, RString, Ruby, TypedData, Value, class, function, method, prelude::*,
11-
scan_args,
10+
DataTypeFunctions, Error, RHash, RString, Ruby, TypedData, Value, class, function, method,
11+
prelude::*, scan_args,
1212
};
1313
use tonic::{Status, metadata::MetadataKey};
1414
use url::Url;
@@ -84,9 +84,11 @@ macro_rules! rpc_call {
8484
impl Client {
8585
pub fn async_new(runtime: &Runtime, options: Struct, queue: Value) -> Result<(), Error> {
8686
runtime.handle.fork_check("create client")?;
87+
let ruby = Ruby::get().expect("Ruby not available");
8788
// Build options
8889
let mut opts_build = ClientOptionsBuilder::default();
8990
let tls = options.child(id!("tls"))?;
91+
let headers = partition_grpc_headers(&ruby, options.member(id!("rpc_metadata"))?)?;
9092
opts_build
9193
.target_url(
9294
Url::parse(
@@ -101,7 +103,8 @@ impl Client {
101103
)
102104
.client_name(options.member::<String>(id!("client_name"))?)
103105
.client_version(options.member::<String>(id!("client_version"))?)
104-
.headers(Some(options.member(id!("rpc_metadata"))?))
106+
.headers(Some(headers.headers))
107+
.binary_headers(Some(headers.binary_headers))
105108
.api_key(options.member(id!("api_key"))?)
106109
.identity(options.member::<String>(id!("identity"))?);
107110
if let Some(tls) = tls {
@@ -193,6 +196,7 @@ impl Client {
193196

194197
pub fn async_invoke_rpc(&self, args: &[Value]) -> Result<(), Error> {
195198
self.runtime_handle.fork_check("use client")?;
199+
let ruby = Ruby::get().expect("Ruby not available");
196200
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
197201
let (service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
198202
scan_args::get_kwargs::<
@@ -202,7 +206,7 @@ impl Client {
202206
String,
203207
RString,
204208
bool,
205-
Option<HashMap<String, String>>,
209+
Option<RHash>,
206210
Option<f64>,
207211
Option<&CancellationToken>,
208212
Value,
@@ -224,11 +228,16 @@ impl Client {
224228
&[],
225229
)?
226230
.required;
231+
let headers = if let Some(metadata) = metadata {
232+
Some(partition_grpc_headers(&ruby, metadata)?)
233+
} else {
234+
None
235+
};
227236
let call = RpcCall {
228237
rpc,
229238
request: unsafe { request.as_slice() },
230239
retry,
231-
metadata,
240+
headers,
232241
timeout,
233242
cancel_token: cancel_token.map(|c| c.token.clone()),
234243
_not_send_sync: PhantomData,
@@ -237,18 +246,59 @@ impl Client {
237246
self.invoke_rpc(service, callback, call)
238247
}
239248

240-
pub fn update_metadata(&self, headers: HashMap<String, String>) -> Result<(), Error> {
249+
pub fn update_metadata(&self, headers: RHash) -> Result<(), Error> {
250+
let ruby = Ruby::get().expect("Ruby not available");
251+
let headers = partition_grpc_headers(&ruby, headers)?;
252+
self.core
253+
.get_client()
254+
.set_headers(headers.headers)
255+
.map_err(|err| error!("Invalid headers: {}", err))?;
241256
self.core
242257
.get_client()
243-
.set_headers(headers)
244-
.map_err(|err| error!("Invalid headers: {}", err))
258+
.set_binary_headers(headers.binary_headers)
259+
.map_err(|err| error!("Invalid headers: {}", err))?;
260+
Ok(())
245261
}
246262

247263
pub fn update_api_key(&self, api_key: Option<String>) {
248264
self.core.get_client().set_api_key(api_key);
249265
}
250266
}
251267

268+
pub(crate) struct GrpcHeaders {
269+
headers: HashMap<String, String>,
270+
binary_headers: HashMap<String, Vec<u8>>,
271+
}
272+
273+
fn partition_grpc_headers(ruby: &Ruby, hash: RHash) -> Result<GrpcHeaders, Error> {
274+
let mut headers = HashMap::new();
275+
let mut binary_headers = HashMap::new();
276+
hash.foreach(|key: String, value: RString| {
277+
if key.ends_with("-bin") {
278+
if value.enc_get() != ruby.ascii8bit_encindex() {
279+
return Err(Error::new(
280+
ruby.exception_arg_error(),
281+
format!("Value for metadata key {key} must be ASCII-8BIT"),
282+
));
283+
}
284+
binary_headers.insert(key, unsafe { value.as_slice().to_vec() });
285+
} else {
286+
let value = value.to_string().map_err(|err| {
287+
Error::new(
288+
ruby.exception_arg_error(),
289+
format!("Value for metadata key {key} invalid: {err}"),
290+
)
291+
})?;
292+
headers.insert(key, value);
293+
}
294+
Ok(magnus::r_hash::ForEach::Continue)
295+
})?;
296+
Ok(GrpcHeaders {
297+
headers,
298+
binary_headers,
299+
})
300+
}
301+
252302
#[derive(DataTypeFunctions, TypedData)]
253303
#[magnus(
254304
class = "Temporalio::Internal::Bridge::Client::RPCFailure",
@@ -280,7 +330,7 @@ pub(crate) struct RpcCall<'a> {
280330
pub rpc: String,
281331
pub request: &'a [u8],
282332
pub retry: bool,
283-
pub metadata: Option<HashMap<String, String>>,
333+
pub headers: Option<GrpcHeaders>,
284334
pub timeout: Option<f64>,
285335
pub cancel_token: Option<tokio_util::sync::CancellationToken>,
286336

@@ -294,15 +344,23 @@ impl RpcCall<'_> {
294344
pub fn into_request<P: prost::Message + Default>(self) -> Result<tonic::Request<P>, Error> {
295345
let proto = P::decode(self.request).map_err(|err| error!("Invalid proto: {}", err))?;
296346
let mut req = tonic::Request::new(proto);
297-
if let Some(metadata) = self.metadata {
298-
for (k, v) in metadata {
347+
if let Some(headers) = self.headers {
348+
for (k, v) in headers.headers {
299349
req.metadata_mut().insert(
300350
MetadataKey::from_str(k.as_str())
301351
.map_err(|err| error!("Invalid metadata key: {}", err))?,
302352
v.parse()
303353
.map_err(|err| error!("Invalid metadata value: {}", err))?,
304354
);
305355
}
356+
for (k, v) in headers.binary_headers {
357+
req.metadata_mut().insert_bin(
358+
MetadataKey::from_str(k.as_str())
359+
.map_err(|err| error!("Invalid metadata key: {}", err))?,
360+
v.try_into()
361+
.map_err(|err| error!("Invalid metadata value: {}", err))?,
362+
);
363+
}
306364
}
307365
if let Some(timeout) = self.timeout {
308366
req.set_timeout(Duration::from_secs_f64(timeout));

temporalio/test/client_test.rb

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,37 @@ def test_fork
234234
assert status.success?
235235
assert_equal 'started workflow', reader.read.strip
236236
end
237+
238+
def test_binary_metadata
239+
orig_metadata = env.client.connection.rpc_metadata
240+
241+
# Connect a new client with some bad metadata
242+
err = assert_raises(ArgumentError) do
243+
Temporalio::Client.connect(
244+
env.client.connection.target_host,
245+
env.client.namespace,
246+
rpc_metadata: { 'connect-bin' => 'not-allowed' }
247+
)
248+
end
249+
assert_equal 'Value for metadata key connect-bin must be ASCII-8BIT', err.message
250+
251+
# Update a client with some bad metadata
252+
err = assert_raises(ArgumentError) do
253+
env.client.connection.rpc_metadata = { 'update-bin' => 'not-allowed' }
254+
end
255+
assert_equal 'Value for metadata key update-bin must be ASCII-8BIT', err.message
256+
257+
# Make an RPC call with some bad metadata
258+
err = assert_raises(ArgumentError) do
259+
env.client.start_workflow(
260+
:MyWorkflow,
261+
id: "wf-#{SecureRandom.uuid}",
262+
task_queue: "tq-#{SecureRandom.uuid}",
263+
rpc_options: Temporalio::Client::RPCOptions.new(metadata: { 'rpc-bin' => 'not-allowed' })
264+
)
265+
end
266+
assert_equal 'Value for metadata key rpc-bin must be ASCII-8BIT', err.message
267+
ensure
268+
env.client.connection.rpc_metadata = orig_metadata
269+
end
237270
end

0 commit comments

Comments
 (0)