Skip to content

Commit ff2f074

Browse files
committed
Cancellation Tokens
1 parent 46a8ecb commit ff2f074

File tree

5 files changed

+50
-6
lines changed

5 files changed

+50
-6
lines changed

src/Temporalio/Bridge/CustomSlotSupplier.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
4646

4747
private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
4848
{
49+
var cancelTokenSrc = new System.Threading.CancellationTokenSource();
4950
var reserveTask = Task.Run(async () =>
5051
{
5152
while (true)
5253
{
5354
try
5455
{
55-
var permit = await userSupplier.ReserveSlotAsync(new(ctx)).ConfigureAwait(false);
56+
var permit = await userSupplier.ReserveSlotAsync(
57+
new(ctx), cancelTokenSrc.Token).ConfigureAwait(false);
5658
var usedPermitId = AddPermitToMap(permit);
5759
unsafe
5860
{

src/Temporalio/Bridge/include/temporal-sdk-bridge.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,13 @@ typedef struct SlotReserveCtx {
389389
struct ByteArrayRef worker_identity;
390390
struct ByteArrayRef worker_build_id;
391391
bool is_sticky;
392+
void *token_src;
392393
} SlotReserveCtx;
393394

394395
typedef void (*CustomReserveSlotCallback)(struct SlotReserveCtx ctx, void *sender);
395396

397+
typedef void (*CustomCancelReserveCallback)(void *token_source);
398+
396399
/**
397400
* Must return C#-tracked id for the permit. A zero value means no permit was reserved.
398401
*/
@@ -448,6 +451,7 @@ typedef void (*CustomReleaseSlotCallback)(struct SlotReleaseCtx ctx);
448451

449452
typedef struct CustomSlotSupplierCallbacks {
450453
CustomReserveSlotCallback reserve;
454+
CustomCancelReserveCallback cancel_reserve;
451455
CustomTryReserveSlotCallback try_reserve;
452456
CustomMarkSlotUsedCallback mark_used;
453457
CustomReleaseSlotCallback release;
@@ -694,6 +698,8 @@ struct WorkerReplayPushResult worker_replay_push(struct Worker *worker,
694698

695699
void complete_async_reserve(void *sender, uintptr_t permit_id);
696700

701+
void set_reserve_cancel_target(struct SlotReserveCtx *ctx, void *token_ptr);
702+
697703
#ifdef __cplusplus
698704
} // extern "C"
699705
#endif // __cplusplus

src/Temporalio/Bridge/src/worker.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ unsafe impl<SK> Sync for CustomSlotSupplier<SK> {}
9696

9797
type CustomReserveSlotCallback =
9898
unsafe extern "C" fn(ctx: SlotReserveCtx, sender: *mut libc::c_void);
99+
type CustomCancelReserveCallback = unsafe extern "C" fn(token_source: *mut libc::c_void);
99100
/// Must return C#-tracked id for the permit. A zero value means no permit was reserved.
100101
type CustomTryReserveSlotCallback = unsafe extern "C" fn(ctx: SlotReserveCtx) -> usize;
101102
type CustomMarkSlotUsedCallback = unsafe extern "C" fn(ctx: SlotMarkUsedCtx);
@@ -108,6 +109,7 @@ pub struct CustomSlotSupplierCallbacksImpl(*const CustomSlotSupplierCallbacks);
108109
#[repr(C)]
109110
pub struct CustomSlotSupplierCallbacks {
110111
reserve: CustomReserveSlotCallback,
112+
cancel_reserve: CustomCancelReserveCallback,
111113
try_reserve: CustomTryReserveSlotCallback,
112114
mark_used: CustomMarkSlotUsedCallback,
113115
release: CustomReleaseSlotCallback,
@@ -139,6 +141,8 @@ pub struct SlotReserveCtx {
139141
worker_identity: ByteArrayRef,
140142
worker_build_id: ByteArrayRef,
141143
is_sticky: bool,
144+
// The C# side will store a pointer here to the cancellation token source
145+
token_src: *mut libc::c_void,
142146
}
143147

144148
#[repr(C)]
@@ -169,6 +173,21 @@ pub struct SlotReleaseCtx {
169173
slot_permit: usize,
170174
}
171175

176+
struct CancelReserveGuard {
177+
token_src: *mut libc::c_void,
178+
callback: CustomCancelReserveCallback,
179+
}
180+
impl Drop for CancelReserveGuard {
181+
fn drop(&mut self) {
182+
if !self.token_src.is_null() {
183+
unsafe {
184+
(self.callback)(self.token_src);
185+
}
186+
}
187+
}
188+
}
189+
unsafe impl Send for CancelReserveGuard {}
190+
172191
#[async_trait::async_trait]
173192
impl<SK: SlotKind + Send + Sync> temporal_sdk_core_api::worker::SlotSupplier
174193
for CustomSlotSupplier<SK>
@@ -180,10 +199,13 @@ impl<SK: SlotKind + Send + Sync> temporal_sdk_core_api::worker::SlotSupplier
180199
let ctx = Self::convert_reserve_ctx(ctx);
181200
let tx = Box::into_raw(Box::new(tx)) as *mut libc::c_void;
182201
unsafe {
202+
let _drop_guard = CancelReserveGuard {
203+
token_src: ctx.token_src,
204+
callback: (*self.inner.0).cancel_reserve,
205+
};
183206
((*self.inner.0).reserve)(ctx, tx);
207+
rx.await.expect("reserve channel is not closed")
184208
}
185-
let r = rx.await.expect("reserve channel is not closed");
186-
r
187209
}
188210

189211
fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option<SlotSupplierPermit> {
@@ -244,6 +266,7 @@ impl<SK: SlotKind + Send + Sync> CustomSlotSupplier<SK> {
244266
worker_identity: ctx.worker_identity().into(),
245267
worker_build_id: ctx.worker_build_id().into(),
246268
is_sticky: ctx.is_sticky(),
269+
token_src: std::ptr::null_mut(),
247270
}
248271
}
249272

@@ -738,6 +761,16 @@ pub extern "C" fn complete_async_reserve(sender: *mut libc::c_void, permit_id: u
738761
}
739762
}
740763

764+
#[no_mangle]
765+
pub unsafe extern "C" fn set_reserve_cancel_target(
766+
ctx: *mut SlotReserveCtx,
767+
token_ptr: *mut libc::c_void,
768+
) {
769+
if let Some(ctx) = ctx.as_mut() {
770+
ctx.token_src = token_ptr;
771+
}
772+
}
773+
741774
impl TryFrom<&WorkerOptions> for temporal_sdk_core::WorkerConfig {
742775
type Error = anyhow::Error;
743776

src/Temporalio/Worker/Tuning/ICustomSlotSupplier.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Threading;
23
using System.Threading.Tasks;
34

45
namespace Temporalio.Worker.Tuning
@@ -22,9 +23,11 @@ public interface ICustomSlotSupplier : ISlotSupplier
2223
/// This method will be called concurrently from multiple threads, so it must be thread-safe.
2324
/// </remarks>
2425
/// <param name="ctx">The context for slot reservation.</param>
26+
/// <param name="cancellationToken">A cancellation token that the SDK may use
27+
/// to cancel the operation.</param>
2528
/// <returns>A permit to use the slot which may be populated with your own data.</returns>
2629
/// <exception cref="OperationCanceledException">Cancellation requested.</exception>
27-
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx);
30+
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken);
2831

2932
/// <summary>
3033
/// This function is called when trying to reserve slots for "eager" workflow and activity tasks.

tests/Temporalio.Tests/Worker/WorkerTuningTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ private class MySlotSupplier : ICustomSlotSupplier
156156

157157
public HashSet<bool> SeenReleaseInfoPresence { get; } = new();
158158

159-
public async Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx)
159+
public async Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken)
160160
{
161161
// Do something async to make sure that works
162162
await Task.Delay(10);
@@ -243,7 +243,7 @@ await Env.Client.ExecuteWorkflowAsync(
243243

244244
private class ThrowingSlotSupplier : ICustomSlotSupplier
245245
{
246-
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx)
246+
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken)
247247
{
248248
// Let the workflow complete, but other reservations fail
249249
if (ctx.SlotType == SlotType.Workflow)

0 commit comments

Comments
 (0)