Skip to content

Commit b641b9e

Browse files
committed
Calling async reserve working
1 parent 3381c61 commit b641b9e

File tree

7 files changed

+188
-16
lines changed

7 files changed

+188
-16
lines changed

temporalio/bridge/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

temporalio/bridge/Cargo.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ crate-type = ["cdylib"]
99

1010
[dependencies]
1111
anyhow = "1.0"
12+
async-trait = "0.1"
1213
futures = "0.3"
1314
log = "0.4"
1415
once_cell = "1.16"
@@ -17,10 +18,14 @@ prost-types = "0.13"
1718
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38", "anyhow"] }
1819
pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] }
1920
pythonize = "0.20"
20-
temporal-client = { version = "0.1.0", path = "./sdk-core/client" }
21-
temporal-sdk-core = { version = "0.1.0", path = "./sdk-core/core", features = ["ephemeral-server"] }
22-
temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
23-
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
21+
#temporal-client = { version = "0.1.0", path = "./sdk-core/client" }
22+
#temporal-sdk-core = { version = "0.1.0", path = "./sdk-core/core", features = ["ephemeral-server"] }
23+
#temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
24+
#temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
25+
temporal-client = { version = "0.1.0", path = "../../../sdk-core/client" }
26+
temporal-sdk-core = { version = "0.1.0", path = "../../../sdk-core/core", features = ["ephemeral-server"] }
27+
temporal-sdk-core-api = { version = "0.1.0", path = "../../../sdk-core/core-api" }
28+
temporal-sdk-core-protos = { version = "0.1.0", path = "../../../sdk-core/sdk-core-protos" }
2429
tokio = "1.26"
2530
tokio-stream = "0.1"
2631
tonic = "0.12"

temporalio/bridge/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ fn temporal_sdk_bridge(py: Python, m: &PyModule) -> PyResult<()> {
4545
)?;
4646
m.add_class::<worker::WorkerRef>()?;
4747
m.add_class::<worker::HistoryPusher>()?;
48+
m.add_class::<worker::CustomSlotSupplier>()?;
4849
m.add_function(wrap_pyfunction!(new_worker, m)?)?;
4950
m.add_function(wrap_pyfunction!(new_replay_worker, m)?)?;
5051
Ok(())

temporalio/bridge/src/runtime.rs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use temporal_sdk_core::telemetry::{
1313
build_otlp_metric_exporter, start_prometheus_metric_exporter, CoreLogStreamConsumer,
1414
MetricsCallBuffer,
1515
};
16-
use temporal_sdk_core::CoreRuntime;
16+
use temporal_sdk_core::{CoreRuntime, TokioRuntimeBuilder};
1717
use temporal_sdk_core_api::telemetry::metrics::{CoreMeter, MetricCallBufferer};
1818
use temporal_sdk_core_api::telemetry::{
1919
CoreLog, Logger, MetricTemporality, OtelCollectorOptionsBuilder,
@@ -88,6 +88,7 @@ const FORWARD_LOG_BUFFER_SIZE: usize = 2048;
8888
const FORWARD_LOG_MAX_FREQ_MS: u64 = 10;
8989

9090
pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
91+
dbg!("Initting runtime");
9192
// Have to build/start telemetry config pieces
9293
let mut telemetry_build = TelemetryOptionsBuilder::default();
9394

@@ -118,12 +119,33 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
118119
}
119120
}
120121

