Skip to content

Commit 14c12af

Browse files
committed
Review fixes
1 parent 478324f commit 14c12af

19 files changed

+337
-285
lines changed
Lines changed: 123 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
34
using System.Runtime.InteropServices;
45
using System.Threading;
56
using System.Threading.Tasks;
67
using Microsoft.Extensions.Logging;
7-
using Temporalio.Bridge.Interop;
88

99
namespace Temporalio.Bridge
1010
{
@@ -14,8 +14,8 @@ namespace Temporalio.Bridge
1414
internal class CustomSlotSupplier : NativeInvokeableClass<Interop.CustomSlotSupplierCallbacks>
1515
{
1616
private readonly ILogger logger;
17-
private readonly Temporalio.Worker.Tuning.ICustomSlotSupplier userSupplier;
18-
private readonly Dictionary<uint, Temporalio.Worker.Tuning.ISlotPermit> permits = new();
17+
private readonly Temporalio.Worker.Tuning.CustomSlotSupplier userSupplier;
18+
private readonly Dictionary<uint, Temporalio.Worker.Tuning.SlotPermit> permits = new();
1919
private uint permitId = 1;
2020

2121
/// <summary>
@@ -24,7 +24,7 @@ internal class CustomSlotSupplier : NativeInvokeableClass<Interop.CustomSlotSupp
2424
/// <param name="userSupplier">User's slot supplier implementation'.</param>
2525
/// <param name="loggerFactory">Logger factory.</param>
2626
internal unsafe CustomSlotSupplier(
27-
Temporalio.Worker.Tuning.ICustomSlotSupplier userSupplier,
27+
Temporalio.Worker.Tuning.CustomSlotSupplier userSupplier,
2828
ILoggerFactory loggerFactory)
2929
{
3030
this.logger = loggerFactory.CreateLogger<CustomSlotSupplier>();
@@ -37,36 +37,36 @@ internal unsafe CustomSlotSupplier(
3737
try_reserve = FunctionPointer<Interop.CustomTryReserveSlotCallback>(TryReserve),
3838
mark_used = FunctionPointer<Interop.CustomMarkSlotUsedCallback>(MarkUsed),
3939
release = FunctionPointer<Interop.CustomReleaseSlotCallback>(Release),
40+
free = FunctionPointer<Interop.CustomSlotImplFreeCallback>(Free),
4041
};
4142

4243
PinCallbackHolder(interopCallbacks);
4344
}
4445

45-
private static void SetCancelTokenOnCtx(ref SlotReserveCtx ctx, CancellationTokenSource cancelTokenSrc)
46+
private static Temporalio.Worker.Tuning.SlotInfo SlotInfoFromBridge(Interop.SlotInfo slotInfo)
4647
{
47-
unsafe
48+
return slotInfo.tag switch
4849
{
49-
try
50-
{
51-
var handle = GCHandle.Alloc(cancelTokenSrc);
52-
fixed (Interop.SlotReserveCtx* p = &ctx)
53-
{
54-
Interop.Methods.set_reserve_cancel_target(p, GCHandle.ToIntPtr(handle).ToPointer());
55-
}
56-
}
57-
catch (Exception e)
58-
{
59-
Console.WriteLine($"Error setting cancel token on ctx: {e}");
60-
throw;
61-
}
62-
}
50+
Interop.SlotInfo_Tag.WorkflowSlotInfo =>
51+
new Temporalio.Worker.Tuning.SlotInfo.WorkflowSlotInfo(
52+
ByteArrayRef.ToUtf8(slotInfo.workflow_slot_info.workflow_type), slotInfo.workflow_slot_info.is_sticky != 0),
53+
Interop.SlotInfo_Tag.ActivitySlotInfo =>
54+
new Temporalio.Worker.Tuning.SlotInfo.ActivitySlotInfo(
55+
ByteArrayRef.ToUtf8(slotInfo.activity_slot_info.activity_type)),
56+
Interop.SlotInfo_Tag.LocalActivitySlotInfo =>
57+
new Temporalio.Worker.Tuning.SlotInfo.LocalActivitySlotInfo(
58+
ByteArrayRef.ToUtf8(slotInfo.local_activity_slot_info.activity_type)),
59+
_ => throw new System.ArgumentOutOfRangeException(nameof(slotInfo)),
60+
};
6361
}
6462

65-
private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
63+
private unsafe void Reserve(Interop.SlotReserveCtx* ctx, void* sender)
6664
{
67-
SafeReserve(ctx, new IntPtr(sender));
65+
SafeReserve(new IntPtr(ctx), new IntPtr(sender));
6866
}
6967

68+
// Note that this is always called by Rust, either because the call is cancelled or because
69+
// it completed. Therefore the GCHandle is always freed.
7070
private unsafe void CancelReserve(void* tokenSrc)
7171
{
7272
var handle = GCHandle.FromIntPtr(new IntPtr(tokenSrc));
@@ -75,49 +75,66 @@ private unsafe void CancelReserve(void* tokenSrc)
7575
handle.Free();
7676
}
7777

78-
private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
78+
private void SafeReserve(IntPtr ctx, IntPtr sender)
7979
{
80-
var reserveTask = Task.Run(async () =>
80+
_ = Task.Run(async () =>
8181
{
82-
var cancelTokenSrc = new System.Threading.CancellationTokenSource();
83-
SetCancelTokenOnCtx(ref ctx, cancelTokenSrc);
84-
while (true)
82+
using (var cancelTokenSrc = new System.Threading.CancellationTokenSource())
8583
{
86-
try
84+
unsafe
8785
{
88-
var permit = await userSupplier.ReserveSlotAsync(
89-
new(ctx), cancelTokenSrc.Token).ConfigureAwait(false);
90-
var usedPermitId = AddPermitToMap(permit);
91-
unsafe
92-
{
93-
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
94-
}
95-
cancelTokenSrc.Dispose();
96-
return;
86+
var srcHandle = GCHandle.Alloc(cancelTokenSrc);
87+
Interop.Methods.set_reserve_cancel_target(
88+
(Interop.SlotReserveCtx*)ctx.ToPointer(),
89+
GCHandle.ToIntPtr(srcHandle).ToPointer());
9790
}
98-
catch (OperationCanceledException)
91+
while (true)
9992
{
100-
cancelTokenSrc.Dispose();
101-
return;
102-
}
93+
try
94+
{
95+
ConfiguredTaskAwaitable<Temporalio.Worker.Tuning.SlotPermit> reserveTask;
96+
unsafe
97+
{
98+
reserveTask = userSupplier.ReserveSlotAsync(
99+
ReserveCtxFromBridge((Interop.SlotReserveCtx*)ctx.ToPointer()),
100+
cancelTokenSrc.Token).ConfigureAwait(false);
101+
}
102+
var permit = await reserveTask;
103+
unsafe
104+
{
105+
var usedPermitId = AddPermitToMap(permit);
106+
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
107+
}
108+
return;
109+
}
110+
catch (OperationCanceledException) when (cancelTokenSrc.Token.IsCancellationRequested)
111+
{
112+
unsafe
113+
{
114+
// Always call this to ensure the sender is freed
115+
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(0));
116+
}
117+
return;
118+
}
103119
#pragma warning disable CA1031 // We are ok catching all exceptions here
104-
catch (Exception e)
105-
{
120+
catch (Exception e)
121+
{
106122
#pragma warning restore CA1031
107-
logger.LogError(e, "Error reserving slot");
123+
logger.LogError(e, "Error reserving slot");
124+
}
125+
// Wait for a bit to avoid spamming errors
126+
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
108127
}
109-
// Wait for a bit to avoid spamming errors
110-
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
111128
}
112129
});
113130
}
114131

115-
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
132+
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx* ctx)
116133
{
117-
Temporalio.Worker.Tuning.ISlotPermit? maybePermit;
134+
Temporalio.Worker.Tuning.SlotPermit? maybePermit;
118135
try
119136
{
120-
maybePermit = userSupplier.TryReserveSlot(new(ctx));
137+
maybePermit = userSupplier.TryReserveSlot(ReserveCtxFromBridge(ctx));
121138
}
122139
#pragma warning disable CA1031 // We are ok catching all exceptions here
123140
catch (Exception e)
@@ -135,11 +152,16 @@ private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
135152
return new(usedPermitId);
136153
}
137154

