Skip to content

Commit 0ee6a32

Browse files
committed
Cancellation Tokens
1 parent f735b23 commit 0ee6a32

File tree

6 files changed

+146
-14
lines changed

6 files changed

+146
-14
lines changed

src/Temporalio/Bridge/CustomSlotSupplier.cs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Runtime.InteropServices;
4+
using System.Threading;
45
using System.Threading.Tasks;
56
using Microsoft.Extensions.Logging;
7+
using Temporalio.Bridge.Interop;
68

79
namespace Temporalio.Bridge
810
{
@@ -31,6 +33,7 @@ internal unsafe CustomSlotSupplier(
3133
var interopCallbacks = new Interop.CustomSlotSupplierCallbacks
3234
{
3335
reserve = FunctionPointer<Interop.CustomReserveSlotCallback>(Reserve),
36+
cancel_reserve = FunctionPointer<Interop.CustomCancelReserveCallback>(CancelReserve),
3437
try_reserve = FunctionPointer<Interop.CustomTryReserveSlotCallback>(TryReserve),
3538
mark_used = FunctionPointer<Interop.CustomMarkSlotUsedCallback>(MarkUsed),
3639
release = FunctionPointer<Interop.CustomReleaseSlotCallback>(Release),
@@ -39,25 +42,62 @@ internal unsafe CustomSlotSupplier(
3942
PinCallbackHolder(interopCallbacks);
4043
}
4144

45+
private static void SetCancelTokenOnCtx(ref SlotReserveCtx ctx, CancellationTokenSource cancelTokenSrc)
46+
{
47+
unsafe
48+
{
49+
try
50+
{
51+
var handle = GCHandle.Alloc(cancelTokenSrc);
52+
fixed (Interop.SlotReserveCtx* p = &ctx)
53+
{
54+
Interop.Methods.set_reserve_cancel_target(p, GCHandle.ToIntPtr(handle).ToPointer());
55+
}
56+
}
57+
catch (Exception e)
58+
{
59+
Console.WriteLine($"Error setting cancel token on ctx: {e}");
60+
throw;
61+
}
62+
}
63+
}
64+
4265
private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
4366
{
4467
SafeReserve(ctx, new IntPtr(sender));
4568
}
4669

70+
private unsafe void CancelReserve(void* tokenSrc)
71+
{
72+
var handle = GCHandle.FromIntPtr(new IntPtr(tokenSrc));
73+
var cancelTokenSrc = (CancellationTokenSource)handle.Target!;
74+
cancelTokenSrc.Cancel();
75+
handle.Free();
76+
}
77+
4778
private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
4879
{
4980
var reserveTask = Task.Run(async () =>
5081
{
82+
var cancelTokenSrc = new System.Threading.CancellationTokenSource();
83+
SetCancelTokenOnCtx(ref ctx, cancelTokenSrc);
5184
while (true)
5285
{
5386
try
5487
{
55-
var permit = await userSupplier.ReserveSlotAsync(new(ctx)).ConfigureAwait(false);
88+
var permit = await userSupplier.ReserveSlotAsync(
89+
new(ctx), cancelTokenSrc.Token).ConfigureAwait(false);
5690
var usedPermitId = AddPermitToMap(permit);
5791
unsafe
5892
{
5993
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
6094
}
95+
cancelTokenSrc.Dispose();
96+
return;
97+
}
98+
catch (OperationCanceledException)
99+
{
100+
cancelTokenSrc.Dispose();
61101
return;
62102
}
63103
#pragma warning disable CA1031 // We are ok catching all exceptions here
@@ -67,7 +107,7 @@ private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
67107
logger.LogError(e, "Error reserving slot");
68108
}
69109
// Wait for a bit to avoid spamming errors
70-
await Task.Delay(1000).ConfigureAwait(false);
110+
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
71111
}
72112
});
73113
}

src/Temporalio/Bridge/Interop/Interop.cs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ internal partial struct ResourceBasedSlotSupplier
558558
public ResourceBasedTunerOptions tuner_options;
559559
}
560560

561-
internal partial struct SlotReserveCtx
561+
internal unsafe partial struct SlotReserveCtx
562562
{
563563
[NativeTypeName("enum SlotKindType")]
564564
public SlotKindType slot_type;
@@ -574,11 +574,16 @@ internal partial struct SlotReserveCtx
574574

575575
[NativeTypeName("bool")]
576576
public byte is_sticky;
577+
578+
public void* token_src;
577579
}
578580

579581
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
580582
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx, void* sender);
581583

