Skip to content

Commit 5b95612

Browse files
committed
Actually cancel reserve slot tasks
1 parent 9f0f51d commit 5b95612

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

temporalio/bridge/src/worker.rs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pyo3::types::{PyBytes, PyTuple};
77
use std::collections::HashMap;
88
use std::collections::HashSet;
99
use std::marker::PhantomData;
10-
use std::sync::Arc;
10+
use std::sync::{Arc, OnceLock};
1111
use std::time::Duration;
1212
use temporal_sdk_core::api::errors::{PollActivityError, PollWfError};
1313
use temporal_sdk_core::replay::{HistoryForReplay, ReplayWorkerInput};
@@ -202,16 +202,58 @@ impl CustomSlotSupplier {
202202
}
203203
}
204204

205+
#[pyclass]
206+
struct CreatedTaskForSlotCallback {
207+
stored_task: Arc<OnceLock<PyObject>>,
208+
}
209+
210+
#[pymethods]
211+
impl CreatedTaskForSlotCallback {
212+
fn __call__(&self, task: PyObject) -> PyResult<()> {
213+
self.stored_task.set(task).expect("must only be set once");
214+
Ok(())
215+
}
216+
}
217+
218+
struct TaskCanceller {
219+
stored_task: Arc<OnceLock<PyObject>>,
220+
}
221+
222+
impl TaskCanceller {
223+
fn new(stored_task: Arc<OnceLock<PyObject>>) -> Self {
224+
TaskCanceller { stored_task }
225+
}
226+
}
227+
228+
impl Drop for TaskCanceller {
229+
fn drop(&mut self) {
230+
if let Some(task) = self.stored_task.get() {
231+
Python::with_gil(|py| {
232+
task.call_method0(py, "cancel")
233+
.expect("Failed to cancel task");
234+
});
235+
}
236+
}
237+
}
238+
205239
#[async_trait::async_trait]
206240
impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<SK> {
207241
type SlotKind = SK;
208242

209243
async fn reserve_slot(&self, ctx: &dyn SlotReservationContext) -> SlotSupplierPermit {
244+
dbg!("Invoking reserve first time");
210245
loop {
246+
let stored_task = Arc::new(OnceLock::new());
247+
let _task_canceller = TaskCanceller::new(stored_task.clone());
211248
let pypermit = match Python::with_gil(|py| {
212249
let py_obj = self.inner.as_ref(py);
213-
let called = py_obj
214-
.call_method1("reserve_slot", (SlotReserveCtx::from_ctx(SK::kind(), ctx),))?;
250+
let called = py_obj.call_method1(
251+
"reserve_slot",
252+
(
253+
SlotReserveCtx::from_ctx(SK::kind(), ctx),
254+
CreatedTaskForSlotCallback { stored_task },
255+
),
256+
)?;
215257
runtime::THREAD_TASK_LOCAL
216258
.with(|tl| pyo3_asyncio::into_future_with_locals(tl.get().unwrap(), called))
217259
}) {

temporalio/worker/_tuning.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from datetime import timedelta
6-
from typing import Literal, Optional, Union
6+
from typing import Any, Callable, Literal, Optional, Union
77

88
from typing_extensions import TypeAlias
99

@@ -90,15 +90,18 @@ class ResourceBasedSlotSupplier:
9090
]
9191

9292

93-
class _ErrorLoggingSlotSupplier(CustomSlotSupplier):
93+
class _BridgeSlotSupplierWrapper:
9494
def __init__(self, supplier: CustomSlotSupplier):
9595
self._supplier = supplier
9696

97-
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
97+
async def reserve_slot(
98+
self, ctx: SlotReserveContext, reserve_cb: Callable[[Any], None]
99+
) -> SlotPermit:
98100
try:
99-
return await self._supplier.reserve_slot(ctx)
101+
reserve_fut = asyncio.create_task(self._supplier.reserve_slot(ctx))
102+
reserve_cb(reserve_fut)
103+
return await reserve_fut
100104
except asyncio.CancelledError:
101-
logger.exception("saw cancelled error")
102105
raise
103106
except Exception:
104107
logger.warning(
@@ -161,7 +164,7 @@ def _to_bridge_slot_supplier(
161164
)
162165
elif isinstance(slot_supplier, CustomSlotSupplier):
163166
return temporalio.bridge.worker.BridgeCustomSlotSupplier(
164-
_ErrorLoggingSlotSupplier(slot_supplier)
167+
_BridgeSlotSupplierWrapper(slot_supplier)
165168
)
166169
else:
167170
raise TypeError(f"Unknown slot supplier type: {slot_supplier}")

0 commit comments

Comments
 (0)