138-
private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
155+
private unsafe void MarkUsed(Interop.SlotMarkUsedCtx* ctx)
139156
{
140157
try
141158
{
142-
userSupplier.MarkSlotUsed(new(ctx, permits[ctx.slot_permit.ToUInt32()]));
159+
Temporalio.Worker.Tuning.SlotPermit permit;
160+
lock (permits)
161+
{
162+
permit = permits[(*ctx).slot_permit.ToUInt32()];
163+
}
164+
userSupplier.MarkSlotUsed(MarkUsedCtxFromBridge(ctx, permit));
143165
}
144166
#pragma warning disable CA1031 // We are ok catching all exceptions here
145167
catch (Exception e)
@@ -149,23 +171,34 @@ private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
149171
}
150172
}
151173

152-
private void Release(Interop.SlotReleaseCtx ctx)
174+
private unsafe void Release(Interop.SlotReleaseCtx* ctx)
153175
{
154-
var permitId = ctx.slot_permit.ToUInt32();
176+
var permitId = (*ctx).slot_permit.ToUInt32();
177+
Temporalio.Worker.Tuning.SlotPermit permit;
178+
lock (permits)
179+
{
180+
permit = permits[permitId];
181+
}
155182
try
156183
{
157-
userSupplier.ReleaseSlot(new(ctx, permits[permitId]));
184+
userSupplier.ReleaseSlot(ReleaseCtxFromBridge(ctx, permit));
158185
}
159186
#pragma warning disable CA1031 // We are ok catching all exceptions here
160187
catch (Exception e)
161188
{
162189
#pragma warning restore CA1031
163190
logger.LogError(e, "Error releasing slot");
164191
}
165-
permits.Remove(permitId);
192+
finally
193+
{
194+
lock (permits)
195+
{
196+
permits.Remove(permitId);
197+
}
198+
}
166199
}
167200

