Skip to content

Commit 008aa55

Browse files
committed
Review fixes
1 parent 478324f commit 008aa55

File tree

12 files changed

+238
-225
lines changed

12 files changed

+238
-225
lines changed
Lines changed: 103 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
34
using System.Runtime.InteropServices;
45
using System.Threading;
56
using System.Threading.Tasks;
67
using Microsoft.Extensions.Logging;
7-
using Temporalio.Bridge.Interop;
88

99
namespace Temporalio.Bridge
1010
{
@@ -15,7 +15,7 @@ internal class CustomSlotSupplier : NativeInvokeableClass<Interop.CustomSlotSupp
1515
{
1616
private readonly ILogger logger;
1717
private readonly Temporalio.Worker.Tuning.ICustomSlotSupplier userSupplier;
18-
private readonly Dictionary<uint, Temporalio.Worker.Tuning.ISlotPermit> permits = new();
18+
private readonly Dictionary<uint, Temporalio.Worker.Tuning.SlotPermit> permits = new();
1919
private uint permitId = 1;
2020

2121
/// <summary>
@@ -37,36 +37,19 @@ internal unsafe CustomSlotSupplier(
3737
try_reserve = FunctionPointer<Interop.CustomTryReserveSlotCallback>(TryReserve),
3838
mark_used = FunctionPointer<Interop.CustomMarkSlotUsedCallback>(MarkUsed),
3939
release = FunctionPointer<Interop.CustomReleaseSlotCallback>(Release),
40+
free = FunctionPointer<Interop.CustomSlotImplFreeCallback>(Free),
4041
};
4142

4243
PinCallbackHolder(interopCallbacks);
4344
}
4445

45-
private static void SetCancelTokenOnCtx(ref SlotReserveCtx ctx, CancellationTokenSource cancelTokenSrc)
46+
private unsafe void Reserve(Interop.SlotReserveCtx* ctx, void* sender)
4647
{
47-
unsafe
48-
{
49-
try
50-
{
51-
var handle = GCHandle.Alloc(cancelTokenSrc);
52-
fixed (Interop.SlotReserveCtx* p = &ctx)
53-
{
54-
Interop.Methods.set_reserve_cancel_target(p, GCHandle.ToIntPtr(handle).ToPointer());
55-
}
56-
}
57-
catch (Exception e)
58-
{
59-
Console.WriteLine($"Error setting cancel token on ctx: {e}");
60-
throw;
61-
}
62-
}
63-
}
64-
65-
private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
66-
{
67-
SafeReserve(ctx, new IntPtr(sender));
48+
SafeReserve(new IntPtr(ctx), new IntPtr(sender));
6849
}
6950

51+
// Note that this is always called by Rust, either because the call is cancelled or because
52+
// it completed. Therefore the GCHandle is always freed.
7053
private unsafe void CancelReserve(void* tokenSrc)
7154
{
7255
var handle = GCHandle.FromIntPtr(new IntPtr(tokenSrc));
@@ -75,49 +58,61 @@ private unsafe void CancelReserve(void* tokenSrc)
7558
handle.Free();
7659
}
7760

78-
private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
61+
private void SafeReserve(IntPtr ctx, IntPtr sender)
7962
{
80-
var reserveTask = Task.Run(async () =>
63+
_ = Task.Run(async () =>
8164
{
82-
var cancelTokenSrc = new System.Threading.CancellationTokenSource();
83-
SetCancelTokenOnCtx(ref ctx, cancelTokenSrc);
84-
while (true)
65+
using (var cancelTokenSrc = new System.Threading.CancellationTokenSource())
8566
{
86-
try
67+
unsafe
8768
{
88-
var permit = await userSupplier.ReserveSlotAsync(
89-
new(ctx), cancelTokenSrc.Token).ConfigureAwait(false);
90-
var usedPermitId = AddPermitToMap(permit);
91-
unsafe
92-
{
93-
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
94-
}
95-
cancelTokenSrc.Dispose();
96-
return;
69+
var srcHandle = GCHandle.Alloc(cancelTokenSrc);
70+
Interop.Methods.set_reserve_cancel_target(
71+
(Interop.SlotReserveCtx*)ctx.ToPointer(),
72+
GCHandle.ToIntPtr(srcHandle).ToPointer());
9773
}
98-
catch (OperationCanceledException)
74+
while (true)
9975
{
100-
cancelTokenSrc.Dispose();
101-
return;
102-
}
76+
try
77+
{
78+
ConfiguredTaskAwaitable<Temporalio.Worker.Tuning.SlotPermit> reserveTask;
79+
unsafe
80+
{
81+
reserveTask = userSupplier.ReserveSlotAsync(
82+
ReserveCtxFromBridge((Interop.SlotReserveCtx*)ctx.ToPointer()),
83+
cancelTokenSrc.Token).ConfigureAwait(false);
84+
}
85+
var permit = await reserveTask;
86+
unsafe
87+
{
88+
var usedPermitId = AddPermitToMap(permit);
89+
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
90+
}
91+
return;
92+
}
93+
catch (OperationCanceledException)
94+
{
95+
return;
96+
}
10397
#pragma warning disable CA1031 // We are ok catching all exceptions here
104-
catch (Exception e)
105-
{
98+
catch (Exception e)
99+
{
106100
#pragma warning restore CA1031
107-
logger.LogError(e, "Error reserving slot");
101+
logger.LogError(e, "Error reserving slot");
102+
}
103+
// Wait for a bit to avoid spamming errors
104+
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
108105
}
109-
// Wait for a bit to avoid spamming errors
110-
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
111106
}
112107
});
113108
}
114109

115-
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
110+
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx* ctx)
116111
{
117-
Temporalio.Worker.Tuning.ISlotPermit? maybePermit;
112+
Temporalio.Worker.Tuning.SlotPermit? maybePermit;
118113
try
119114
{
120-
maybePermit = userSupplier.TryReserveSlot(new(ctx));
115+
maybePermit = userSupplier.TryReserveSlot(ReserveCtxFromBridge(ctx));
121116
}
122117
#pragma warning disable CA1031 // We are ok catching all exceptions here
123118
catch (Exception e)
@@ -135,11 +130,16 @@ private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
135130
return new(usedPermitId);
136131
}
137132

138-
private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
133+
private unsafe void MarkUsed(Interop.SlotMarkUsedCtx* ctx)
139134
{
140135
try
141136
{
142-
userSupplier.MarkSlotUsed(new(ctx, permits[ctx.slot_permit.ToUInt32()]));
137+
Temporalio.Worker.Tuning.SlotPermit permit;
138+
lock (permits)
139+
{
140+
permit = permits[(*ctx).slot_permit.ToUInt32()];
141+
}
142+
userSupplier.MarkSlotUsed(MarkUsedCtxFromBridge(ctx, permit));
143143
}
144144
#pragma warning disable CA1031 // We are ok catching all exceptions here
145145
catch (Exception e)
@@ -149,23 +149,34 @@ private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
149149
}
150150
}
151151

