diff --git a/TUnit.Engine.Tests/HookTimeoutTests.cs b/TUnit.Engine.Tests/HookTimeoutTests.cs new file mode 100644 index 0000000000..818a72c2fe --- /dev/null +++ b/TUnit.Engine.Tests/HookTimeoutTests.cs @@ -0,0 +1,33 @@ +using Shouldly; +using TUnit.Engine.Tests.Enums; + +namespace TUnit.Engine.Tests; + +public class HookTimeoutTests(TestMode testMode) : InvokableTestBase(testMode) +{ + [Test] + public async Task ClassHook_WithTimeout_ShouldFail() + { + await RunTestsWithFilter( + "/*/*/ClassHookTimeoutTests/*", + [ + result => result.ResultSummary.Outcome.ShouldBe("Failed"), + result => result.ResultSummary.Counters.Total.ShouldBe(1), + result => result.ResultSummary.Counters.Passed.ShouldBe(0), + result => result.ResultSummary.Counters.Failed.ShouldBe(1), + ]); + } + + [Test] + public async Task AssemblyHook_WithTimeout_ShouldPass() + { + await RunTestsWithFilter( + "/*/*/AssemblyHookTimeoutPassTests/*", + [ + result => result.ResultSummary.Outcome.ShouldBe("Completed"), + result => result.ResultSummary.Counters.Total.ShouldBe(1), + result => result.ResultSummary.Counters.Passed.ShouldBe(1), + result => result.ResultSummary.Counters.Failed.ShouldBe(0), + ]); + } +} diff --git a/TUnit.Engine/Services/HookCollectionService.cs b/TUnit.Engine/Services/HookCollectionService.cs index 57ecd0b633..559f0bffaa 100644 --- a/TUnit.Engine/Services/HookCollectionService.cs +++ b/TUnit.Engine/Services/HookCollectionService.cs @@ -50,14 +50,14 @@ public async ValueTask InitializeAsync() // Pre-compute all global hooks that don't depend on specific types/assemblies _beforeEveryTestHooks = await BuildGlobalBeforeEveryTestHooksAsync(); _afterEveryTestHooks = await BuildGlobalAfterEveryTestHooksAsync(); - _beforeTestSessionHooks = BuildGlobalBeforeTestSessionHooks(); - _afterTestSessionHooks = BuildGlobalAfterTestSessionHooks(); - _beforeTestDiscoveryHooks = BuildGlobalBeforeTestDiscoveryHooks(); - _afterTestDiscoveryHooks = BuildGlobalAfterTestDiscoveryHooks(); - _beforeEveryClassHooks = BuildGlobalBeforeEveryClassHooks(); - _afterEveryClassHooks = BuildGlobalAfterEveryClassHooks(); - _beforeEveryAssemblyHooks = BuildGlobalBeforeEveryAssemblyHooks(); - _afterEveryAssemblyHooks = BuildGlobalAfterEveryAssemblyHooks(); + _beforeTestSessionHooks = await BuildGlobalBeforeTestSessionHooksAsync(); + _afterTestSessionHooks = await BuildGlobalAfterTestSessionHooksAsync(); + _beforeTestDiscoveryHooks = await BuildGlobalBeforeTestDiscoveryHooksAsync(); + _afterTestDiscoveryHooks = await BuildGlobalAfterTestDiscoveryHooksAsync(); + _beforeEveryClassHooks = await BuildGlobalBeforeEveryClassHooksAsync(); + _afterEveryClassHooks = await BuildGlobalAfterEveryClassHooksAsync(); + _beforeEveryAssemblyHooks = await BuildGlobalBeforeEveryAssemblyHooksAsync(); + _afterEveryAssemblyHooks = await BuildGlobalAfterEveryAssemblyHooksAsync(); } private async Task>> BuildGlobalBeforeEveryTestHooksAsync() @@ -94,13 +94,13 @@ private async Task>> Bu .ToList(); } - private IReadOnlyList> BuildGlobalBeforeTestSessionHooks() + private async Task>> BuildGlobalBeforeTestSessionHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.BeforeTestSessionHooks.Count); foreach (var hook in Sources.BeforeTestSessionHooks) { - var hookFunc = CreateTestSessionHookDelegate(hook); + var hookFunc = await CreateTestSessionHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -111,13 +111,13 @@ private IReadOnlyList> BuildGl .ToList(); } - private IReadOnlyList> BuildGlobalAfterTestSessionHooks() + private async Task>> BuildGlobalAfterTestSessionHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.AfterTestSessionHooks.Count); foreach (var hook in Sources.AfterTestSessionHooks) { - var hookFunc = CreateTestSessionHookDelegate(hook); + var hookFunc = await CreateTestSessionHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -128,13 +128,13 @@ private IReadOnlyList> BuildGl .ToList(); } - private IReadOnlyList> BuildGlobalBeforeTestDiscoveryHooks() + private async Task>> BuildGlobalBeforeTestDiscoveryHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.BeforeTestDiscoveryHooks.Count); foreach (var hook in Sources.BeforeTestDiscoveryHooks) { - var hookFunc = CreateBeforeTestDiscoveryHookDelegate(hook); + var hookFunc = await CreateBeforeTestDiscoveryHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -145,13 +145,13 @@ private IReadOnlyList> .ToList(); } - private IReadOnlyList> BuildGlobalAfterTestDiscoveryHooks() + private async Task>> BuildGlobalAfterTestDiscoveryHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.AfterTestDiscoveryHooks.Count); foreach (var hook in Sources.AfterTestDiscoveryHooks) { - var hookFunc = CreateTestDiscoveryHookDelegate(hook); + var hookFunc = await CreateTestDiscoveryHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -162,13 +162,13 @@ private IReadOnlyList> Build .ToList(); } - private IReadOnlyList> BuildGlobalBeforeEveryClassHooks() + private async Task>> BuildGlobalBeforeEveryClassHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.BeforeEveryClassHooks.Count); foreach (var hook in Sources.BeforeEveryClassHooks) { - var hookFunc = CreateClassHookDelegate(hook); + var hookFunc = await CreateClassHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -179,13 +179,13 @@ private IReadOnlyList> BuildGlob .ToList(); } - private IReadOnlyList> BuildGlobalAfterEveryClassHooks() + private async Task>> BuildGlobalAfterEveryClassHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.AfterEveryClassHooks.Count); foreach (var hook in Sources.AfterEveryClassHooks) { - var hookFunc = CreateClassHookDelegate(hook); + var hookFunc = await CreateClassHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -196,13 +196,13 @@ private IReadOnlyList> BuildGlob .ToList(); } - private IReadOnlyList> BuildGlobalBeforeEveryAssemblyHooks() + private async Task>> BuildGlobalBeforeEveryAssemblyHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.BeforeEveryAssemblyHooks.Count); foreach (var hook in Sources.BeforeEveryAssemblyHooks) { - var hookFunc = CreateAssemblyHookDelegate(hook); + var hookFunc = await CreateAssemblyHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -213,13 +213,13 @@ private IReadOnlyList> BuildG .ToList(); } - private IReadOnlyList> BuildGlobalAfterEveryAssemblyHooks() + private async Task>> BuildGlobalAfterEveryAssemblyHooksAsync() { var allHooks = new List<(int order, int registrationIndex, Func hook)>(Sources.AfterEveryAssemblyHooks.Count); foreach (var hook in Sources.AfterEveryAssemblyHooks) { - var hookFunc = CreateAssemblyHookDelegate(hook); + var hookFunc = await CreateAssemblyHookDelegateAsync(hook); allHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } @@ -422,168 +422,196 @@ public ValueTask>> Coll return new ValueTask>>(_afterEveryTestHooks ?? []); } - public ValueTask>> CollectBeforeClassHooksAsync(Type testClassType) + public async ValueTask>> CollectBeforeClassHooksAsync(Type testClassType) { - var hooks = _beforeClassHooksCache.GetOrAdd(testClassType, type => + if (_beforeClassHooksCache.TryGetValue(testClassType, out var cachedHooks)) { - var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func hook)> hooks)>(); + return cachedHooks; + } - // Collect hooks for each type in the hierarchy - var currentType = type; - while (currentType != null) - { - var typeHooks = new List<(int order, int registrationIndex, Func hook)>(); + var hooks = await BuildBeforeClassHooksAsync(testClassType); + _beforeClassHooksCache.TryAdd(testClassType, hooks); + return hooks; + } + + private async Task>> BuildBeforeClassHooksAsync(Type type) + { + var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func hook)> hooks)>(); - if (Sources.BeforeClassHooks.TryGetValue(currentType, out var sourceHooks)) + // Collect hooks for each type in the hierarchy + var currentType = type; + while (currentType != null) + { + var typeHooks = new List<(int order, int registrationIndex, Func hook)>(); + + if (Sources.BeforeClassHooks.TryGetValue(currentType, out var sourceHooks)) + { + foreach (var hook in sourceHooks) { - foreach (var hook in sourceHooks) - { - var hookFunc = CreateClassHookDelegate(hook); - typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); - } + var hookFunc = await CreateClassHookDelegateAsync(hook); + typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } + } - // Also check the open generic type definition for generic types - if (currentType is { IsGenericType: true, IsGenericTypeDefinition: false }) + // Also check the open generic type definition for generic types + if (currentType is { IsGenericType: true, IsGenericTypeDefinition: false }) + { + var openGenericType = GetCachedGenericTypeDefinition(currentType); + if (Sources.BeforeClassHooks.TryGetValue(openGenericType, out var openTypeHooks)) { - var openGenericType = GetCachedGenericTypeDefinition(currentType); - if (Sources.BeforeClassHooks.TryGetValue(openGenericType, out var openTypeHooks)) + foreach (var hook in openTypeHooks) { - foreach (var hook in openTypeHooks) - { - var hookFunc = CreateClassHookDelegate(hook); - typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); - } + var hookFunc = await CreateClassHookDelegateAsync(hook); + typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } } - - if (typeHooks.Count > 0) - { - hooksByType.Add((currentType, typeHooks)); - } - - currentType = currentType.BaseType; } - hooksByType.Reverse(); - - var finalHooks = new List>(); - foreach (var (_, typeHooks) in hooksByType) + if (typeHooks.Count > 0) { - SortAndAddClassHooks(finalHooks, typeHooks); + hooksByType.Add((currentType, typeHooks)); } - return finalHooks; - }); + currentType = currentType.BaseType; + } + + hooksByType.Reverse(); + + var finalHooks = new List>(); + foreach (var (_, typeHooks) in hooksByType) + { + SortAndAddClassHooks(finalHooks, typeHooks); + } - return new ValueTask>>(hooks); + return finalHooks; } - public ValueTask>> CollectAfterClassHooksAsync(Type testClassType) + public async ValueTask>> CollectAfterClassHooksAsync(Type testClassType) { - var hooks = _afterClassHooksCache.GetOrAdd(testClassType, type => + if (_afterClassHooksCache.TryGetValue(testClassType, out var cachedHooks)) { - var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func hook)> hooks)>(); + return cachedHooks; + } - // Collect hooks for each type in the hierarchy - var currentType = type; - while (currentType != null) - { - var typeHooks = new List<(int order, int registrationIndex, Func hook)>(); + var hooks = await BuildAfterClassHooksAsync(testClassType); + _afterClassHooksCache.TryAdd(testClassType, hooks); + return hooks; + } + + private async Task>> BuildAfterClassHooksAsync(Type type) + { + var hooksByType = new List<(Type type, List<(int order, int registrationIndex, Func hook)> hooks)>(); + + // Collect hooks for each type in the hierarchy + var currentType = type; + while (currentType != null) + { + var typeHooks = new List<(int order, int registrationIndex, Func hook)>(); - if (Sources.AfterClassHooks.TryGetValue(currentType, out var sourceHooks)) + if (Sources.AfterClassHooks.TryGetValue(currentType, out var sourceHooks)) + { + foreach (var hook in sourceHooks) { - foreach (var hook in sourceHooks) - { - var hookFunc = CreateClassHookDelegate(hook); - typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); - } + var hookFunc = await CreateClassHookDelegateAsync(hook); + typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } + } - // Also check the open generic type definition for generic types - if (currentType is { IsGenericType: true, IsGenericTypeDefinition: false }) + // Also check the open generic type definition for generic types + if (currentType is { IsGenericType: true, IsGenericTypeDefinition: false }) + { + var openGenericType = GetCachedGenericTypeDefinition(currentType); + if (Sources.AfterClassHooks.TryGetValue(openGenericType, out var openTypeHooks)) { - var openGenericType = GetCachedGenericTypeDefinition(currentType); - if (Sources.AfterClassHooks.TryGetValue(openGenericType, out var openTypeHooks)) + foreach (var hook in openTypeHooks) { - foreach (var hook in openTypeHooks) - { - var hookFunc = CreateClassHookDelegate(hook); - typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); - } + var hookFunc = await CreateClassHookDelegateAsync(hook); + typeHooks.Add((hook.Order, hook.RegistrationIndex, hookFunc)); } } - - if (typeHooks.Count > 0) - { - hooksByType.Add((currentType, typeHooks)); - } - - currentType = currentType.BaseType; } - var finalHooks = new List>(); - foreach (var (_, typeHooks) in hooksByType) + if (typeHooks.Count > 0) { - SortAndAddClassHooks(finalHooks, typeHooks); + hooksByType.Add((currentType, typeHooks)); } - return finalHooks; - }); + currentType = currentType.BaseType; + } + + var finalHooks = new List>(); + foreach (var (_, typeHooks) in hooksByType) + { + SortAndAddClassHooks(finalHooks, typeHooks); + } - return new ValueTask>>(hooks); + return finalHooks; } - public ValueTask>> CollectBeforeAssemblyHooksAsync(Assembly assembly) + public async ValueTask>> CollectBeforeAssemblyHooksAsync(Assembly assembly) { - var hooks = _beforeAssemblyHooksCache.GetOrAdd(assembly, asm => + if (_beforeAssemblyHooksCache.TryGetValue(assembly, out var cachedHooks)) { - if (!Sources.BeforeAssemblyHooks.TryGetValue(asm, out var assemblyHooks)) - { - return []; - } + return cachedHooks; + } - var allHooks = new List<(int order, Func hook)>(assemblyHooks.Count); + var hooks = await BuildBeforeAssemblyHooksAsync(assembly); + _beforeAssemblyHooksCache.TryAdd(assembly, hooks); + return hooks; + } - foreach (var hook in assemblyHooks) - { - var hookFunc = CreateAssemblyHookDelegate(hook); - allHooks.Add((hook.Order, hookFunc)); - } + private async Task>> BuildBeforeAssemblyHooksAsync(Assembly assembly) + { + if (!Sources.BeforeAssemblyHooks.TryGetValue(assembly, out var assemblyHooks)) + { + return []; + } - return allHooks - .OrderBy(h => h.order) - .Select(h => h.hook) - .ToList(); - }); + var allHooks = new List<(int order, Func hook)>(assemblyHooks.Count); + + foreach (var hook in assemblyHooks) + { + var hookFunc = await CreateAssemblyHookDelegateAsync(hook); + allHooks.Add((hook.Order, hookFunc)); + } - return new ValueTask>>(hooks); + return allHooks + .OrderBy(h => h.order) + .Select(h => h.hook) + .ToList(); } - public ValueTask>> CollectAfterAssemblyHooksAsync(Assembly assembly) + public async ValueTask>> CollectAfterAssemblyHooksAsync(Assembly assembly) { - var hooks = _afterAssemblyHooksCache.GetOrAdd(assembly, asm => + if (_afterAssemblyHooksCache.TryGetValue(assembly, out var cachedHooks)) { - if (!Sources.AfterAssemblyHooks.TryGetValue(asm, out var assemblyHooks)) - { - return []; - } + return cachedHooks; + } - var allHooks = new List<(int order, Func hook)>(assemblyHooks.Count); + var hooks = await BuildAfterAssemblyHooksAsync(assembly); + _afterAssemblyHooksCache.TryAdd(assembly, hooks); + return hooks; + } - foreach (var hook in assemblyHooks) - { - var hookFunc = CreateAssemblyHookDelegate(hook); - allHooks.Add((hook.Order, hookFunc)); - } + private async Task>> BuildAfterAssemblyHooksAsync(Assembly assembly) + { + if (!Sources.AfterAssemblyHooks.TryGetValue(assembly, out var assemblyHooks)) + { + return []; + } + + var allHooks = new List<(int order, Func hook)>(assemblyHooks.Count); - return allHooks - .OrderBy(h => h.order) - .Select(h => h.hook) - .ToList(); - }); + foreach (var hook in assemblyHooks) + { + var hookFunc = await CreateAssemblyHookDelegateAsync(hook); + allHooks.Add((hook.Order, hookFunc)); + } - return new ValueTask>>(hooks); + return allHooks + .OrderBy(h => h.order) + .Select(h => h.hook) + .ToList(); } public ValueTask>> CollectBeforeTestSessionHooksAsync() @@ -700,6 +728,22 @@ private static Func CreateClassHookDe }; } + private async Task> CreateClassHookDelegateAsync(StaticHookMethod hook) + { + // Process hook registration event receivers + await ProcessHookRegistrationAsync(hook); + + return async (context, cancellationToken) => + { + var timeoutAction = HookTimeoutHelper.CreateTimeoutHookAction( + hook, + context, + cancellationToken); + + await timeoutAction(); + }; + } + private static Func CreateAssemblyHookDelegate(StaticHookMethod hook) { return async (context, cancellationToken) => @@ -713,6 +757,22 @@ private static Func CreateAssembly }; } + private async Task> CreateAssemblyHookDelegateAsync(StaticHookMethod hook) + { + // Process hook registration event receivers + await ProcessHookRegistrationAsync(hook); + + return async (context, cancellationToken) => + { + var timeoutAction = HookTimeoutHelper.CreateTimeoutHookAction( + hook, + context, + cancellationToken); + + await timeoutAction(); + }; + } + private static Func CreateTestSessionHookDelegate(StaticHookMethod hook) { return async (context, cancellationToken) => @@ -726,6 +786,22 @@ private static Func CreateTestSessi }; } + private async Task> CreateTestSessionHookDelegateAsync(StaticHookMethod hook) + { + // Process hook registration event receivers + await ProcessHookRegistrationAsync(hook); + + return async (context, cancellationToken) => + { + var timeoutAction = HookTimeoutHelper.CreateTimeoutHookAction( + hook, + context, + cancellationToken); + + await timeoutAction(); + }; + } + private static Func CreateBeforeTestDiscoveryHookDelegate(StaticHookMethod hook) { return async (context, cancellationToken) => @@ -739,6 +815,22 @@ private static Func CreateB }; } + private async Task> CreateBeforeTestDiscoveryHookDelegateAsync(StaticHookMethod hook) + { + // Process hook registration event receivers + await ProcessHookRegistrationAsync(hook); + + return async (context, cancellationToken) => + { + var timeoutAction = HookTimeoutHelper.CreateTimeoutHookAction( + hook, + context, + cancellationToken); + + await timeoutAction(); + }; + } + private static Func CreateTestDiscoveryHookDelegate(StaticHookMethod hook) { return async (context, cancellationToken) => @@ -752,4 +844,20 @@ private static Func CreateTestDis }; } + private async Task> CreateTestDiscoveryHookDelegateAsync(StaticHookMethod hook) + { + // Process hook registration event receivers + await ProcessHookRegistrationAsync(hook); + + return async (context, cancellationToken) => + { + var timeoutAction = HookTimeoutHelper.CreateTimeoutHookAction( + hook, + context, + cancellationToken); + + await timeoutAction(); + }; + } + } diff --git a/TUnit.TestProject/HookTimeoutTests.cs b/TUnit.TestProject/HookTimeoutTests.cs new file mode 100644 index 0000000000..e0e4ec5919 --- /dev/null +++ b/TUnit.TestProject/HookTimeoutTests.cs @@ -0,0 +1,69 @@ +namespace TUnit.TestProject; + +/// +/// Tests for hook timeout attribute functionality. +/// This class tests that [Timeout] attribute on hooks is properly respected. +/// +public class HookTimeoutTests +{ + /// + /// A 100ms timeout on the hook - it should fail because the hook takes 500ms + /// + [Test] + public void Test_WithTimeoutHook() + { + // This test exists to verify that the hook timeout was applied correctly + } +} + +/// +/// Class-level hook with timeout that should fail +/// +public class ClassHookTimeoutTests +{ + private static bool _classHookExecuted; + + [Timeout(100)] // 100ms timeout - should fail + [Before(Class)] + public static async Task BeforeClass(CancellationToken cancellationToken) + { + _classHookExecuted = true; + // This will take longer than the timeout + await Task.Delay(TimeSpan.FromMilliseconds(500), cancellationToken); + } + + [Test] + public async Task Test_ShouldNotRun_BecauseHookTimedOut() + { + // This test should not actually execute because the class hook should timeout + await Assert.That(_classHookExecuted).IsTrue(); + } +} + +/// +/// Assembly-level hook with timeout that should pass (short delay within timeout) +/// +public class AssemblyHookTimeoutPassTests +{ + // Note: We can't test assembly hooks that fail timeout in the same way, + // as they affect the whole assembly. We test that the timeout IS applied + // by checking that a hook with a longer timeout succeeds. + + // The BeforeAssembly hook with a 5 second timeout, but quick execution + private static bool _assemblyHookExecuted; + + [Timeout(5000)] // 5 second timeout + [Before(Assembly)] + public static async Task BeforeAssembly(CancellationToken cancellationToken) + { + _assemblyHookExecuted = true; + // Short delay - well within timeout + await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken); + } + + [Test] + public async Task Test_ShouldRun_BecauseHookCompletedInTime() + { + await Assert.That(_assemblyHookExecuted).IsTrue(); + } +}