584+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
585+
internal unsafe delegate void CustomCancelReserveCallback(void* token_source);
586+
582587
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
583588
[return: NativeTypeName("uintptr_t")]
584589
internal delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
@@ -615,7 +620,7 @@ internal unsafe partial struct SlotInfo
615620
{
616621
public SlotInfo_Tag tag;
617622

618-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L422_C3")]
623+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L425_C3")]
619624
public _Anonymous_e__Union Anonymous;
620625

621626
internal ref WorkflowSlotInfo_Body workflow_slot_info
@@ -694,6 +699,9 @@ internal partial struct CustomSlotSupplierCallbacks
694699
[NativeTypeName("CustomReserveSlotCallback")]
695700
public IntPtr reserve;
696701

702+
[NativeTypeName("CustomCancelReserveCallback")]
703+
public IntPtr cancel_reserve;
704+
697705
[NativeTypeName("CustomTryReserveSlotCallback")]
698706
public IntPtr try_reserve;
699707

@@ -721,7 +729,7 @@ internal unsafe partial struct SlotSupplier
721729
{
722730
public SlotSupplier_Tag tag;
723731

724-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L468_C3")]
732+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C3")]
725733
public _Anonymous_e__Union Anonymous;
726734

727735
internal ref FixedSizeSlotSupplier fixed_size
@@ -761,15 +769,15 @@ internal ref CustomSlotSupplierCallbacksImpl custom
761769
internal unsafe partial struct _Anonymous_e__Union
762770
{
763771
[FieldOffset(0)]
764-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L469_C5")]
772+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L473_C5")]
765773
public _Anonymous1_e__Struct Anonymous1;
766774

767775
[FieldOffset(0)]
768-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C5")]
776+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
769777
public _Anonymous2_e__Struct Anonymous2;
770778

771779
[FieldOffset(0)]
772-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L475_C5")]
780+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
773781
public _Anonymous3_e__Struct Anonymous3;
774782

775783
internal partial struct _Anonymous1_e__Struct
@@ -1061,5 +1069,8 @@ internal static unsafe partial class Methods
10611069

10621070
[DllImport("temporal_sdk_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
10631071
public static extern void complete_async_reserve(void* sender, [NativeTypeName("uintptr_t")] UIntPtr permit_id);
1072+
1073+
[DllImport("temporal_sdk_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
1074+
public static extern void set_reserve_cancel_target([NativeTypeName("struct SlotReserveCtx *")] SlotReserveCtx* ctx, void* token_ptr);
10641075
}
10651076
}

src/Temporalio/Bridge/include/temporal-sdk-bridge.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,13 @@ typedef struct SlotReserveCtx {
389389
struct ByteArrayRef worker_identity;
390390
struct ByteArrayRef worker_build_id;
391391
bool is_sticky;
392+
void *token_src;
392393
} SlotReserveCtx;
393394

394395
typedef void (*CustomReserveSlotCallback)(struct SlotReserveCtx ctx, void *sender);
395396

397+
typedef void (*CustomCancelReserveCallback)(void *token_source);
398+
396399
/**
397400
* Must return C#-tracked id for the permit. A zero value means no permit was reserved.
398401
*/
@@ -448,6 +451,7 @@ typedef void (*CustomReleaseSlotCallback)(struct SlotReleaseCtx ctx);
448451

449452
typedef struct CustomSlotSupplierCallbacks {
450453
CustomReserveSlotCallback reserve;
454+
CustomCancelReserveCallback cancel_reserve;
451455
CustomTryReserveSlotCallback try_reserve;
452456
CustomMarkSlotUsedCallback mark_used;
453457
CustomReleaseSlotCallback release;
@@ -694,6 +698,8 @@ struct WorkerReplayPushResult worker_replay_push(struct Worker *worker,
694698

695699
void complete_async_reserve(void *sender, uintptr_t permit_id);
696700

701+
void set_reserve_cancel_target(struct SlotReserveCtx *ctx, void *token_ptr);
702+
697703
#ifdef __cplusplus
698704
} // extern "C"
699705
#endif // __cplusplus

src/Temporalio/Bridge/src/worker.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ unsafe impl<SK> Sync for CustomSlotSupplier<SK> {}
9696

9797
type CustomReserveSlotCallback =
9898
unsafe extern "C" fn(ctx: SlotReserveCtx, sender: *mut libc::c_void);
99+
type CustomCancelReserveCallback = unsafe extern "C" fn(token_source: *mut libc::c_void);
99100
/// Must return C#-tracked id for the permit. A zero value means no permit was reserved.
100101
type CustomTryReserveSlotCallback = unsafe extern "C" fn(ctx: SlotReserveCtx) -> usize;
101102
type CustomMarkSlotUsedCallback = unsafe extern "C" fn(ctx: SlotMarkUsedCtx);
@@ -108,6 +109,7 @@ pub struct CustomSlotSupplierCallbacksImpl(*const CustomSlotSupplierCallbacks);
108109
#[repr(C)]
109110
pub struct CustomSlotSupplierCallbacks {
110111
reserve: CustomReserveSlotCallback,
112+
cancel_reserve: CustomCancelReserveCallback,
111113
try_reserve: CustomTryReserveSlotCallback,
112114
mark_used: CustomMarkSlotUsedCallback,
113115
release: CustomReleaseSlotCallback,
@@ -139,6 +141,8 @@ pub struct SlotReserveCtx {
139141
worker_identity: ByteArrayRef,
140142
worker_build_id: ByteArrayRef,
141143
is_sticky: bool,
144+
// The C# side will store a pointer here to the cancellation token source
145+
token_src: *mut libc::c_void,
142146
}
143147

144148
#[repr(C)]
@@ -169,6 +173,21 @@ pub struct SlotReleaseCtx {
169173
slot_permit: usize,
170174
}
171175

176+
struct CancelReserveGuard {
177+
token_src: *mut libc::c_void,
178+
callback: CustomCancelReserveCallback,
179+
}
180+
impl Drop for CancelReserveGuard {
181+
fn drop(&mut self) {
182+
if !self.token_src.is_null() {
183+
unsafe {
184+
(self.callback)(self.token_src);
185+
}
186+
}
187+
}
188+
}
189+
unsafe impl Send for CancelReserveGuard {}
190+
172191
#[async_trait::async_trait]
173192
impl<SK: SlotKind + Send + Sync> temporal_sdk_core_api::worker::SlotSupplier
174193
for CustomSlotSupplier<SK>
@@ -180,10 +199,13 @@ impl<SK: SlotKind + Send + Sync> temporal_sdk_core_api::worker::SlotSupplier
180199
let ctx = Self::convert_reserve_ctx(ctx);
181200
let tx = Box::into_raw(Box::new(tx)) as *mut libc::c_void;
182201
unsafe {
202+
let _drop_guard = CancelReserveGuard {
203+
token_src: ctx.token_src,
204+
callback: (*self.inner.0).cancel_reserve,
205+
};
183206
((*self.inner.0).reserve)(ctx, tx);
207+
rx.await.expect("reserve channel is not closed")
184208
}
185-
let r = rx.await.expect("reserve channel is not closed");
186-
r
187209
}
188210

189211
fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option<SlotSupplierPermit> {
@@ -244,6 +266,7 @@ impl<SK: SlotKind + Send + Sync> CustomSlotSupplier<SK> {
244266
worker_identity: ctx.worker_identity().into(),
245267
worker_build_id: ctx.worker_build_id().into(),
246268
is_sticky: ctx.is_sticky(),
269+
token_src: std::ptr::null_mut(),
247270
}
248271
}
249272

@@ -738,6 +761,16 @@ pub extern "C" fn complete_async_reserve(sender: *mut libc::c_void, permit_id: u
738761
}
739762
}
740763

764+
#[no_mangle]
765+
pub unsafe extern "C" fn set_reserve_cancel_target(
766+
ctx: *mut SlotReserveCtx,
767+
token_ptr: *mut libc::c_void,
768+
) {
769+
if let Some(ctx) = ctx.as_mut() {
770+
ctx.token_src = token_ptr;
771+
}
772+
}
773+
741774
impl TryFrom<&WorkerOptions> for temporal_sdk_core::WorkerConfig {
742775
type Error = anyhow::Error;
743776

src/Temporalio/Worker/Tuning/ICustomSlotSupplier.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Threading;
23
using System.Threading.Tasks;
34

45
namespace Temporalio.Worker.Tuning
@@ -22,9 +23,11 @@ public interface ICustomSlotSupplier : ISlotSupplier
2223
/// This method will be called concurrently from multiple threads, so it must be thread-safe.
2324
/// </remarks>
2425
/// <param name="ctx">The context for slot reservation.</param>
26+
/// <param name="cancellationToken">A cancellation token that the SDK may use
27+
/// to cancel the operation.</param>
2528
/// <returns>A permit to use the slot which may be populated with your own data.</returns>
2629
/// <exception cref="OperationCanceledException">Cancellation requested.</exception>
27-
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx);
30+
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken);
2831

2932
/// <summary>
3033
/// This function is called when trying to reserve slots for "eager" workflow and activity tasks.

tests/Temporalio.Tests/Worker/WorkerTuningTests.cs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ private class MySlotSupplier : ICustomSlotSupplier
156156

157157
public HashSet<bool> SeenReleaseInfoPresence { get; } = new();
158158

159-
public async Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx)
159+
public async Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken)
160160
{
161161
// Do something async to make sure that works
162-
await Task.Delay(10);
162+
await Task.Delay(10, cancellationToken);
163163
ReserveTracking(ctx);
164164
return new MyPermit(ReserveCount);
165165
}
@@ -243,7 +243,7 @@ await Env.Client.ExecuteWorkflowAsync(
243243

244244
private class ThrowingSlotSupplier : ICustomSlotSupplier
245245
{
246-
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx)
246+
public Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken)
247247
{
248248
// Let the workflow complete, but other reservations fail
249249
if (ctx.SlotType == SlotType.Workflow)
@@ -286,4 +286,43 @@ await Env.Client.ExecuteWorkflowAsync(
286286
new(id: $"workflow-{Guid.NewGuid()}", taskQueue: worker.Options.TaskQueue!));
287287
});
288288
}
289+
290+
private class BlockingSlotSupplier : ICustomSlotSupplier
291+
{
292+
public async Task<ISlotPermit> ReserveSlotAsync(SlotReserveContext ctx, CancellationToken cancellationToken)
293+
{
294+
await Task.Delay(100_000, cancellationToken);
295+
cancellationToken.ThrowIfCancellationRequested();
296+
throw new InvalidOperationException("Should not be reachable");
297+
}
298+
299+
public ISlotPermit? TryReserveSlot(SlotReserveContext ctx)
300+
{
301+
return null;
302+
}
303+
304+
public void MarkSlotUsed(SlotMarkUsedContext ctx)
305+
{
306+
}
307+
308+
public void ReleaseSlot(SlotReleaseContext ctx)
309+
{
310+
}
311+
}
312+
313+
[Fact]
314+
public async Task CanRunWith_BlockingSlotSupplier()
315+
{
316+
var mySlotSupplier = new BlockingSlotSupplier();
317+
using var worker = new TemporalWorker(
318+
Client,
319+
new TemporalWorkerOptions($"tq-{Guid.NewGuid()}")
320+
{
321+
Tuner = new WorkerTuner(mySlotSupplier, mySlotSupplier, mySlotSupplier),
322+
}.AddWorkflow<OneTaskWf>());
323+
await worker.ExecuteAsync(async () =>
324+
{
325+
await Task.Delay(1000);
326+
});
327+
}
289328
}

0 commit comments

Comments
 (0)