@@ -8,7 +8,9 @@ use temporal_client::{
88 ConfiguredClient , HealthService , HttpConnectProxyOptions , RetryClient , RetryConfig ,
99 TemporalServiceClientWithMetrics , TestService , TlsConfig , WorkflowService ,
1010} ;
11- use tonic:: metadata:: MetadataKey ;
11+ use tonic:: metadata:: {
12+ AsciiMetadataKey , AsciiMetadataValue , BinaryMetadataKey , BinaryMetadataValue ,
13+ } ;
1214use url:: Url ;
1315
1416use crate :: runtime;
@@ -28,7 +30,7 @@ pub struct ClientConfig {
2830 target_url : String ,
2931 client_name : String ,
3032 client_version : String ,
31- metadata : HashMap < String , String > ,
33+ metadata : HashMap < String , RpcMetadataValue > ,
3234 api_key : Option < String > ,
3335 identity : String ,
3436 tls_config : Option < ClientTlsConfig > ,
@@ -72,10 +74,18 @@ struct RpcCall {
7274 rpc : String ,
7375 req : Vec < u8 > ,
7476 retry : bool ,
75- metadata : HashMap < String , String > ,
77+ metadata : HashMap < String , RpcMetadataValue > ,
7678 timeout_millis : Option < u64 > ,
7779}
7880
81+ #[ derive( FromPyObject ) ]
82+ enum RpcMetadataValue {
83+ #[ pyo3( transparent, annotation = "str" ) ]
84+ Str ( String ) ,
85+ #[ pyo3( transparent, annotation = "bytes" ) ]
86+ Bytes ( Vec < u8 > ) ,
87+ }
88+
7989pub fn connect_client < ' a > (
8090 py : Python < ' a > ,
8191 runtime_ref : & runtime:: RuntimeRef ,
@@ -116,8 +126,19 @@ macro_rules! rpc_call_on_trait {
116126
117127#[ pymethods]
118128impl ClientRef {
119- fn update_metadata ( & self , headers : HashMap < String , String > ) {
120- self . retry_client . get_client ( ) . set_headers ( headers) ;
129+ fn update_metadata ( & self , headers : HashMap < String , RpcMetadataValue > ) -> PyResult < ( ) > {
130+ let ( ascii_headers, binary_headers) = partition_headers ( headers) ;
131+
132+ self . retry_client
133+ . get_client ( )
134+ . set_headers ( ascii_headers)
135+ . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) ) ?;
136+ self . retry_client
137+ . get_client ( )
138+ . set_binary_headers ( binary_headers)
139+ . map_err ( |err| PyValueError :: new_err ( err. to_string ( ) ) ) ?;
140+
141+ Ok ( ( ) )
121142 }
122143
123144 fn update_api_key ( & self , api_key : Option < String > ) {
@@ -536,12 +557,32 @@ fn rpc_req<P: prost::Message + Default>(call: RpcCall) -> PyResult<tonic::Reques
536557 . map_err ( |err| PyValueError :: new_err ( format ! ( "Invalid proto: {err}" ) ) ) ?;
537558 let mut req = tonic:: Request :: new ( proto) ;
538559 for ( k, v) in call. metadata {
539- req. metadata_mut ( ) . insert (
540- MetadataKey :: from_str ( k. as_str ( ) )
541- . map_err ( |err| PyValueError :: new_err ( format ! ( "Invalid metadata key: {err}" ) ) ) ?,
542- v. parse ( )
543- . map_err ( |err| PyValueError :: new_err ( format ! ( "Invalid metadata value: {err}" ) ) ) ?,
544- ) ;
560+ if let Ok ( binary_key) = BinaryMetadataKey :: from_str ( & k) {
561+ let RpcMetadataValue :: Bytes ( bytes) = v else {
562+ return Err ( PyValueError :: new_err ( format ! (
563+ "Invalid metadata value for binary key {k}: expected bytes"
564+ ) ) ) ;
565+ } ;
566+
567+ req. metadata_mut ( )
568+ . insert_bin ( binary_key, BinaryMetadataValue :: from_bytes ( & bytes) ) ;
569+ } else {
570+ let ascii_key = AsciiMetadataKey :: from_str ( & k)
571+ . map_err ( |err| PyValueError :: new_err ( format ! ( "Invalid metadata key: {err}" ) ) ) ?;
572+
573+ let RpcMetadataValue :: Str ( string) = v else {
574+ return Err ( PyValueError :: new_err ( format ! (
575+ "Invalid metadata value for ASCII key {k}: expected str"
576+ ) ) ) ;
577+ } ;
578+
579+ req. metadata_mut ( ) . insert (
580+ ascii_key,
581+ AsciiMetadataValue :: from_str ( & string) . map_err ( |err| {
582+ PyValueError :: new_err ( format ! ( "Invalid metadata value: {err}" ) )
583+ } ) ?,
584+ ) ;
585+ }
545586 }
546587 if let Some ( timeout_millis) = call. timeout_millis {
547588 req. set_timeout ( Duration :: from_millis ( timeout_millis) ) ;
@@ -568,11 +609,41 @@ where
568609 }
569610}
570611
612+ fn partition_headers (
613+ headers : HashMap < String , RpcMetadataValue > ,
614+ ) -> ( HashMap < String , String > , HashMap < String , Vec < u8 > > ) {
615+ let ( ascii_enum_headers, binary_enum_headers) : ( HashMap < _ , _ > , HashMap < _ , _ > ) = headers
616+ . into_iter ( )
617+ . partition ( |( _, v) | matches ! ( v, RpcMetadataValue :: Str ( _) ) ) ;
618+
619+ let ascii_headers = ascii_enum_headers
620+ . into_iter ( )
621+ . map ( |( k, v) | {
622+ let RpcMetadataValue :: Str ( s) = v else {
623+ unreachable ! ( ) ;
624+ } ;
625+ ( k, s)
626+ } )
627+ . collect ( ) ;
628+ let binary_headers = binary_enum_headers
629+ . into_iter ( )
630+ . map ( |( k, v) | {
631+ let RpcMetadataValue :: Bytes ( b) = v else {
632+ unreachable ! ( ) ;
633+ } ;
634+ ( k, b)
635+ } )
636+ . collect ( ) ;
637+
638+ ( ascii_headers, binary_headers)
639+ }
640+
571641impl TryFrom < ClientConfig > for ClientOptions {
572642 type Error = PyErr ;
573643
574644 fn try_from ( opts : ClientConfig ) -> PyResult < Self > {
575645 let mut gateway_opts = ClientOptionsBuilder :: default ( ) ;
646+ let ( ascii_headers, binary_headers) = partition_headers ( opts. metadata ) ;
576647 gateway_opts
577648 . target_url (
578649 Url :: parse ( & opts. target_url )
@@ -587,7 +658,8 @@ impl TryFrom<ClientConfig> for ClientOptions {
587658 )
588659 . keep_alive ( opts. keep_alive_config . map ( Into :: into) )
589660 . http_connect_proxy ( opts. http_connect_proxy_config . map ( Into :: into) )
590- . headers ( Some ( opts. metadata ) )
661+ . headers ( Some ( ascii_headers) )
662+ . binary_headers ( Some ( binary_headers) )
591663 . api_key ( opts. api_key ) ;
592664 // Builder does not allow us to set option here, so we have to make
593665 // a conditional to even call it
0 commit comments