Skip to content

Commit 89b5781

Browse files
committed
Callbacks working / test passes
1 parent e1d94b8 commit 89b5781

File tree

12 files changed

+157
-94
lines changed

12 files changed

+157
-94
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ obj/
77
/tests/golangworker/golangworker
88
/.vs
99
/.vscode
10-
/.idea
10+
/.idea
11+
/.zed
12+
Temporalio.sln.DotSettings.user

src/Temporalio/Bridge/CustomSlotSupplier.cs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
14
using System.Threading.Tasks;
25

36
namespace Temporalio.Bridge
@@ -8,6 +11,8 @@ namespace Temporalio.Bridge
811
internal class CustomSlotSupplier : NativeInvokeableClass<Interop.CustomSlotSupplierCallbacks>
912
{
1013
private readonly Temporalio.Worker.Tuning.ICustomSlotSupplier userSupplier;
14+
private readonly Dictionary<uint, GCHandle> permits = new();
15+
private uint permitId = 1;
1116

1217
/// <summary>
1318
/// Initializes a new instance of the <see cref="CustomSlotSupplier" /> class.
@@ -28,25 +33,65 @@ internal unsafe CustomSlotSupplier(Temporalio.Worker.Tuning.ICustomSlotSupplier
2833
PinCallbackHolder(interopCallbacks);
2934
}
3035

31-
private void Reserve(Interop.SlotReserveCtx ctx)
36+
private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
3237
{
33-
// TODO: Need to call callback with result that will put it in a channel to await in Rust
34-
var reserveTask = Task.Run(() => userSupplier.ReserveSlotAsync(new(ctx)));
38+
SafeReserve(ctx, new IntPtr(sender));
3539
}
3640

37-
private void TryReserve(Interop.SlotReserveCtx ctx)
41+
private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
3842
{
39-
userSupplier.TryReserveSlot(new(ctx));
43+
var reserveTask = Task.Run(async () =>
44+
{
45+
try
46+
{
47+
var permit = await userSupplier.ReserveSlotAsync(new(ctx)).ConfigureAwait(false);
48+
var usedPermitId = AddPermitToMap(permit);
49+
unsafe
50+
{
51+
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
52+
}
53+
}
54+
catch (Exception e)
55+
{
56+
Console.WriteLine("Exception in reserve: " + e.Message);
57+
throw;
58+
}
59+
});
60+
}
61+
62+
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
63+
{
64+
var maybePermit = userSupplier.TryReserveSlot(new(ctx));
65+
if (maybePermit == null)
66+
{
67+
return UIntPtr.Zero;
68+
}
69+
var usedPermitId = AddPermitToMap(maybePermit);
70+
return new(usedPermitId);
4071
}
4172

4273
private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
4374
{
44-
userSupplier.MarkSlotUsed(new(ctx));
75+
userSupplier.MarkSlotUsed(new(ctx, permits[ctx.slot_permit.ToUInt32()]));
4576
}
4677

4778
private void Release(Interop.SlotReleaseCtx ctx)
4879
{
49-
userSupplier.ReleaseSlot(new(ctx));
80+
var permitId = ctx.slot_permit.ToUInt32();
81+
userSupplier.ReleaseSlot(new(ctx, permits[permitId]));
82+
permits.Remove(permitId);
83+
}
84+
85+
private uint AddPermitToMap(Temporalio.Worker.Tuning.SlotPermit permit)
86+
{
87+
var handle = GCHandle.Alloc(permit);
88+
lock (permits)
89+
{
90+
var usedPermitId = permitId;
91+
permits.Add(permitId, handle);
92+
permitId += 1;
93+
return usedPermitId;
94+
}
5095
}
5196
}
5297
}

src/Temporalio/Bridge/Interop/Interop.cs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,11 @@ internal partial struct SlotReserveCtx
577577
}
578578

579579
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
580-
internal delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
580+
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx, void* sender);
581581

582582
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
583-
internal delegate void CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
583+
[return: NativeTypeName("uintptr_t")]
584+
internal delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
584585

585586
internal enum SlotInfo_Tag
586587
{
@@ -614,7 +615,7 @@ internal unsafe partial struct SlotInfo
614615
{
615616
public SlotInfo_Tag tag;
616617

617-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L419_C3")]
618+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L422_C3")]
618619
public _Anonymous_e__Union Anonymous;
619620

620621
internal ref WorkflowSlotInfo_Body workflow_slot_info
@@ -664,13 +665,13 @@ internal partial struct _Anonymous_e__Union
664665
}
665666
}
666667

667-
internal unsafe partial struct SlotMarkUsedCtx
668+
internal partial struct SlotMarkUsedCtx
668669
{
669670
[NativeTypeName("struct SlotInfo")]
670671
public SlotInfo slot_info;
671672

672-
[NativeTypeName("const void *")]
673-
public void* slot_permit;
673+
[NativeTypeName("uintptr_t")]
674+
public UIntPtr slot_permit;
674675
}
675676

676677
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
@@ -681,8 +682,8 @@ internal unsafe partial struct SlotReleaseCtx
681682
[NativeTypeName("const struct SlotInfo *")]
682683
public SlotInfo* slot_info;
683684

684-
[NativeTypeName("const void *")]
685-
public void* slot_permit;
685+
[NativeTypeName("uintptr_t")]
686+
public UIntPtr slot_permit;
686687
}
687688

