Skip to content

Commit 490dc21

Browse files
committed
Passing through contexts working
1 parent b641b9e commit 490dc21

File tree

8 files changed

+258
-85
lines changed

8 files changed

+258
-85
lines changed

temporalio/bridge/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pythonize = "0.20"
2525
temporal-client = { version = "0.1.0", path = "../../../sdk-core/client" }
2626
temporal-sdk-core = { version = "0.1.0", path = "../../../sdk-core/core", features = ["ephemeral-server"] }
2727
temporal-sdk-core-api = { version = "0.1.0", path = "../../../sdk-core/core-api" }
28-
temporal-sdk-core-protos = { version = "0.1.0", path = "../../../sdk-core/sdk-core-protos" }
28+
temporal-sdk-core-protos = { version = "0.1.0", path = "../../../sdk-core/sdk-core-protos", features = ["serde_serialize"] }
2929
tokio = "1.26"
3030
tokio-stream = "0.1"
3131
tonic = "0.12"

temporalio/bridge/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ fn temporal_sdk_bridge(py: Python, m: &PyModule) -> PyResult<()> {
4646
m.add_class::<worker::WorkerRef>()?;
4747
m.add_class::<worker::HistoryPusher>()?;
4848
m.add_class::<worker::CustomSlotSupplier>()?;
49+
m.add_class::<worker::SlotReserveCtx>()?;
50+
m.add_class::<worker::SlotReleaseCtx>()?;
51+
m.add_class::<worker::SlotMarkUsedCtx>()?;
52+
m.add_class::<worker::WorkflowSlotInfo>()?;
53+
m.add_class::<worker::ActivitySlotInfo>()?;
54+
m.add_class::<worker::LocalActivitySlotInfo>()?;
4955
m.add_function(wrap_pyfunction!(new_worker, m)?)?;
5056
m.add_function(wrap_pyfunction!(new_replay_worker, m)?)?;
5157
Ok(())

temporalio/bridge/src/runtime.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ const FORWARD_LOG_BUFFER_SIZE: usize = 2048;
8888
const FORWARD_LOG_MAX_FREQ_MS: u64 = 10;
8989

9090
pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
91-
dbg!("Initting runtime");
9291
// Have to build/start telemetry config pieces
9392
let mut telemetry_build = TelemetryOptionsBuilder::default();
9493

@@ -120,9 +119,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
120119
}
121120