122+
let task_locals = Python::with_gil(|py| {
123+
let asyncio = py.import("asyncio")?;
124+
let event_loop = asyncio.call_method0("get_event_loop")?;
125+
dbg!(&event_loop);
126+
let locals = pyo3_asyncio::TaskLocals::with_running_loop(py)?.copy_context(py)?;
127+
PyResult::Ok(locals)
128+
})
129+
.expect("Works");
130+
121131
// Create core runtime which starts tokio multi-thread runtime
122132
let mut core = CoreRuntime::new(
123133
telemetry_build
124134
.build()
125135
.map_err(|err| PyValueError::new_err(format!("Invalid telemetry config: {}", err)))?,
126-
tokio::runtime::Builder::new_multi_thread(),
136+
TokioRuntimeBuilder {
137+
inner: tokio::runtime::Builder::new_multi_thread(),
138+
lang_on_thread_start: Some(move || {
139+
// Set task locals for each thread
140+
Python::with_gil(|py| {
141+
THREAD_TASK_LOCAL.with(|r| {
142+
std::cell::OnceCell::set(r, task_locals.clone()).expect("NOT ALREADY SET");
143+
});
144+
PyResult::Ok(())
145+
})
146+
.expect("Setting event loop works");
147+
}),
148+
},
127149
)
128150
.map_err(|err| PyRuntimeError::new_err(format!("Failed initializing telemetry: {}", err)))?;
129151

@@ -364,12 +386,17 @@ impl TryFrom<MetricsConfig> for Arc<dyn CoreMeter> {
364386
// altered to support spawning based on current Tokio runtime instead of a
365387
// single static one
366388

367-
struct TokioRuntime;
389+
pub(crate) struct TokioRuntime;
368390

369391
tokio::task_local! {
370392
static TASK_LOCALS: once_cell::unsync::OnceCell<pyo3_asyncio::TaskLocals>;
371393
}
372394

395+
thread_local! {
396+
pub(crate) static THREAD_TASK_LOCAL: std::cell::OnceCell<pyo3_asyncio::TaskLocals> =
397+
std::cell::OnceCell::new();
398+
}
399+
373400
impl pyo3_asyncio::generic::Runtime for TokioRuntime {
374401
type JoinError = tokio::task::JoinError;
375402
type JoinHandle = tokio::task::JoinHandle<()>;
@@ -378,9 +405,7 @@ impl pyo3_asyncio::generic::Runtime for TokioRuntime {
378405
where
379406
F: Future<Output = ()> + Send + 'static,
380407
{
381-
tokio::runtime::Handle::current().spawn(async move {
382-
fut.await;
383-
})
408+
tokio::runtime::Handle::current().spawn(fut)
384409
}
385410
}
386411

temporalio/bridge/src/worker.rs

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
use anyhow::Context;
22
use prost::Message;
3-
use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError};
3+
use pyo3::exceptions::{PyException, PyRuntimeError, PyTypeError, PyValueError};
44
use pyo3::prelude::*;
55
use pyo3::types::{PyBytes, PyTuple};
6+
use pyo3_asyncio::generic::ContextExt;
67
use std::collections::HashMap;
78
use std::collections::HashSet;
9+
use std::marker::PhantomData;
810
use std::sync::Arc;
911
use std::time::Duration;
1012
use temporal_sdk_core::api::errors::{PollActivityError, PollWfError};
1113
use temporal_sdk_core::replay::{HistoryForReplay, ReplayWorkerInput};
1214
use temporal_sdk_core_api::errors::WorkflowErrorType;
13-
use temporal_sdk_core_api::worker::SlotKind;
15+
use temporal_sdk_core_api::worker::{
16+
SlotKind, SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext,
17+
SlotSupplier as SlotSupplierTrait, SlotSupplierPermit,
18+
};
1419
use temporal_sdk_core_api::Worker;
1520
use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion;
1621
use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion};
@@ -20,6 +25,7 @@ use tokio_stream::wrappers::ReceiverStream;
2025

2126
use crate::client;
2227
use crate::runtime;
28+
use crate::runtime::{TokioRuntime, THREAD_TASK_LOCAL};
2329

2430
pyo3::create_exception!(temporal_sdk_bridge, PollShutdownError, PyException);
2531

