Skip to content

Commit 80cddc8

Browse files
committed
point sdk-core to fork, add more complete metadata patch
1 parent 648ab71 commit 80cddc8

File tree

8 files changed

+53
-20
lines changed

8 files changed

+53
-20
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "sdk-core"]
22
path = temporalio/bridge/sdk-core
3-
url = https://github.com/temporalio/sdk-core.git
3+
url = https://github.com/jazev-stripe/sdk-core.git

temporalio/bridge/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ClientConfig:
5959
"""Python representation of the Rust struct for configuring the client."""
6060

6161
target_url: str
62-
metadata: Mapping[str, str]
62+
metadata: Mapping[str, str | bytes]
6363
api_key: Optional[str]
6464
identity: str
6565
tls_config: Optional[ClientTlsConfig]
@@ -108,7 +108,7 @@ def __init__(
108108
self._runtime = runtime
109109
self._ref = ref
110110

111-
def update_metadata(self, metadata: Mapping[str, str]) -> None:
111+
def update_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
112112
"""Update underlying metadata on Core client."""
113113
self._ref.update_metadata(metadata)
114114

temporalio/bridge/sdk-core

temporalio/bridge/src/client.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub struct ClientConfig {
3030
target_url: String,
3131
client_name: String,
3232
client_version: String,
33-
metadata: HashMap<String, String>,
33+
metadata: HashMap<String, RpcMetadataValue>,
3434
api_key: Option<String>,
3535
identity: String,
3636
tls_config: Option<ClientTlsConfig>,
@@ -126,8 +126,10 @@ macro_rules! rpc_call_on_trait {
126126

127127
#[pymethods]
128128
impl ClientRef {
129-
fn update_metadata(&self, headers: HashMap<String, String>) {
130-
self.retry_client.get_client().set_headers(headers);
129+
fn update_metadata(&self, headers: HashMap<String, RpcMetadataValue>) {
130+
let (ascii_headers, binary_headers) = partition_headers(headers);
131+
self.retry_client.get_client().set_headers(ascii_headers);
132+
self.retry_client.get_client().set_binary_headers(binary_headers);
131133
}
132134

133135
fn update_api_key(&self, api_key: Option<String>) {
@@ -595,11 +597,41 @@ where
595597
}
596598
}
597599

600+
fn partition_headers(
601+
headers: HashMap<String, RpcMetadataValue>,
602+
) -> (HashMap<String, String>, HashMap<String, Vec<u8>>) {
603+
let (ascii_enum_headers, binary_enum_headers): (HashMap<_, _>, HashMap<_, _>) = headers
604+
.into_iter()
605+
.partition(|(_, v)| matches!(v, RpcMetadataValue::Str(_)));
606+
607+
let ascii_headers = ascii_enum_headers
608+
.into_iter()
609+
.map(|(k, v)| {
610+
let RpcMetadataValue::Str(s) = v else {
611+
unreachable!();
612+
};
613+
(k, s)
614+
})
615+
.collect();
616+
let binary_headers = binary_enum_headers
617+
.into_iter()
618+
.map(|(k, v)| {
619+
let RpcMetadataValue::Bytes(b) = v else {
620+
unreachable!();
621+
};
622+
(k, b)
623+
})
624+
.collect();
625+
626+
(ascii_headers, binary_headers)
627+
}
628+
598629
impl TryFrom<ClientConfig> for ClientOptions {
599630
type Error = PyErr;
600631

601632
fn try_from(opts: ClientConfig) -> PyResult<Self> {
602633
let mut gateway_opts = ClientOptionsBuilder::default();
634+
let (ascii_headers, binary_headers) = partition_headers(opts.metadata);
603635
gateway_opts
604636
.target_url(
605637
Url::parse(&opts.target_url)
@@ -614,7 +646,8 @@ impl TryFrom<ClientConfig> for ClientOptions {
614646
)
615647
.keep_alive(opts.keep_alive_config.map(Into::into))
616648
.http_connect_proxy(opts.http_connect_proxy_config.map(Into::into))
617-
.headers(Some(opts.metadata))
649+
.headers(Some(ascii_headers))
650+
.binary_headers(Some(binary_headers))
618651
.api_key(opts.api_key);
619652
// Builder does not allow us to set option here, so we have to make
620653
// a conditional to even call it

temporalio/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def connect(
118118
tls: Union[bool, TLSConfig] = False,
119119
retry_config: Optional[RetryConfig] = None,
120120
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
121-
rpc_metadata: Mapping[str, str] = {},
121+
rpc_metadata: Mapping[str, str | bytes] = {},
122122
identity: Optional[str] = None,
123123
lazy: bool = False,
124124
runtime: Optional[temporalio.runtime.Runtime] = None,
@@ -296,7 +296,7 @@ def data_converter(self) -> temporalio.converter.DataConverter:
296296
return self._config["data_converter"]
297297

298298
@property
299-
def rpc_metadata(self) -> Mapping[str, str]:
299+
def rpc_metadata(self) -> Mapping[str, str | bytes]:
300300
"""Headers for every call made by this client.
301301
302302
Do not use mutate this mapping. Rather, set this property with an
@@ -305,7 +305,7 @@ def rpc_metadata(self) -> Mapping[str, str]:
305305
return self.service_client.config.rpc_metadata
306306

307307
@rpc_metadata.setter
308-
def rpc_metadata(self, value: Mapping[str, str]) -> None:
308+
def rpc_metadata(self, value: Mapping[str, str | bytes]) -> None:
309309
"""Update the headers for this client.
310310
311311
Do not mutate this mapping after set. Rather, set an entirely new
@@ -7202,7 +7202,7 @@ async def connect(
72027202
tls: Union[bool, TLSConfig] = True,
72037203
retry_config: Optional[RetryConfig] = None,
72047204
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
7205-
rpc_metadata: Mapping[str, str] = {},
7205+
rpc_metadata: Mapping[str, str | bytes] = {},
72067206
identity: Optional[str] = None,
72077207
lazy: bool = False,
72087208
runtime: Optional[temporalio.runtime.Runtime] = None,
@@ -7294,7 +7294,7 @@ def identity(self) -> str:
72947294
return self._service_client.config.identity
72957295

72967296
@property
7297-
def rpc_metadata(self) -> Mapping[str, str]:
7297+
def rpc_metadata(self) -> Mapping[str, str | bytes]:
72987298
"""Headers for every call made by this client.
72997299
73007300
Do not use mutate this mapping. Rather, set this property with an
@@ -7304,7 +7304,7 @@ def rpc_metadata(self) -> Mapping[str, str]:
73047304
return self.service_client.config.rpc_metadata
73057305

73067306
@rpc_metadata.setter
7307-
def rpc_metadata(self, value: Mapping[str, str]) -> None:
7307+
def rpc_metadata(self, value: Mapping[str, str | bytes]) -> None:
73087308
"""Update the headers for this client.
73097309
73107310
Do not mutate this mapping after set. Rather, set an entirely new

temporalio/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class ConnectConfig:
143143
tls: Union[bool, TLSConfig] = False
144144
retry_config: Optional[RetryConfig] = None
145145
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
146-
rpc_metadata: Mapping[str, str] = field(default_factory=dict)
146+
rpc_metadata: Mapping[str, str | bytes] = field(default_factory=dict)
147147
identity: str = ""
148148
lazy: bool = False
149149
runtime: Optional[temporalio.runtime.Runtime] = None
@@ -260,7 +260,7 @@ def worker_service_client(self) -> _BridgeServiceClient:
260260
raise NotImplementedError
261261

262262
@abstractmethod
263-
def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
263+
def update_rpc_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
264264
"""Update service client's RPC metadata."""
265265
raise NotImplementedError
266266

@@ -1312,7 +1312,7 @@ def worker_service_client(self) -> _BridgeServiceClient:
13121312
"""Underlying service client."""
13131313
return self
13141314

1315-
def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
1315+
def update_rpc_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
13161316
"""Update Core client metadata."""
13171317
# Mutate the bridge config and then only mutate the running client
13181318
# metadata if already connected

temporalio/testing/_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def start_local(
8484
temporalio.common.QueryRejectCondition
8585
] = None,
8686
retry_config: Optional[temporalio.client.RetryConfig] = None,
87-
rpc_metadata: Mapping[str, str] = {},
87+
rpc_metadata: Mapping[str, str | bytes] = {},
8888
identity: Optional[str] = None,
8989
tls: bool | temporalio.client.TLSConfig = False,
9090
ip: str = "127.0.0.1",
@@ -244,7 +244,7 @@ async def start_time_skipping(
244244
temporalio.common.QueryRejectCondition
245245
] = None,
246246
retry_config: Optional[temporalio.client.RetryConfig] = None,
247-
rpc_metadata: Mapping[str, str] = {},
247+
rpc_metadata: Mapping[str, str | bytes] = {},
248248
identity: Optional[str] = None,
249249
port: Optional[int] = None,
250250
download_dest_dir: Optional[str] = None,

tests/worker/test_update_with_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ async def __call__(
836836
req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest,
837837
*,
838838
retry: bool = False,
839-
metadata: Mapping[str, str] = {},
839+
metadata: Mapping[str, str | bytes] = {},
840840
timeout: Optional[timedelta] = None,
841841
) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse:
842842
raise self.empty_details_err

0 commit comments

Comments
 (0)