122121
let task_locals = Python::with_gil(|py| {
123-
let asyncio = py.import("asyncio")?;
124-
let event_loop = asyncio.call_method0("get_event_loop")?;
125-
dbg!(&event_loop);
122+
// Event loop is assumed to be running at this point
126123
let locals = pyo3_asyncio::TaskLocals::with_running_loop(py)?.copy_context(py)?;
127124
PyResult::Ok(locals)
128125
})
@@ -137,7 +134,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
137134
inner: tokio::runtime::Builder::new_multi_thread(),
138135
lang_on_thread_start: Some(move || {
139136
// Set task locals for each thread
140-
Python::with_gil(|py| {
137+
Python::with_gil(|_| {
141138
THREAD_TASK_LOCAL.with(|r| {
142139
std::cell::OnceCell::set(r, task_locals.clone()).expect("NOT ALREADY SET");
143140
});

temporalio/bridge/src/worker.rs

Lines changed: 136 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use anyhow::Context;
22
use prost::Message;
3-
use pyo3::exceptions::{PyException, PyRuntimeError, PyTypeError, PyValueError};
3+
use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError};
44
use pyo3::prelude::*;
55
use pyo3::types::{PyBytes, PyTuple};
6-
use pyo3_asyncio::generic::ContextExt;
76
use std::collections::HashMap;
87
use std::collections::HashSet;
98
use std::marker::PhantomData;
@@ -13,8 +12,8 @@ use temporal_sdk_core::api::errors::{PollActivityError, PollWfError};
1312
use temporal_sdk_core::replay::{HistoryForReplay, ReplayWorkerInput};
1413
use temporal_sdk_core_api::errors::WorkflowErrorType;
1514
use temporal_sdk_core_api::worker::{
16-
SlotKind, SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext,
17-
SlotSupplier as SlotSupplierTrait, SlotSupplierPermit,
15+
SlotInfo, SlotInfoTrait, SlotKind, SlotMarkUsedContext, SlotReleaseContext,
16+
SlotReservationContext, SlotSupplier as SlotSupplierTrait, SlotSupplierPermit,
1817
};
1918
use temporal_sdk_core_api::Worker;
2019
use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion;
@@ -25,7 +24,6 @@ use tokio_stream::wrappers::ReceiverStream;
2524

2625
use crate::client;
2726
use crate::runtime;
28-
use crate::runtime::{TokioRuntime, THREAD_TASK_LOCAL};
2927

3028
pyo3::create_exception!(temporal_sdk_bridge, PollShutdownError, PyException);
3129

@@ -88,11 +86,16 @@ pub struct ResourceBasedSlotSupplier {
8886

8987
#[pyclass]
9088
pub struct SlotReserveCtx {
91-
slot_type: String, // TODO: Real type
92-
task_queue: String,
93-
worker_identity: String,
94-
worker_build_id: String,
95-
is_sticky: bool,
89+
#[pyo3(get)]
90+
pub slot_type: String, // TODO: Real type
91+
#[pyo3(get)]
92+
pub task_queue: String,
93+
#[pyo3(get)]
94+
pub worker_identity: String,
95+
#[pyo3(get)]
96+
pub worker_build_id: String,
97+
#[pyo3(get)]
98+
pub is_sticky: bool,
9699
}
97100

98101
impl SlotReserveCtx {
@@ -108,10 +111,60 @@ impl SlotReserveCtx {
108111
}
109112

110113
#[pyclass]
111-
pub struct SlotMarkUsedCtx {}
114+
pub struct SlotMarkUsedCtx {
115+
#[pyo3(get)]
116+
slot_info: PyObject,
117+
#[pyo3(get)]
118+
permit: PyObject,
119+
}
112120

121+
// NOTE: this is dumb because we already have the generated proto code, we just can't use
122+
// it b/c it's not pyclassable. In theory maybe we could compile-flag enable it in the core
123+
// protos crate but... that's a lot for just this. Maybe if there are other use cases.
124+
125+
#[pyclass]
126+
pub struct WorkflowSlotInfo {
127+
#[pyo3(get)]
128+
pub workflow_type: String,
129+
#[pyo3(get)]
130+
pub is_sticky: bool,
131+
}
132+
#[pyclass]
133+
pub struct ActivitySlotInfo {
134+
#[pyo3(get)]
135+
pub activity_type: String,
136+
}
113137
#[pyclass]
114-
pub struct SlotReleaseCtx {}
138+
pub struct LocalActivitySlotInfo {
139+
#[pyo3(get)]
140+
pub activity_type: String,
141+
}
142+
143+
#[pyclass]
144+
pub struct SlotReleaseCtx {
145+
#[pyo3(get)]
146+
slot_info: Option<PyObject>,
147+
#[pyo3(get)]
148+
permit: PyObject,
149+
}
150+
151+
fn slot_info_to_py_obj(py: Python<'_>, info: SlotInfo) -> PyObject {
152+
match info {
153+
SlotInfo::Workflow(w) => WorkflowSlotInfo {
154+
workflow_type: w.workflow_type.clone(),
155+
is_sticky: w.is_sticky,
156+
}
157+
.into_py(py),
158+
SlotInfo::Activity(a) => ActivitySlotInfo {
159+
activity_type: a.activity_type.clone(),
160+
}
161+
.into_py(py),
162+
SlotInfo::LocalActivity(a) => LocalActivitySlotInfo {
163+
activity_type: a.activity_type.clone(),
164+
}
165+
.into_py(py),
166+
}
167+
}
115168

116169
#[pyclass]
117170
#[derive(Clone)]
@@ -132,72 +185,94 @@ impl CustomSlotSupplier {
132185
}
133186
}
134187

135-
impl<SK: SlotKind> CustomSlotSupplierOfType<SK> {
136-
fn call_method<P: IntoPy<PyObject>, F: FnOnce(Python<'_>, &PyAny) -> FR, FR>(
137-
&self,
138-
method_name: &str,
139-
arg: P,
140-
post_closure: F,
141-
) -> FR {
142-
Python::with_gil(|py| {
143-
let py_obj = self.inner.as_ref(py);
144-
let method = py_obj
145-
.getattr(method_name)
146-
.map_err(|_| {
147-
PyTypeError::new_err(format!(
148-
"CustomSlotSupplier must implement '{}' method",
149-
method_name
150-
))
151-
})
152-
.expect("TODO");
153-
154-
post_closure(py, method.call((arg.into_py(py),), None).expect("TODO"))
155-
})
156-
}
157-
}
158-
159188
#[async_trait::async_trait]
160189
impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<SK> {
161190
type SlotKind = SK;
162191

163192
async fn reserve_slot(&self, ctx: &dyn SlotReservationContext) -> SlotSupplierPermit {
164-
dbg!("Trying to reserve slot");
165-
let pypermit = Python::with_gil(|py| {
193+
loop {
194+
let pypermit = Python::with_gil(|py| {
195+
let py_obj = self.inner.as_ref(py);
196+
let called = py_obj.call_method1(
197+
"reserve_slot",
198+
(SlotReserveCtx::from_ctx(
199+
Self::SlotKind::kind().to_string(),
200+
ctx,
201+
),),
202+
)?;
203+
runtime::THREAD_TASK_LOCAL
204+
.with(|tl| pyo3_asyncio::into_future_with_locals(tl.get().unwrap(), called))
205+
})
206+
.expect("TODO")
207+
.await;
208+
match pypermit {
209+
Ok(p) => {
210+
return SlotSupplierPermit::with_user_data(p);
211+
}
212+
Err(e) => {
213+
dbg!("Error in reserve_slot", e);
214+
}
215+
}
216+
}
217+
}
218+
219+
fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option<SlotSupplierPermit> {
220+
Python::with_gil(|py| {
166221
let py_obj = self.inner.as_ref(py);
167-
let called = py_obj.call_method1(
168-
"reserve_slot",
222+
let pa = py_obj.call_method1(
223+
"try_reserve_slot",
169224
(SlotReserveCtx::from_ctx(
170225
Self::SlotKind::kind().to_string(),
171226
ctx,
172227
),),
173228
)?;
174-
THREAD_TASK_LOCAL
175-
.with(|tl| pyo3_asyncio::into_future_with_locals(tl.get().unwrap(), called))
229+
230+
if pa.is_none() {
231+
return Ok(None);
232+
}
233+
PyResult::Ok(Some(SlotSupplierPermit::with_user_data(pa.into_py(py))))
176234
})
177235
.expect("TODO")
178-
.await;
179-
SlotSupplierPermit::with_user_data(pypermit)
180236
}
181237

182-
fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option<SlotSupplierPermit> {
183-
self.call_method(
184-
"try_reserve_slot",
185-
SlotReserveCtx::from_ctx(Self::SlotKind::kind().to_string(), ctx),
186-
|py, pa| {
187-
if pa.is_none() {
188-
return None;
189-
}
190-
Some(SlotSupplierPermit::with_user_data(pa.into_py(py)))
191-
},
192-
)
193-
}
194-
195-
fn mark_slot_used(&self, _ctx: &dyn SlotMarkUsedContext<SlotKind = Self::SlotKind>) {
196-
self.call_method("mark_slot_used", SlotMarkUsedCtx {}, |_, _| ())
238+
fn mark_slot_used(&self, ctx: &dyn SlotMarkUsedContext<SlotKind = Self::SlotKind>) {
239+
Python::with_gil(|py| {
240+
let permit = ctx
241+
.permit()
242+
.user_data::<PyObject>()
243+
.cloned()
244+
.unwrap_or_else(|| py.None());
245+
let py_obj = self.inner.as_ref(py);
246+
py_obj.call_method1(
247+
"mark_slot_used",
248+
(SlotMarkUsedCtx {
249+
slot_info: slot_info_to_py_obj(py, ctx.info().downcast()),
250+
permit,
251+
},),
252+
)?;
253+
PyResult::Ok(())
254+
})
255+
.expect("TODO");
197256
}
198257

199-
fn release_slot(&self, _ctx: &dyn SlotReleaseContext<SlotKind = Self::SlotKind>) {
200-
self.call_method("release_slot", SlotReleaseCtx {}, |_, _| ())
258+
fn release_slot(&self, ctx: &dyn SlotReleaseContext<SlotKind = Self::SlotKind>) {
259+
Python::with_gil(|py| {
260+
let permit = ctx
261+
.permit()
262+
.user_data::<PyObject>()
263+
.cloned()
264+
.unwrap_or_else(|| py.None());
265+
let py_obj = self.inner.as_ref(py);
266+
py_obj.call_method1(
267+
"release_slot",
268+
(SlotReleaseCtx {
269+
slot_info: ctx.info().map(|i| slot_info_to_py_obj(py, i.downcast())),
270+
permit,
271+
},),
272+
)?;
273+
PyResult::Ok(())
274+
})
275+
.expect("TODO");
201276
}
202277

203278
fn available_slots(&self) -> Option<usize> {

0 commit comments

Comments
 (0)