@@ -63,6 +69,7 @@ pub struct TunerHolder {
6369
pub enum SlotSupplier {
6470
FixedSize(FixedSizeSlotSupplier),
6571
ResourceBased(ResourceBasedSlotSupplier),
72+
Custom(CustomSlotSupplier),
6673
}
6774

6875
#[derive(FromPyObject)]
@@ -79,6 +86,125 @@ pub struct ResourceBasedSlotSupplier {
7986
tuner_config: ResourceBasedTunerConfig,
8087
}
8188

89+
#[pyclass]
90+
pub struct SlotReserveCtx {
91+
slot_type: String, // TODO: Real type
92+
task_queue: String,
93+
worker_identity: String,
94+
worker_build_id: String,
95+
is_sticky: bool,
96+
}
97+
98+
impl SlotReserveCtx {
99+
fn from_ctx(slot_type: String, ctx: &dyn SlotReservationContext) -> Self {
100+
SlotReserveCtx {
101+
slot_type,
102+
task_queue: ctx.task_queue().to_string(),
103+
worker_identity: ctx.worker_identity().to_string(),
104+
worker_build_id: ctx.worker_build_id().to_string(),
105+
is_sticky: ctx.is_sticky(),
106+
}
107+
}
108+
}
109+
110+
#[pyclass]
111+
pub struct SlotMarkUsedCtx {}
112+
113+
#[pyclass]
114+
pub struct SlotReleaseCtx {}
115+
116+
#[pyclass]
117+
#[derive(Clone)]
118+
pub struct CustomSlotSupplier {
119+
inner: PyObject,
120+
}
121+
122+
struct CustomSlotSupplierOfType<SK: SlotKind> {
123+
inner: PyObject,
124+
_phantom: PhantomData<SK>,
125+
}
126+
127+
#[pymethods]
128+
impl CustomSlotSupplier {
129+
#[new]
130+
fn new(inner: PyObject) -> Self {
131+
CustomSlotSupplier { inner }
132+
}
133+
}
134+
135+
impl<SK: SlotKind> CustomSlotSupplierOfType<SK> {
136+
fn call_method<P: IntoPy<PyObject>, F: FnOnce(Python<'_>, &PyAny) -> FR, FR>(
137+
&self,
138+
method_name: &str,
139+
arg: P,
140+
post_closure: F,
141+
) -> FR {
142+
Python::with_gil(|py| {
143+
let py_obj = self.inner.as_ref(py);
144+
let method = py_obj
145+
.getattr(method_name)
146+
.map_err(|_| {
147+
PyTypeError::new_err(format!(
148+
"CustomSlotSupplier must implement '{}' method",
149+
method_name
150+
))
151+
})
152+
.expect("TODO");
153+
154+
post_closure(py, method.call((arg.into_py(py),), None).expect("TODO"))
155+
})
156+
}
157+
}
158+
159+
#[async_trait::async_trait]
160+
impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<SK> {
161+
type SlotKind = SK;
162+
163+
async fn reserve_slot(&self, ctx: &dyn SlotReservationContext) -> SlotSupplierPermit {
164+
dbg!("Trying to reserve slot");
165+
let pypermit = Python::with_gil(|py| {
166+
let py_obj = self.inner.as_ref(py);
167+
let called = py_obj.call_method1(
168+
"reserve_slot",
169+
(SlotReserveCtx::from_ctx(
170+
Self::SlotKind::kind().to_string(),
171+
ctx,
172+
),),
173+
)?;
174+
THREAD_TASK_LOCAL
175+
.with(|tl| pyo3_asyncio::into_future_with_locals(tl.get().unwrap(), called))
176+
})
177+
.expect("TODO")
178+
.await;
179+
SlotSupplierPermit::with_user_data(pypermit)
180+
}
181+
182+
fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option<SlotSupplierPermit> {
183+
self.call_method(
184+
"try_reserve_slot",
185+
SlotReserveCtx::from_ctx(Self::SlotKind::kind().to_string(), ctx),
186+
|py, pa| {
187+
if pa.is_none() {
188+
return None;
189+
}
190+
Some(SlotSupplierPermit::with_user_data(pa.into_py(py)))
191+
},
192+
)
193+
}
194+
195+
fn mark_slot_used(&self, _ctx: &dyn SlotMarkUsedContext<SlotKind = Self::SlotKind>) {
196+
self.call_method("mark_slot_used", SlotMarkUsedCtx {}, |_, _| ())
197+
}
198+
199+
fn release_slot(&self, _ctx: &dyn SlotReleaseContext<SlotKind = Self::SlotKind>) {
200+
self.call_method("release_slot", SlotReleaseCtx {}, |_, _| ())
201+
}
202+
203+
fn available_slots(&self) -> Option<usize> {
204+
None
205+
}
206+
}
207+
82208
#[derive(FromPyObject, Clone, Copy, PartialEq)]
83209
pub struct ResourceBasedTunerConfig {
84210
target_memory_usage: f64,
@@ -369,7 +495,9 @@ impl TryFrom<TunerHolder> for temporal_sdk_core::TunerHolder {
369495
}
370496
}
371497

372-
impl<SK: SlotKind> TryFrom<SlotSupplier> for temporal_sdk_core::SlotSupplierOptions<SK> {
498+
impl<SK: SlotKind + Send + Sync + 'static> TryFrom<SlotSupplier>
499+
for temporal_sdk_core::SlotSupplierOptions<SK>
500+
{
373501
type Error = PyErr;
374502

375503
fn try_from(supplier: SlotSupplier) -> PyResult<temporal_sdk_core::SlotSupplierOptions<SK>> {
@@ -386,6 +514,12 @@ impl<SK: SlotKind> TryFrom<SlotSupplier> for temporal_sdk_core::SlotSupplierOpti
386514
),
387515
)
388516
}
517+
SlotSupplier::Custom(cs) => temporal_sdk_core::SlotSupplierOptions::Custom(Arc::new(
518+
CustomSlotSupplierOfType::<SK> {
519+
inner: cs.inner,
520+
_phantom: PhantomData,
521+
},
522+
)),
389523
})
390524
}
391525
}