152-
private void Release(Interop.SlotReleaseCtx ctx)
152+
private unsafe void Release(Interop.SlotReleaseCtx* ctx)
153153
{
154-
var permitId = ctx.slot_permit.ToUInt32();
154+
var permitId = (*ctx).slot_permit.ToUInt32();
155+
Temporalio.Worker.Tuning.SlotPermit permit;
156+
lock (permits)
157+
{
158+
permit = permits[permitId];
159+
}
155160
try
156161
{
157-
userSupplier.ReleaseSlot(new(ctx, permits[permitId]));
162+
userSupplier.ReleaseSlot(ReleaseCtxFromBridge(ctx, permit));
158163
}
159164
#pragma warning disable CA1031 // We are ok catching all exceptions here
160165
catch (Exception e)
161166
{
162167
#pragma warning restore CA1031
163168
logger.LogError(e, "Error releasing slot");
164169
}
165-
permits.Remove(permitId);
170+
finally
171+
{
172+
lock (permits)
173+
{
174+
permits.Remove(permitId);
175+
}
176+
}
166177
}
167178

168-
private uint AddPermitToMap(Temporalio.Worker.Tuning.ISlotPermit permit)
179+
private uint AddPermitToMap(Temporalio.Worker.Tuning.SlotPermit permit)
169180
{
170181
lock (permits)
171182
{
@@ -175,5 +186,39 @@ private uint AddPermitToMap(Temporalio.Worker.Tuning.ISlotPermit permit)
175186
return usedPermitId;
176187
}
177188
}
189+
190+
private unsafe Temporalio.Worker.Tuning.SlotReserveContext ReserveCtxFromBridge(Interop.SlotReserveCtx* ctx)
191+
{
192+
return new(
193+
SlotType: (*ctx).slot_type switch
194+
{
195+
Interop.SlotKindType.WorkflowSlotKindType => Temporalio.Worker.Tuning.SlotType.Workflow,
196+
Interop.SlotKindType.ActivitySlotKindType => Temporalio.Worker.Tuning.SlotType.Activity,
197+
Interop.SlotKindType.LocalActivitySlotKindType => Temporalio.Worker.Tuning.SlotType.LocalActivity,
198+
_ => throw new System.ArgumentOutOfRangeException(nameof(ctx)),
199+
},
200+
TaskQueue: ByteArrayRef.ToUtf8((*ctx).task_queue),
201+
WorkerIdentity: ByteArrayRef.ToUtf8((*ctx).worker_identity),
202+
WorkerBuildId: ByteArrayRef.ToUtf8((*ctx).worker_build_id),
203+
IsSticky: (*ctx).is_sticky != 0);
204+
}
205+
206+
private unsafe Temporalio.Worker.Tuning.SlotReleaseContext ReleaseCtxFromBridge(
207+
Interop.SlotReleaseCtx* ctx,
208+
Temporalio.Worker.Tuning.SlotPermit permit)
209+
{
210+
return new(
211+
SlotInfo: (*ctx).slot_info is null ? null : Temporalio.Worker.Tuning.SlotInfo.FromBridge(*(*ctx).slot_info),
212+
Permit: permit);
213+
}
214+
215+
private unsafe Temporalio.Worker.Tuning.SlotMarkUsedContext MarkUsedCtxFromBridge(
216+
Interop.SlotMarkUsedCtx* ctx,
217+
Temporalio.Worker.Tuning.SlotPermit permit)
218+
{
219+
return new(
220+
SlotInfo: Temporalio.Worker.Tuning.SlotInfo.FromBridge((*ctx).slot_info),
221+
Permit: permit);
222+
}
178223
}
179224
}

