Skip to content

Commit 14539b3

Browse files
committed
point sdk-core to fork, add more complete metadata patch
1 parent 7509e71 commit 14539b3

File tree

8 files changed

+53
-20
lines changed

8 files changed

+53
-20
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "sdk-core"]
22
path = temporalio/bridge/sdk-core
3-
url = https://github.com/temporalio/sdk-core.git
3+
url = https://github.com/jazev-stripe/sdk-core.git

temporalio/bridge/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ClientConfig:
5959
"""Python representation of the Rust struct for configuring the client."""
6060

6161
target_url: str
62-
metadata: Mapping[str, str]
62+
metadata: Mapping[str, str | bytes]
6363
api_key: Optional[str]
6464
identity: str
6565
tls_config: Optional[ClientTlsConfig]
@@ -108,7 +108,7 @@ def __init__(
108108
self._runtime = runtime
109109
self._ref = ref
110110

111-
def update_metadata(self, metadata: Mapping[str, str]) -> None:
111+
def update_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
112112
"""Update underlying metadata on Core client."""
113113
self._ref.update_metadata(metadata)
114114

temporalio/bridge/src/client.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub struct ClientConfig {
3030
target_url: String,
3131
client_name: String,
3232
client_version: String,
33-
metadata: HashMap<String, String>,
33+
metadata: HashMap<String, RpcMetadataValue>,
3434
api_key: Option<String>,
3535
identity: String,
3636
tls_config: Option<ClientTlsConfig>,
@@ -126,8 +126,10 @@ macro_rules! rpc_call_on_trait {
126126

127127
#[pymethods]
128128
impl ClientRef {
129-
fn update_metadata(&self, headers: HashMap<String, String>) {
130-
self.retry_client.get_client().set_headers(headers);
129+
fn update_metadata(&self, headers: HashMap<String, RpcMetadataValue>) {
130+
let (ascii_headers, binary_headers) = partition_headers(headers);
131+
self.retry_client.get_client().set_headers(ascii_headers);
132+
self.retry_client.get_client().set_binary_headers(binary_headers);
131133
}
132134

133135
fn update_api_key(&self, api_key: Option<String>) {
@@ -598,11 +600,41 @@ where
598600
}
599601
}
600602

603+
fn partition_headers(
604+
headers: HashMap<String, RpcMetadataValue>,
605+
) -> (HashMap<String, String>, HashMap<String, Vec<u8>>) {
606+
let (ascii_enum_headers, binary_enum_headers): (HashMap<_, _>, HashMap<_, _>) = headers
607+
.into_iter()
608+
.partition(|(_, v)| matches!(v, RpcMetadataValue::Str(_)));
609+
610+
let ascii_headers = ascii_enum_headers
611+
.into_iter()
612+
.map(|(k, v)| {
613+
let RpcMetadataValue::Str(s) = v else {
614+
unreachable!();
615+
};
616+
(k, s)
617+
})
618+
.collect();
619+
let binary_headers = binary_enum_headers
620+
.into_iter()
621+
.map(|(k, v)| {
622+
let RpcMetadataValue::Bytes(b) = v else {
623+
unreachable!();
624+
};
625+
(k, b)
626+
})
627+
.collect();
628+
629+
(ascii_headers, binary_headers)
630+
}
631+
601632
impl TryFrom<ClientConfig> for ClientOptions {
602633
type Error = PyErr;
603634

604635
fn try_from(opts: ClientConfig) -> PyResult<Self> {
605636
let mut gateway_opts = ClientOptionsBuilder::default();
637+
let (ascii_headers, binary_headers) = partition_headers(opts.metadata);
606638
gateway_opts
607639
.target_url(
608640
Url::parse(&opts.target_url)
@@ -617,7 +649,8 @@ impl TryFrom<ClientConfig> for ClientOptions {
617649
)
618650
.keep_alive(opts.keep_alive_config.map(Into::into))
619651
.http_connect_proxy(opts.http_connect_proxy_config.map(Into::into))
620-
.headers(Some(opts.metadata))
652+
.headers(Some(ascii_headers))
653+
.binary_headers(Some(binary_headers))
621654
.api_key(opts.api_key);
622655
// Builder does not allow us to set option here, so we have to make
623656
// a conditional to even call it

temporalio/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def connect(
118118
tls: Union[bool, TLSConfig] = False,
119119
retry_config: Optional[RetryConfig] = None,
120120
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
121-
rpc_metadata: Mapping[str, str] = {},
121+
rpc_metadata: Mapping[str, str | bytes] = {},
122122
identity: Optional[str] = None,
123123
lazy: bool = False,
124124
runtime: Optional[temporalio.runtime.Runtime] = None,
@@ -296,7 +296,7 @@ def data_converter(self) -> temporalio.converter.DataConverter:
296296
return self._config["data_converter"]
297297

298298
@property
299-
def rpc_metadata(self) -> Mapping[str, str]:
299+
def rpc_metadata(self) -> Mapping[str, str | bytes]:
300300
"""Headers for every call made by this client.
301301
302302
Do not use mutate this mapping. Rather, set this property with an
@@ -305,7 +305,7 @@ def rpc_metadata(self) -> Mapping[str, str]:
305305
return self.service_client.config.rpc_metadata
306306

307307
@rpc_metadata.setter
308-
def rpc_metadata(self, value: Mapping[str, str]) -> None:
308+
def rpc_metadata(self, value: Mapping[str, str | bytes]) -> None:
309309
"""Update the headers for this client.
310310
311311
Do not mutate this mapping after set. Rather, set an entirely new
@@ -7209,7 +7209,7 @@ async def connect(
72097209
tls: Union[bool, TLSConfig] = True,
72107210
retry_config: Optional[RetryConfig] = None,
72117211
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
7212-
rpc_metadata: Mapping[str, str] = {},
7212+
rpc_metadata: Mapping[str, str | bytes] = {},
72137213
identity: Optional[str] = None,
72147214
lazy: bool = False,
72157215
runtime: Optional[temporalio.runtime.Runtime] = None,
@@ -7301,7 +7301,7 @@ def identity(self) -> str:
73017301
return self._service_client.config.identity
73027302

73037303
@property
7304-
def rpc_metadata(self) -> Mapping[str, str]:
7304+
def rpc_metadata(self) -> Mapping[str, str | bytes]:
73057305
"""Headers for every call made by this client.
73067306
73077307
Do not use mutate this mapping. Rather, set this property with an
@@ -7311,7 +7311,7 @@ def rpc_metadata(self) -> Mapping[str, str]:
73117311
return self.service_client.config.rpc_metadata
73127312

73137313
@rpc_metadata.setter
7314-
def rpc_metadata(self, value: Mapping[str, str]) -> None:
7314+
def rpc_metadata(self, value: Mapping[str, str | bytes]) -> None:
73157315
"""Update the headers for this client.
73167316
73177317
Do not mutate this mapping after set. Rather, set an entirely new

temporalio/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class ConnectConfig:
143143
tls: Union[bool, TLSConfig] = False
144144
retry_config: Optional[RetryConfig] = None
145145
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
146-
rpc_metadata: Mapping[str, str] = field(default_factory=dict)
146+
rpc_metadata: Mapping[str, str | bytes] = field(default_factory=dict)
147147
identity: str = ""
148148
lazy: bool = False
149149
runtime: Optional[temporalio.runtime.Runtime] = None
@@ -264,7 +264,7 @@ def worker_service_client(self) -> _BridgeServiceClient:
264264
raise NotImplementedError
265265

266266
@abstractmethod
267-
def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
267+
def update_rpc_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
268268
"""Update service client's RPC metadata."""
269269
raise NotImplementedError
270270

@@ -1316,7 +1316,7 @@ def worker_service_client(self) -> _BridgeServiceClient:
13161316
"""Underlying service client."""
13171317
return self
13181318

1319-
def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
1319+
def update_rpc_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
13201320
"""Update Core client metadata."""
13211321
# Mutate the bridge config and then only mutate the running client
13221322
# metadata if already connected

temporalio/testing/_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def start_local(
8484
temporalio.common.QueryRejectCondition
8585
] = None,
8686
retry_config: Optional[temporalio.client.RetryConfig] = None,
87-
rpc_metadata: Mapping[str, str] = {},
87+
rpc_metadata: Mapping[str, str | bytes] = {},
8888
identity: Optional[str] = None,
8989
tls: bool | temporalio.client.TLSConfig = False,
9090
ip: str = "127.0.0.1",
@@ -244,7 +244,7 @@ async def start_time_skipping(
244244
temporalio.common.QueryRejectCondition
245245
] = None,
246246
retry_config: Optional[temporalio.client.RetryConfig] = None,
247-
rpc_metadata: Mapping[str, str] = {},
247+
rpc_metadata: Mapping[str, str | bytes] = {},
248248
identity: Optional[str] = None,
249249
port: Optional[int] = None,
250250
download_dest_dir: Optional[str] = None,

tests/worker/test_update_with_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ async def __call__(
822822
req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest,
823823
*,
824824
retry: bool = False,
825-
metadata: Mapping[str, str] = {},
825+
metadata: Mapping[str, str | bytes] = {},
826826
timeout: Optional[timedelta] = None,
827827
) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse:
828828
raise self.empty_details_err

0 commit comments

Comments
 (0)