168-
private uint AddPermitToMap(Temporalio.Worker.Tuning.ISlotPermit permit)
201+
private uint AddPermitToMap(Temporalio.Worker.Tuning.SlotPermit permit)
169202
{
170203
lock (permits)
171204
{
@@ -175,5 +208,39 @@ private uint AddPermitToMap(Temporalio.Worker.Tuning.ISlotPermit permit)
175208
return usedPermitId;
176209
}
177210
}
211+
212+
private unsafe Temporalio.Worker.Tuning.SlotReserveContext ReserveCtxFromBridge(Interop.SlotReserveCtx* ctx)
213+
{
214+
return new(
215+
SlotType: (*ctx).slot_type switch
216+
{
217+
Interop.SlotKindType.WorkflowSlotKindType => Temporalio.Worker.Tuning.SlotType.Workflow,
218+
Interop.SlotKindType.ActivitySlotKindType => Temporalio.Worker.Tuning.SlotType.Activity,
219+
Interop.SlotKindType.LocalActivitySlotKindType => Temporalio.Worker.Tuning.SlotType.LocalActivity,
220+
_ => throw new System.ArgumentOutOfRangeException(nameof(ctx)),
221+
},
222+
TaskQueue: ByteArrayRef.ToUtf8((*ctx).task_queue),
223+
WorkerIdentity: ByteArrayRef.ToUtf8((*ctx).worker_identity),
224+
WorkerBuildId: ByteArrayRef.ToUtf8((*ctx).worker_build_id),
225+
IsSticky: (*ctx).is_sticky != 0);
226+
}
227+
228+
private unsafe Temporalio.Worker.Tuning.SlotReleaseContext ReleaseCtxFromBridge(
229+
Interop.SlotReleaseCtx* ctx,
230+
Temporalio.Worker.Tuning.SlotPermit permit)
231+
{
232+
return new(
233+
SlotInfo: (*ctx).slot_info is null ? null : SlotInfoFromBridge(*(*ctx).slot_info),
234+
Permit: permit);
235+
}
236+
237+
private unsafe Temporalio.Worker.Tuning.SlotMarkUsedContext MarkUsedCtxFromBridge(
238+
Interop.SlotMarkUsedCtx* ctx,
239+
Temporalio.Worker.Tuning.SlotPermit permit)
240+
{
241+
return new(
242+
SlotInfo: SlotInfoFromBridge((*ctx).slot_info),
243+
Permit: permit);
244+
}
178245
}
179246
}

src/Temporalio/Bridge/Interop/Interop.cs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,14 @@ internal unsafe partial struct SlotReserveCtx
579579
}
580580

581581
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
582-
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx, void* sender);
582+
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx, void* sender);
583583

584584
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
585585
internal unsafe delegate void CustomCancelReserveCallback(void* token_source);
586586

587587
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
588588
[return: NativeTypeName("uintptr_t")]
589-
internal delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
589+
internal unsafe delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx);
590590

591591
internal enum SlotInfo_Tag
592592
{
@@ -680,7 +680,7 @@ internal partial struct SlotMarkUsedCtx
680680
}
681681

682682
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
683-
internal delegate void CustomMarkSlotUsedCallback([NativeTypeName("struct SlotMarkUsedCtx")] SlotMarkUsedCtx ctx);
683+
internal unsafe delegate void CustomMarkSlotUsedCallback([NativeTypeName("const struct SlotMarkUsedCtx *")] SlotMarkUsedCtx* ctx);
684684

