Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
using Xunit;
using Stride.Engine;
using Stride.Graphics.Regression;
using System.Threading.Tasks;
using System.Threading;

namespace Stride.BepuPhysics.Tests
{
Expand Down Expand Up @@ -125,6 +127,56 @@ public static void ConstraintsTest()
RunGameTest(game);
}

[Fact]
public static void ThreadContextTest()
{
var game = new GameTest();
game.Script.AddTask(async () =>
{
var thread = Thread.CurrentThread;

game.ScreenShotAutomationEnabled = false;

var e1 = new BodyComponent { Collider = new CompoundCollider { Colliders = { new BoxCollider() } } };
var e2 = new BodyComponent { Collider = new CompoundCollider { Colliders = { new BoxCollider() } } };
var e3 = new BodyComponent { Collider = new CompoundCollider { Colliders = { new BoxCollider() } } };

game.SceneSystem.SceneInstance.RootScene.Entities.AddRange(new EntityComponent[] { e1, e2, e3 }.Select(x => new Entity { x }));

await Task.Run(() =>
{
Assert.NotEqual(Thread.CurrentThread, thread);
});

Assert.Equal(Thread.CurrentThread, thread);

await e1.Simulation!.AfterUpdate();

Assert.Equal(Thread.CurrentThread, thread);

await Task.Run(() =>
{
Assert.NotEqual(Thread.CurrentThread, thread);
});

Assert.Equal(Thread.CurrentThread, thread);

await e1.Simulation!.NextUpdate();

Assert.Equal(Thread.CurrentThread, thread);

await Task.Run(() =>
{
Assert.NotEqual(Thread.CurrentThread, thread);
});

Assert.Equal(Thread.CurrentThread, thread);

game.Exit();
});
RunGameTest(game);
}

[Fact]
public static void ConstraintsForceTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Stride.Core;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Stride.Core.MicroThreading;
using Stride.Core.Serialization;
using Stride.Engine;
using NVector3 = System.Numerics.Vector3;
Expand All @@ -43,6 +44,7 @@ public sealed class BepuSimulation : IDisposable
private TimeSpan _softStartRemainingDuration;
private bool _softStartScheduled = false;
private UrlReference<Scene>? _associatedScene = null;
private Scheduler? _scheduler;
private AwaitRunner _preTickRunner = new();
private AwaitRunner _postTickRunner = new();

Expand Down Expand Up @@ -335,13 +337,23 @@ public StaticComponent GetComponent(StaticHandle handle)
/// Yields execution until right before the next physics tick
/// </summary>
/// <returns>Task that will resume next tick.</returns>
public TickAwaiter NextUpdate() => new TickAwaiter(_preTickRunner);
public TickAwaiter NextUpdate()
{
if (Scheduler.CurrentMicroThread is null || SynchronizationContext.Current is null)
throw new Exception($"{nameof(NextUpdate)} cannot be called out of the micro-thread context.");
return new TickAwaiter(_preTickRunner, Scheduler.CurrentMicroThread, SynchronizationContext.Current);
}

/// <summary>
/// Yields execution until right after the next physics tick
/// </summary>
/// <returns>Task that will resume next tick.</returns>
public TickAwaiter AfterUpdate() => new TickAwaiter(_postTickRunner);
public TickAwaiter AfterUpdate()
{
if (Scheduler.CurrentMicroThread is null || SynchronizationContext.Current is null)
throw new Exception($"{nameof(AfterUpdate)} cannot be called out of the micro-thread context.");
return new TickAwaiter(_postTickRunner, Scheduler.CurrentMicroThread, SynchronizationContext.Current);
}

/// <summary>
/// Whether a physics test with <paramref name="mask"/> against <paramref name="collidable"/> should be performed or entirely ignored
Expand Down Expand Up @@ -967,14 +979,14 @@ protected override void AfterSimulationUpdate(BepuSimulation sim, float deltaTim
internal class AwaitRunner
{
private Lock _addLock = new();
private List<Action> _scheduled = new();
private List<Action> _processed = new();
private List<(Action action, SynchronizationContext context)> _scheduled = new();
private List<(Action action, SynchronizationContext context)> _processed = new();

public void Add(Action a)
public void Add(Action action, SynchronizationContext context)
{
lock (_addLock)
{
_scheduled.Add(a);
_scheduled.Add((action, context));
}
}

Expand All @@ -985,8 +997,19 @@ public void Run()
(_processed, _scheduled) = (_scheduled, _processed);
}

foreach (var item in _processed)
item.Invoke();
foreach (var (action, context) in _processed)
{
var previousSyncContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(context);
try
{
action.Invoke();
}
finally
{
SynchronizationContext.SetSynchronizationContext(previousSyncContext);
}
}

_processed.Clear();
}
Expand All @@ -995,20 +1018,30 @@ public void Run()
/// <summary>
/// Await this struct to continue during a physics tick
/// </summary>
public struct TickAwaiter : INotifyCompletion
public readonly struct TickAwaiter : INotifyCompletion
{
private AwaitRunner _runner;
private readonly AwaitRunner _runner;
private readonly MicroThread _microThread;
private readonly SynchronizationContext _context;

internal TickAwaiter(AwaitRunner runner)
internal TickAwaiter(AwaitRunner runner, MicroThread microThread, SynchronizationContext context)
{
_runner = runner;
_microThread = microThread;
_context = context;
}

public bool IsCompleted => false; // Forces the awaiter to call OnCompleted() right away to schedule asynchronous method continuation with our runner
public bool IsCompleted
{
get
{
return _microThread.IsOver;
}
}

public void OnCompleted(Action continuation) => _runner.Add(continuation);
public void OnCompleted(Action continuation) => _runner.Add(continuation, _context);

public void GetResult() { }
public void GetResult() => _microThread.CancellationToken.ThrowIfCancellationRequested();

public TickAwaiter GetAwaiter() => this;
}
Expand Down
Loading