src/Temporalio/Bridge/Interop/Interop.cs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,14 @@ internal unsafe partial struct SlotReserveCtx
579579
}
580580

581581
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
582-
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx, void* sender);
582+
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx, void* sender);
583583

584584
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
585585
internal unsafe delegate void CustomCancelReserveCallback(void* token_source);
586586

587587
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
588588
[return: NativeTypeName("uintptr_t")]
589-
internal delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
589+
internal unsafe delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx);
590590

591591
internal enum SlotInfo_Tag
592592
{
@@ -680,7 +680,7 @@ internal partial struct SlotMarkUsedCtx
680680
}
681681

682682
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
683-
internal delegate void CustomMarkSlotUsedCallback([NativeTypeName("struct SlotMarkUsedCtx")] SlotMarkUsedCtx ctx);
683+
internal unsafe delegate void CustomMarkSlotUsedCallback([NativeTypeName("const struct SlotMarkUsedCtx *")] SlotMarkUsedCtx* ctx);
684684

685685
internal unsafe partial struct SlotReleaseCtx
686686
{
@@ -692,7 +692,10 @@ internal unsafe partial struct SlotReleaseCtx
692692
}
693693

694694
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
695-
internal delegate void CustomReleaseSlotCallback([NativeTypeName("struct SlotReleaseCtx")] SlotReleaseCtx ctx);
695+
internal unsafe delegate void CustomReleaseSlotCallback([NativeTypeName("const struct SlotReleaseCtx *")] SlotReleaseCtx* ctx);
696+
697+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
698+
internal unsafe delegate void CustomSlotImplFreeCallback([NativeTypeName("const struct CustomSlotSupplierCallbacks *")] CustomSlotSupplierCallbacks* userimpl);
696699