temporalio/worker/_tuning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing_extensions import TypeAlias
77

88
import temporalio.bridge.worker
9+
import temporalio.bridge.temporal_sdk_bridge
910
from temporalio.bridge.worker import (
1011
CustomSlotSupplier,
1112
SlotMarkUsedContext,
@@ -108,6 +109,8 @@ def _to_bridge_slot_supplier(
108109
slot_supplier.tuner_config.target_cpu_usage,
109110
),
110111
)
112+
elif isinstance(slot_supplier, CustomSlotSupplier):
113+
return temporalio.bridge.temporal_sdk_bridge.CustomSlotSupplier(slot_supplier)
111114
else:
112115
raise TypeError(f"Unknown slot supplier type: {slot_supplier}")
113116

tests/worker/test_worker.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,22 +344,25 @@ async def test_warns_when_workers_too_lot(client: Client, env: WorkflowEnvironme
344344
async def test_custom_slot_supplier(client: Client, env: WorkflowEnvironment):
345345
class MySlotSupplier(CustomSlotSupplier):
346346
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
347+
print("Reserving slot")
347348
return SlotPermit()
348349

349350
def try_reserve_slot(self, ctx: SlotReserveContext) -> Optional[SlotPermit]:
350-
pass
351+
print("Try Reserving slot")
352+
return None
351353

352354
def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
353-
pass
355+
print("Marking slot used")
354356

355357
def release_slot(self, ctx: SlotReleaseContext) -> None:
356-
pass
358+
print("Releasing slot")
357359

358360
ss = MySlotSupplier()
359361

360362
tuner = WorkerTuner.create_composite(
361363
workflow_supplier=ss, activity_supplier=ss, local_activity_supplier=ss
362364
)
365+
print("!!!!!!!!1 About to create new worker, event loop is", asyncio.get_event_loop())
363366
async with new_worker(
364367
client,
365368
WaitOnSignalWorkflow,

0 commit comments

Comments
 (0)