688689
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
@@ -720,7 +721,7 @@ internal unsafe partial struct SlotSupplier
720721
{
721722
public SlotSupplier_Tag tag;
722723

723-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L465_C3")]
724+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L468_C3")]
724725
public _Anonymous_e__Union Anonymous;
725726

726727
internal ref FixedSizeSlotSupplier fixed_size
@@ -760,15 +761,15 @@ internal ref CustomSlotSupplierCallbacksImpl custom
760761
internal unsafe partial struct _Anonymous_e__Union
761762
{
762763
[FieldOffset(0)]
763-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L466_C5")]
764+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L469_C5")]
764765
public _Anonymous1_e__Struct Anonymous1;
765766

766767
[FieldOffset(0)]
767-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L469_C5")]
768+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C5")]
768769
public _Anonymous2_e__Struct Anonymous2;
769770

770771
[FieldOffset(0)]
771-
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C5")]
772+
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L475_C5")]
772773
public _Anonymous3_e__Struct Anonymous3;
773774

774775
internal partial struct _Anonymous1_e__Struct
@@ -1057,5 +1058,8 @@ internal static unsafe partial class Methods
10571058
[DllImport("temporal_sdk_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
10581059
[return: NativeTypeName("struct WorkerReplayPushResult")]
10591060
public static extern WorkerReplayPushResult worker_replay_push([NativeTypeName("struct Worker *")] Worker* worker, [NativeTypeName("struct WorkerReplayPusher *")] WorkerReplayPusher* worker_replay_pusher, [NativeTypeName("struct ByteArrayRef")] ByteArrayRef workflow_id, [NativeTypeName("struct ByteArrayRef")] ByteArrayRef history);
1061+
1062+
[DllImport("temporal_sdk_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
1063+
public static extern void complete_async_reserve(void* sender, [NativeTypeName("uintptr_t")] UIntPtr permit_id);
10601064
}
10611065
}

src/Temporalio/Bridge/include/temporal-sdk-bridge.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,12 @@ typedef struct SlotReserveCtx {
391391
bool is_sticky;
392392
} SlotReserveCtx;
393393

394-
typedef void (*CustomReserveSlotCallback)(struct SlotReserveCtx ctx);
394+
typedef void (*CustomReserveSlotCallback)(struct SlotReserveCtx ctx, void *sender);
395395

396-
typedef void (*CustomTryReserveSlotCallback)(struct SlotReserveCtx ctx);
396+
/**
397+
* Must return C#-tracked id for the permit. A zero value means no permit was reserved.
398+
*/
399+
typedef uintptr_t (*CustomTryReserveSlotCallback)(struct SlotReserveCtx ctx);
397400

398401
typedef enum SlotInfo_Tag {
399402
WorkflowSlotInfo,
@@ -426,19 +429,19 @@ typedef struct SlotInfo {
426429
typedef struct SlotMarkUsedCtx {
427430
struct SlotInfo slot_info;
428431
/**
429-
* User instance of a slot permit.
432+
* C# id for the slot permit.
430433
*/
431-
const void *slot_permit;
434+
uintptr_t slot_permit;
432435
} SlotMarkUsedCtx;
433436

434437
typedef void (*CustomMarkSlotUsedCallback)(struct SlotMarkUsedCtx ctx);
435438

436439
typedef struct SlotReleaseCtx {
437440
const struct SlotInfo *slot_info;
438441
/**
439-
* User instance of a slot permit.
442+
* C# id for the slot permit.
440443
*/
441-
const void *slot_permit;
444+
uintptr_t slot_permit;
442445
} SlotReleaseCtx;
443446

444447
typedef void (*CustomReleaseSlotCallback)(struct SlotReleaseCtx ctx);
@@ -689,6 +692,8 @@ struct WorkerReplayPushResult worker_replay_push(struct Worker *worker,
689692
struct ByteArrayRef workflow_id,
690693
struct ByteArrayRef history);
691694

695+
void complete_async_reserve(void *sender, uintptr_t permit_id);
696+
692697
#ifdef __cplusplus
693698
} // extern "C"
694699
#endif // __cplusplus

src/Temporalio/Bridge/src/lib.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,6 @@ impl ByteArrayRef {
3131
}
3232
}
3333

34-
fn from_string(s: &String) -> ByteArrayRef {
35-
ByteArrayRef {
36-
data: s.as_ptr(),
37-
size: s.len(),
38-
}
39-
}
40-
4134
fn to_slice(&self) -> &[u8] {
4235
unsafe { std::slice::from_raw_parts(self.data, self.size) }
4336
}
@@ -113,7 +106,7 @@ impl From<&str> for ByteArrayRef {
113106

114107
impl From<String> for ByteArrayRef {
115108
fn from(s: String) -> ByteArrayRef {
116-
ByteArrayRef::from_string(&s)
109+
ByteArrayRef::from_str(&s)
117110
}
118111
}
119112

src/Temporalio/Bridge/src/metric.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ impl CustomMetricMeterRef {
417417
),
418418
};
419419
CustomMetricAttribute {
420-
key: ByteArrayRef::from_string(&kv.key),
420+
key: ByteArrayRef::from_str(&kv.key),
421421
value,
422422
value_type,
423423
}

src/Temporalio/Bridge/src/runtime.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,13 @@ impl fmt::Debug for LogForwarder {
316316
#[no_mangle]
317317
pub extern "C" fn forwarded_log_target(log: *const ForwardedLog) -> ByteArrayRef {
318318
let log = unsafe { &*log };
319-
ByteArrayRef::from_string(&log.core.target)
319+
ByteArrayRef::from_str(&log.core.target)
320320
}
321321

322322
#[no_mangle]
323323
pub extern "C" fn forwarded_log_message(log: *const ForwardedLog) -> ByteArrayRef {
324324
let log = unsafe { &*log };
325-
ByteArrayRef::from_string(&log.core.message)
325+
ByteArrayRef::from_str(&log.core.message)
326326
}
327327

328328
#[no_mangle]

0 commit comments

Comments
 (0)