685685
internal unsafe partial struct SlotReleaseCtx
686686
{
@@ -692,7 +692,10 @@ internal unsafe partial struct SlotReleaseCtx
692692
}
693693

694694
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
695-
internal delegate void CustomReleaseSlotCallback([NativeTypeName("struct SlotReleaseCtx")] SlotReleaseCtx ctx);
695+
internal unsafe delegate void CustomReleaseSlotCallback([NativeTypeName("const struct SlotReleaseCtx *")] SlotReleaseCtx* ctx);
696+
697+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
698+
internal unsafe delegate void CustomSlotImplFreeCallback([NativeTypeName("const struct CustomSlotSupplierCallbacks *")] CustomSlotSupplierCallbacks* userimpl);
696699

697700
internal partial struct CustomSlotSupplierCallbacks
698701
{
@@ -710,6 +713,9 @@ internal partial struct CustomSlotSupplierCallbacks
710713

711714
[NativeTypeName("CustomReleaseSlotCallback")]
712715
public IntPtr release;
716+
717+
[NativeTypeName("CustomSlotImplFreeCallback")]
718+
public IntPtr free;
713719
}
714720

715721
internal unsafe partial struct CustomSlotSupplierCallbacksImpl
@@ -729,7 +735,7 @@ internal unsafe partial struct SlotSupplier
729735
{
730736
public SlotSupplier_Tag tag;
731737

732-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C3")]
738+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L475_C3")]
733739
public _Anonymous_e__Union Anonymous;
734740

735741
internal ref FixedSizeSlotSupplier fixed_size
@@ -769,15 +775,15 @@ internal ref CustomSlotSupplierCallbacksImpl custom
769775
internal unsafe partial struct _Anonymous_e__Union
770776
{
771777
[FieldOffset(0)]
772-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L473_C5")]
778+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
773779
public _Anonymous1_e__Struct Anonymous1;
774780

775781
[FieldOffset(0)]
776-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
782+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
777783
public _Anonymous2_e__Struct Anonymous2;
778784

779785
[FieldOffset(0)]
780-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
786+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L482_C5")]
781787
public _Anonymous3_e__Struct Anonymous3;
782788

783789
internal partial struct _Anonymous1_e__Struct

src/Temporalio/Bridge/NativeInvokeableClass.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Temporalio.Bridge
88
/// Extend this class to help with making a class that has callbacks which are invoked by Rust.
99
/// </summary>
1010
/// <typeparam name="T">The native type that holds the function ptrs for callbacks to C#.</typeparam>
11-
internal class NativeInvokeableClass<T>
11+
internal abstract class NativeInvokeableClass<T>
1212
where T : unmanaged
1313
{
1414
private readonly List<GCHandle> handles = new();
@@ -23,7 +23,7 @@ internal class NativeInvokeableClass<T>
2323
/// the callbacks via <see cref="FunctionPointer"/>. Also adds `this` to the handle list.
2424
/// </summary>
2525
/// <param name="value">The native type to pin.</param>
26-
internal void PinCallbackHolder(T value)
26+
private protected void PinCallbackHolder(T value)
2727
{
2828
// Pin the callback holder & set it as the first handle
2929
var holderHandle = GCHandle.Alloc(value, GCHandleType.Pinned);
@@ -42,7 +42,7 @@ internal void PinCallbackHolder(T value)
4242
/// <typeparam name="TF">The native type of the function pointer.</typeparam>
4343
/// <param name="func">The C# method to use for the callback.</param>
4444
/// <returns>The function pointer to the C# method.</returns>
45-
internal IntPtr FunctionPointer<TF>(TF func)
45+
private protected IntPtr FunctionPointer<TF>(TF func)
4646
where TF : Delegate
4747
{
4848
var handle = GCHandle.Alloc(func);
@@ -53,8 +53,8 @@ internal IntPtr FunctionPointer<TF>(TF func)
5353
/// <summary>
5454
/// Free the memory of the native type and all the function pointers.
5555
/// </summary>
56-
/// <param name="meter">The native type to free.</param>
57-
internal unsafe void Free(T* meter)
56+
/// <param name="ptr">The native type to free.</param>
57+
private protected unsafe void Free(T* ptr)
5858
{
5959
// Free in order which frees function pointers first then object handles
6060
foreach (var handle in handles)

0 commit comments

Comments
 (0)