697700
internal partial struct CustomSlotSupplierCallbacks
698701
{
@@ -710,6 +713,9 @@ internal partial struct CustomSlotSupplierCallbacks
710713

711714
[NativeTypeName("CustomReleaseSlotCallback")]
712715
public IntPtr release;
716+
717+
[NativeTypeName("CustomSlotImplFreeCallback")]
718+
public IntPtr free;
713719
}
714720

715721
internal unsafe partial struct CustomSlotSupplierCallbacksImpl
@@ -729,7 +735,7 @@ internal unsafe partial struct SlotSupplier
729735
{
730736
public SlotSupplier_Tag tag;
731737

732-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C3")]
738+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L475_C3")]
733739
public _Anonymous_e__Union Anonymous;
734740

735741
internal ref FixedSizeSlotSupplier fixed_size
@@ -769,15 +775,15 @@ internal ref CustomSlotSupplierCallbacksImpl custom
769775
internal unsafe partial struct _Anonymous_e__Union
770776
{
771777
[FieldOffset(0)]
772-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L473_C5")]
778+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
773779
public _Anonymous1_e__Struct Anonymous1;
774780

775781
[FieldOffset(0)]
776-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
782+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
777783
public _Anonymous2_e__Struct Anonymous2;
778784

779785
[FieldOffset(0)]
780-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
786+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L482_C5")]
781787
public _Anonymous3_e__Struct Anonymous3;
782788

783789
internal partial struct _Anonymous1_e__Struct

src/Temporalio/Bridge/NativeInvokeableClass.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Temporalio.Bridge
88
/// Extend this class to help with making a class that has callbacks which are invoked by Rust.
99
/// </summary>
1010
/// <typeparam name="T">The native type that holds the function ptrs for callbacks to C#.</typeparam>
11-
internal class NativeInvokeableClass<T>
11+
internal abstract class NativeInvokeableClass<T>
1212
where T : unmanaged
1313
{
1414
private readonly List<GCHandle> handles = new();
@@ -23,7 +23,7 @@ internal class NativeInvokeableClass<T>
2323
/// the callbacks via <see cref="FunctionPointer"/>. Also adds `this` to the handle list.
2424
/// </summary>
2525
/// <param name="value">The native type to pin.</param>
26-
internal void PinCallbackHolder(T value)
26+
private protected void PinCallbackHolder(T value)
2727
{
2828
// Pin the callback holder & set it as the first handle
2929
var holderHandle = GCHandle.Alloc(value, GCHandleType.Pinned);
@@ -42,7 +42,7 @@ internal void PinCallbackHolder(T value)
4242
/// <typeparam name="TF">The native type of the function pointer.</typeparam>
4343
/// <param name="func">The C# method to use for the callback.</param>
4444
/// <returns>The function pointer to the C# method.</returns>
45-
internal IntPtr FunctionPointer<TF>(TF func)
45+
private protected IntPtr FunctionPointer<TF>(TF func)
4646
where TF : Delegate
4747
{
4848
var handle = GCHandle.Alloc(func);
@@ -53,8 +53,8 @@ internal IntPtr FunctionPointer<TF>(TF func)
5353
/// <summary>
5454
/// Free the memory of the native type and all the function pointers.
5555
/// </summary>
56-
/// <param name="meter">The native type to free.</param>
57-
internal unsafe void Free(T* meter)
56+
/// <param name="ptr">The native type to free.</param>
57+
private protected unsafe void Free(T* ptr)
5858
{
5959
// Free in order which frees function pointers first then object handles
6060
foreach (var handle in handles)

0 commit comments

Comments
 (0)