diff --git a/TUnit.Assertions.SourceGenerator.Tests/AssertOverloadsGeneratorTests.cs b/TUnit.Assertions.SourceGenerator.Tests/AssertOverloadsGeneratorTests.cs new file mode 100644 index 0000000000..8050f6ddab --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.Tests/AssertOverloadsGeneratorTests.cs @@ -0,0 +1,30 @@ +using TUnit.Assertions.SourceGenerator.Generators; + +namespace TUnit.Assertions.SourceGenerator.Tests; + +internal class AssertOverloadsGeneratorTests : TestsBase +{ + [Test] + public Task BasicOverloadGeneration() => RunTest( + Path.Combine(Sourcy.Git.RootDirectory.FullName, + "TUnit.Assertions.SourceGenerator.Tests", + "TestData", + "SimpleAssertOverloadsTest.cs"), + async generatedFiles => + { + await Assert.That(generatedFiles).HasCount().GreaterThanOrEqualTo(1); + + // Verify the generator produces wrapper types + var mainFile = generatedFiles.First(); + await Assert.That(mainFile).Contains("FuncString"); + await Assert.That(mainFile).Contains("AsyncFuncString"); + await Assert.That(mainFile).Contains("TaskString"); + await Assert.That(mainFile).Contains("ValueTaskString"); + await Assert.That(mainFile).Contains("IAssertionSource"); + + // Verify overloads are generated + await Assert.That(mainFile).Contains("public static FuncString That("); + await Assert.That(mainFile).Contains("Func func"); + await Assert.That(mainFile).Contains("[OverloadResolutionPriority(3)]"); + }); +} diff --git a/TUnit.Assertions.SourceGenerator.Tests/TestData/SimpleAssertOverloadsTest.cs b/TUnit.Assertions.SourceGenerator.Tests/TestData/SimpleAssertOverloadsTest.cs new file mode 100644 index 0000000000..1f751b9ed4 --- /dev/null +++ b/TUnit.Assertions.SourceGenerator.Tests/TestData/SimpleAssertOverloadsTest.cs @@ -0,0 +1,13 @@ +using TUnit.Assertions.Attributes; + +namespace TUnit.Assertions.SourceGenerator.Tests.TestData; + +/// +/// Test case: Simple method decorated with [GenerateAssertOverloads] +/// Should generate wrapper types and overloads for Func, Task, and ValueTask variants. +/// +public static partial class TestAssert +{ + [GenerateAssertOverloads(Priority = 3)] + public static string That(string? value) => value ?? ""; +} diff --git a/TUnit.Assertions.SourceGenerator/Generators/AssertOverloadsGenerator.cs b/TUnit.Assertions.SourceGenerator/Generators/AssertOverloadsGenerator.cs new file mode 100644 index 0000000000..0864de630f --- /dev/null +++ b/TUnit.Assertions.SourceGenerator/Generators/AssertOverloadsGenerator.cs @@ -0,0 +1,865 @@ +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace TUnit.Assertions.SourceGenerator.Generators; + +/// +/// Source generator that creates wrapper type overloads from methods decorated with [GenerateAssertOverloads]. +/// Generates: +/// - Wrapper types (FuncXxxAssertion, TaskXxxAssertion, etc.) implementing IAssertionSource<T> +/// - Assert.That() overloads for Func, Task, ValueTask variants +/// +[Generator] +public sealed class AssertOverloadsGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Find all methods decorated with [GenerateAssertOverloads] + var methods = context.SyntaxProvider + .ForAttributeWithMetadataName( + "TUnit.Assertions.Attributes.GenerateAssertOverloadsAttribute", + predicate: static (node, _) => node is MethodDeclarationSyntax, + transform: static (ctx, ct) => GetMethodData(ctx, ct)) + .Where(x => x != null); + + // Generate output + context.RegisterSourceOutput(methods.Collect(), static (context, methods) => + { + GenerateOverloads(context, methods!); + }); + } + + private static OverloadMethodData? GetMethodData( + GeneratorAttributeSyntaxContext context, + CancellationToken cancellationToken) + { + if (context.TargetSymbol is not IMethodSymbol methodSymbol) + { + return null; + } + + // Extract attribute properties + int priority = 0; + bool generateFunc = true; + bool generateFuncTask = true; + bool generateFuncValueTask = true; + bool generateTask = true; + bool generateValueTask = true; + + var attribute = context.Attributes.FirstOrDefault(); + if (attribute != null) + { + foreach (var namedArg in attribute.NamedArguments) + { + switch (namedArg.Key) + { + case "Priority" when namedArg.Value.Value is int p: + priority = p; + break; + case "Func" when namedArg.Value.Value is bool f: + generateFunc = f; + break; + case "FuncTask" when namedArg.Value.Value is bool ft: + generateFuncTask = ft; + break; + case "FuncValueTask" when namedArg.Value.Value is bool fvt: + generateFuncValueTask = fvt; + break; + case "Task" when namedArg.Value.Value is bool t: + generateTask = t; + break; + case "ValueTask" when namedArg.Value.Value is bool vt: + generateValueTask = vt; + break; + } + } + } + + return new OverloadMethodData( + methodSymbol, + priority, + generateFunc, + generateFuncTask, + generateFuncValueTask, + generateTask, + generateValueTask); + } + + private static void GenerateOverloads( + SourceProductionContext context, + ImmutableArray methods) + { + var validMethods = methods.Where(m => m != null).Select(m => m!).ToList(); + if (validMethods.Count == 0) + { + return; + } + + // Group methods by return type to avoid generating duplicate wrapper types + // when multiple source types (e.g., IDictionary and IReadOnlyDictionary) share the same assertion type + var methodsByReturnType = validMethods + .GroupBy(m => m.Method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + .ToList(); + + foreach (var group in methodsByReturnType) + { + GenerateForMethodGroup(context, group.ToList()); + } + } + + /// + /// Generates wrapper types and overloads for a group of methods that share the same return type. + /// This avoids generating duplicate wrapper types when multiple source types (e.g., IDictionary and IReadOnlyDictionary) + /// share the same assertion type (e.g., DictionaryAssertion). + /// + private static void GenerateForMethodGroup( + SourceProductionContext context, + List methodGroup) + { + if (methodGroup.Count == 0) + { + return; + } + + // Use the first method to get shared info (return type, containing type, namespace, type parameters) + var firstMethodData = methodGroup[0]; + var firstMethod = firstMethodData.Method; + var returnType = firstMethod.ReturnType as INamedTypeSymbol; + if (returnType == null) + { + return; + } + + var firstParam = firstMethod.Parameters.FirstOrDefault(); + if (firstParam == null) + { + return; + } + + var containingType = firstMethod.ContainingType; + var namespaceName = containingType.ContainingNamespace?.ToDisplayString() ?? "TUnit.Assertions"; + var returnTypeName = returnType.Name; + + // Handle generic methods - extract type parameters from the method + var methodTypeParameters = firstMethod.TypeParameters; + + // For wrapper types, we need to use the source type from the first parameter + // This is typically the type that assertions operate on + var sourceTypeInfo = GetSourceTypeInfo(firstParam.Type); + + var sb = new StringBuilder(); + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine(); + sb.AppendLine("using System;"); + sb.AppendLine("using System.Runtime.CompilerServices;"); + sb.AppendLine("using System.Text;"); + sb.AppendLine("using System.Threading.Tasks;"); + sb.AppendLine("using TUnit.Assertions.Conditions;"); + sb.AppendLine("using TUnit.Assertions.Core;"); + sb.AppendLine(); + sb.AppendLine($"namespace {namespaceName};"); + sb.AppendLine(); + + // Determine which wrapper types to generate (based on first method's settings) + // All methods in the group should typically have the same generation settings + // since they share the same return type + if (firstMethodData.GenerateFunc) + { + GenerateFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (firstMethodData.GenerateFuncTask) + { + GenerateAsyncFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (firstMethodData.GenerateFuncValueTask) + { + GenerateValueTaskAsyncFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (firstMethodData.GenerateTask) + { + GenerateTaskWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (firstMethodData.GenerateValueTask) + { + GenerateValueTaskWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + // Generate partial class with overloads for ALL methods in the group + sb.AppendLine($"public static partial class {containingType.Name}"); + sb.AppendLine("{"); + + foreach (var methodData in methodGroup) + { + var method = methodData.Method; + var param = method.Parameters.FirstOrDefault(); + if (param == null) + { + continue; + } + + var paramSourceTypeInfo = GetSourceTypeInfo(param.Type); + var paramTypeParameters = method.TypeParameters; + + if (methodData.GenerateFunc) + { + GenerateFuncOverload(sb, returnTypeName, paramSourceTypeInfo, paramTypeParameters, methodData.Priority); + } + + if (methodData.GenerateFuncTask) + { + GenerateAsyncFuncOverload(sb, returnTypeName, paramSourceTypeInfo, paramTypeParameters, methodData.Priority); + } + + if (methodData.GenerateFuncValueTask) + { + GenerateValueTaskAsyncFuncOverload(sb, returnTypeName, paramSourceTypeInfo, paramTypeParameters, methodData.Priority); + } + + if (methodData.GenerateTask) + { + GenerateTaskOverload(sb, returnTypeName, paramSourceTypeInfo, paramTypeParameters, methodData.Priority); + } + + if (methodData.GenerateValueTask) + { + GenerateValueTaskOverload(sb, returnTypeName, paramSourceTypeInfo, paramTypeParameters, methodData.Priority); + } + } + + sb.AppendLine("}"); + + // Generate a unique file name based on return type + var safeReturnTypeName = returnTypeName + .Replace("<", "_") + .Replace(">", "_") + .Replace(",", "_") + .Replace(" ", ""); + var fileName = $"{containingType.Name}.{safeReturnTypeName}.Overloads.g.cs"; + context.AddSource(fileName, sb.ToString()); + } + + private static void GenerateForMethod( + SourceProductionContext context, + OverloadMethodData methodData) + { + var method = methodData.Method; + var returnType = method.ReturnType as INamedTypeSymbol; + if (returnType == null) + { + return; + } + + // Get the source type (first parameter type) + var firstParam = method.Parameters.FirstOrDefault(); + if (firstParam == null) + { + return; + } + + var containingType = method.ContainingType; + var namespaceName = containingType.ContainingNamespace?.ToDisplayString() ?? "TUnit.Assertions"; + + // Extract type information + var sourceTypeInfo = GetSourceTypeInfo(firstParam.Type); + var returnTypeName = returnType.Name; + + // Handle generic methods - extract type parameters from the method + var methodTypeParameters = method.TypeParameters; + var typeParameterList = ""; + var typeParameterConstraints = ""; + if (methodTypeParameters.Length > 0) + { + typeParameterList = "<" + string.Join(", ", methodTypeParameters.Select(tp => tp.Name)) + ">"; + typeParameterConstraints = GetTypeConstraints(methodTypeParameters); + } + + // Build the wrapper type name suffix (includes type params for generic methods) + var wrapperTypeSuffix = returnTypeName + typeParameterList; + + var sb = new StringBuilder(); + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine(); + sb.AppendLine("using System;"); + sb.AppendLine("using System.Runtime.CompilerServices;"); + sb.AppendLine("using System.Text;"); + sb.AppendLine("using System.Threading.Tasks;"); + sb.AppendLine("using TUnit.Assertions.Conditions;"); + sb.AppendLine("using TUnit.Assertions.Core;"); + sb.AppendLine(); + sb.AppendLine($"namespace {namespaceName};"); + sb.AppendLine(); + + // Generate wrapper types + if (methodData.GenerateFunc) + { + GenerateFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (methodData.GenerateFuncTask) + { + GenerateAsyncFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (methodData.GenerateFuncValueTask) + { + GenerateValueTaskAsyncFuncWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (methodData.GenerateTask) + { + GenerateTaskWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + if (methodData.GenerateValueTask) + { + GenerateValueTaskWrapperType(sb, returnTypeName, sourceTypeInfo, methodTypeParameters); + } + + // Generate partial class with overloads + sb.AppendLine($"public static partial class {containingType.Name}"); + sb.AppendLine("{"); + + if (methodData.GenerateFunc) + { + GenerateFuncOverload(sb, returnTypeName, sourceTypeInfo, methodTypeParameters, methodData.Priority); + } + + if (methodData.GenerateFuncTask) + { + GenerateAsyncFuncOverload(sb, returnTypeName, sourceTypeInfo, methodTypeParameters, methodData.Priority); + } + + if (methodData.GenerateFuncValueTask) + { + GenerateValueTaskAsyncFuncOverload(sb, returnTypeName, sourceTypeInfo, methodTypeParameters, methodData.Priority); + } + + if (methodData.GenerateTask) + { + GenerateTaskOverload(sb, returnTypeName, sourceTypeInfo, methodTypeParameters, methodData.Priority); + } + + if (methodData.GenerateValueTask) + { + GenerateValueTaskOverload(sb, returnTypeName, sourceTypeInfo, methodTypeParameters, methodData.Priority); + } + + sb.AppendLine("}"); + + // Generate a unique file name + var safeTypeName = sourceTypeInfo.SafeTypeName; + var fileName = $"{containingType.Name}.{safeTypeName}.Overloads.g.cs"; + context.AddSource(fileName, sb.ToString()); + } + + private static SourceTypeInfo GetSourceTypeInfo(ITypeSymbol typeSymbol) + { + // Get display string for use in generated code + var fullTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var minimalTypeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + + // Handle nullable reference types + var isNullable = typeSymbol.NullableAnnotation == NullableAnnotation.Annotated; + var underlyingType = typeSymbol; + + if (isNullable && typeSymbol is INamedTypeSymbol namedType) + { + underlyingType = namedType.WithNullableAnnotation(NullableAnnotation.None); + } + + // Create safe file name component + var safeTypeName = minimalTypeName + .Replace("<", "_") + .Replace(">", "_") + .Replace(",", "_") + .Replace(" ", "") + .Replace("?", "Nullable"); + + return new SourceTypeInfo( + fullTypeName, + minimalTypeName, + safeTypeName, + isNullable, + underlyingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + underlyingType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + } + + private static string GetTypeConstraints(ImmutableArray typeParameters) + { + if (typeParameters.Length == 0) + { + return ""; + } + + var constraints = new StringBuilder(); + foreach (var tp in typeParameters) + { + var constraintParts = new List(); + + // Reference type constraint + if (tp.HasReferenceTypeConstraint) + { + constraintParts.Add("class"); + } + + // Value type constraint + if (tp.HasValueTypeConstraint) + { + constraintParts.Add("struct"); + } + + // Unmanaged constraint + if (tp.HasUnmanagedTypeConstraint) + { + constraintParts.Add("unmanaged"); + } + + // notnull constraint + if (tp.HasNotNullConstraint) + { + constraintParts.Add("notnull"); + } + + // Type constraints + foreach (var constraintType in tp.ConstraintTypes) + { + constraintParts.Add(constraintType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + } + + // new() constraint (must be last) + if (tp.HasConstructorConstraint) + { + constraintParts.Add("new()"); + } + + if (constraintParts.Count > 0) + { + constraints.AppendLine(); + constraints.Append($" where {tp.Name} : {string.Join(", ", constraintParts)}"); + } + } + + return constraints.ToString(); + } + + private static string GetTypeParameterList(ImmutableArray typeParameters) + { + if (typeParameters.Length == 0) + { + return ""; + } + + return "<" + string.Join(", ", typeParameters.Select(tp => tp.Name)) + ">"; + } + + private static void GenerateFuncWrapperType( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var wrapperTypeName = $"Func{returnTypeName}{typeParamList}"; + var sourceType = sourceTypeInfo.MinimalTypeName; + // Ensure nullable type for tuple first element + var nullableSourceType = sourceTypeInfo.IsNullable ? sourceType : sourceType + "?"; + + sb.AppendLine($"/// "); + sb.AppendLine($"/// Func wrapper for {returnTypeName}. Implements IAssertionSource for lazy synchronous evaluation."); + sb.AppendLine($"/// "); + sb.AppendLine($"public class Func{returnTypeName}{typeParamList} : IAssertionSource<{sourceType}>{constraints}"); + sb.AppendLine("{"); + sb.AppendLine($" public AssertionContext<{sourceType}> Context {{ get; }}"); + sb.AppendLine(); + sb.AppendLine($" public Func{returnTypeName}(Func<{sourceType}> func, string? expression)"); + sb.AppendLine(" {"); + sb.AppendLine(" var expressionBuilder = new StringBuilder();"); + sb.AppendLine(" expressionBuilder.Append($\"Assert.That({expression ?? \"?\"})\");"); + sb.AppendLine(); + sb.AppendLine($" var evaluationContext = new EvaluationContext<{sourceType}>(() =>"); + sb.AppendLine(" {"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = func();"); + sb.AppendLine($" return Task.FromResult<({nullableSourceType}, Exception?)>((result, null));"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine($" return Task.FromResult<({nullableSourceType}, Exception?)>((default, ex));"); + sb.AppendLine(" }"); + sb.AppendLine(" });"); + sb.AppendLine(); + sb.AppendLine($" Context = new AssertionContext<{sourceType}>(evaluationContext, expressionBuilder);"); + sb.AppendLine(" }"); + sb.AppendLine(); + GenerateIAssertionSourceMethods(sb, sourceType); + sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void GenerateAsyncFuncWrapperType( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + // Ensure nullable type for tuple first element + var nullableSourceType = sourceTypeInfo.IsNullable ? sourceType : sourceType + "?"; + + sb.AppendLine($"/// "); + sb.AppendLine($"/// Async Func wrapper for {returnTypeName}. Implements IAssertionSource for async factory evaluation."); + sb.AppendLine($"/// "); + sb.AppendLine($"public class AsyncFunc{returnTypeName}{typeParamList} : IAssertionSource<{sourceType}>{constraints}"); + sb.AppendLine("{"); + sb.AppendLine($" public AssertionContext<{sourceType}> Context {{ get; }}"); + sb.AppendLine(); + sb.AppendLine($" public AsyncFunc{returnTypeName}(Func> func, string? expression)"); + sb.AppendLine(" {"); + sb.AppendLine(" var expressionBuilder = new StringBuilder();"); + sb.AppendLine(" expressionBuilder.Append($\"Assert.That({expression ?? \"?\"})\");"); + sb.AppendLine(); + sb.AppendLine($" var evaluationContext = new EvaluationContext<{sourceType}>(async () =>"); + sb.AppendLine(" {"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = await func().ConfigureAwait(false);"); + sb.AppendLine($" return (({nullableSourceType})result, (Exception?)null);"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine($" return (default({nullableSourceType}), ex);"); + sb.AppendLine(" }"); + sb.AppendLine(" });"); + sb.AppendLine(); + sb.AppendLine($" Context = new AssertionContext<{sourceType}>(evaluationContext, expressionBuilder);"); + sb.AppendLine(" }"); + sb.AppendLine(); + GenerateIAssertionSourceMethods(sb, sourceType); + sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void GenerateValueTaskAsyncFuncWrapperType( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + // Ensure nullable type for tuple first element + var nullableSourceType = sourceTypeInfo.IsNullable ? sourceType : sourceType + "?"; + + sb.AppendLine($"/// "); + sb.AppendLine($"/// ValueTask Async Func wrapper for {returnTypeName}. Implements IAssertionSource for ValueTask async factory evaluation."); + sb.AppendLine($"/// "); + sb.AppendLine($"public class ValueTaskAsyncFunc{returnTypeName}{typeParamList} : IAssertionSource<{sourceType}>{constraints}"); + sb.AppendLine("{"); + sb.AppendLine($" public AssertionContext<{sourceType}> Context {{ get; }}"); + sb.AppendLine(); + sb.AppendLine($" public ValueTaskAsyncFunc{returnTypeName}(Func> func, string? expression)"); + sb.AppendLine(" {"); + sb.AppendLine(" var expressionBuilder = new StringBuilder();"); + sb.AppendLine(" expressionBuilder.Append($\"Assert.That({expression ?? \"?\"})\");"); + sb.AppendLine(); + sb.AppendLine($" var evaluationContext = new EvaluationContext<{sourceType}>(async () =>"); + sb.AppendLine(" {"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = await func().ConfigureAwait(false);"); + sb.AppendLine($" return (({nullableSourceType})result, (Exception?)null);"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine($" return (default({nullableSourceType}), ex);"); + sb.AppendLine(" }"); + sb.AppendLine(" });"); + sb.AppendLine(); + sb.AppendLine($" Context = new AssertionContext<{sourceType}>(evaluationContext, expressionBuilder);"); + sb.AppendLine(" }"); + sb.AppendLine(); + GenerateIAssertionSourceMethods(sb, sourceType); + sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void GenerateTaskWrapperType( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + // Ensure nullable type for tuple first element + var nullableSourceType = sourceTypeInfo.IsNullable ? sourceType : sourceType + "?"; + + sb.AppendLine($"/// "); + sb.AppendLine($"/// Task wrapper for {returnTypeName}. Implements IAssertionSource for awaiting an already-started task."); + sb.AppendLine($"/// "); + sb.AppendLine($"public class Task{returnTypeName}{typeParamList} : IAssertionSource<{sourceType}>{constraints}"); + sb.AppendLine("{"); + sb.AppendLine($" public AssertionContext<{sourceType}> Context {{ get; }}"); + sb.AppendLine(); + sb.AppendLine($" public Task{returnTypeName}(Task<{sourceType}> task, string? expression)"); + sb.AppendLine(" {"); + sb.AppendLine(" var expressionBuilder = new StringBuilder();"); + sb.AppendLine(" expressionBuilder.Append($\"Assert.That({expression ?? \"?\"})\");"); + sb.AppendLine(); + sb.AppendLine($" var evaluationContext = new EvaluationContext<{sourceType}>(async () =>"); + sb.AppendLine(" {"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = await task.ConfigureAwait(false);"); + sb.AppendLine($" return (({nullableSourceType})result, (Exception?)null);"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine($" return (default({nullableSourceType}), ex);"); + sb.AppendLine(" }"); + sb.AppendLine(" });"); + sb.AppendLine(); + sb.AppendLine($" Context = new AssertionContext<{sourceType}>(evaluationContext, expressionBuilder);"); + sb.AppendLine(" }"); + sb.AppendLine(); + GenerateIAssertionSourceMethods(sb, sourceType); + sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void GenerateValueTaskWrapperType( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + // Ensure nullable type for tuple first element + var nullableSourceType = sourceTypeInfo.IsNullable ? sourceType : sourceType + "?"; + + sb.AppendLine($"/// "); + sb.AppendLine($"/// ValueTask wrapper for {returnTypeName}. Implements IAssertionSource for awaiting a ValueTask."); + sb.AppendLine($"/// "); + sb.AppendLine($"public class ValueTask{returnTypeName}{typeParamList} : IAssertionSource<{sourceType}>{constraints}"); + sb.AppendLine("{"); + sb.AppendLine($" public AssertionContext<{sourceType}> Context {{ get; }}"); + sb.AppendLine(); + sb.AppendLine($" public ValueTask{returnTypeName}(ValueTask<{sourceType}> valueTask, string? expression)"); + sb.AppendLine(" {"); + sb.AppendLine(" var expressionBuilder = new StringBuilder();"); + sb.AppendLine(" expressionBuilder.Append($\"Assert.That({expression ?? \"?\"})\");"); + sb.AppendLine(); + sb.AppendLine($" var evaluationContext = new EvaluationContext<{sourceType}>(async () =>"); + sb.AppendLine(" {"); + sb.AppendLine(" try"); + sb.AppendLine(" {"); + sb.AppendLine(" var result = await valueTask.ConfigureAwait(false);"); + sb.AppendLine($" return (({nullableSourceType})result, (Exception?)null);"); + sb.AppendLine(" }"); + sb.AppendLine(" catch (Exception ex)"); + sb.AppendLine(" {"); + sb.AppendLine($" return (default({nullableSourceType}), ex);"); + sb.AppendLine(" }"); + sb.AppendLine(" });"); + sb.AppendLine(); + sb.AppendLine($" Context = new AssertionContext<{sourceType}>(evaluationContext, expressionBuilder);"); + sb.AppendLine(" }"); + sb.AppendLine(); + GenerateIAssertionSourceMethods(sb, sourceType); + sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void GenerateIAssertionSourceMethods(StringBuilder sb, string sourceType) + { + // Generate the IAssertionSource interface methods + sb.AppendLine($" /// "); + sb.AppendLine($" public TypeOfAssertion<{sourceType}, TExpected> IsTypeOf()"); + sb.AppendLine(" {"); + sb.AppendLine(" Context.ExpressionBuilder.Append($\".IsTypeOf<{typeof(TExpected).Name}>()\");"); + sb.AppendLine($" return new TypeOfAssertion<{sourceType}, TExpected>(Context);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// "); + sb.AppendLine($" public IsNotTypeOfAssertion<{sourceType}, TExpected> IsNotTypeOf()"); + sb.AppendLine(" {"); + sb.AppendLine(" Context.ExpressionBuilder.Append($\".IsNotTypeOf<{typeof(TExpected).Name}>()\");"); + sb.AppendLine($" return new IsNotTypeOfAssertion<{sourceType}, TExpected>(Context);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// "); + sb.AppendLine($" public IsAssignableToAssertion IsAssignableTo()"); + sb.AppendLine(" {"); + sb.AppendLine(" Context.ExpressionBuilder.Append($\".IsAssignableTo<{typeof(TTarget).Name}>()\");"); + sb.AppendLine($" return new IsAssignableToAssertion(Context);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// "); + sb.AppendLine($" public IsNotAssignableToAssertion IsNotAssignableTo()"); + sb.AppendLine(" {"); + sb.AppendLine(" Context.ExpressionBuilder.Append($\".IsNotAssignableTo<{typeof(TTarget).Name}>()\");"); + sb.AppendLine($" return new IsNotAssignableToAssertion(Context);"); + sb.AppendLine(" }"); + } + + private static void GenerateFuncOverload( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters, + int priority) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + + if (priority != 0) + { + sb.AppendLine($" [OverloadResolutionPriority({priority})]"); + } + + sb.AppendLine($" public static Func{returnTypeName}{typeParamList} That{typeParamList}("); + sb.AppendLine($" Func<{sourceType}> func,"); + sb.AppendLine($" [CallerArgumentExpression(nameof(func))] string? expression = null){constraints}"); + sb.AppendLine(" {"); + sb.AppendLine($" return new Func{returnTypeName}{typeParamList}(func, expression);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateAsyncFuncOverload( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters, + int priority) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + + if (priority != 0) + { + sb.AppendLine($" [OverloadResolutionPriority({priority})]"); + } + + sb.AppendLine($" public static AsyncFunc{returnTypeName}{typeParamList} That{typeParamList}("); + sb.AppendLine($" Func> func,"); + sb.AppendLine($" [CallerArgumentExpression(nameof(func))] string? expression = null){constraints}"); + sb.AppendLine(" {"); + sb.AppendLine($" return new AsyncFunc{returnTypeName}{typeParamList}(func, expression);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateValueTaskAsyncFuncOverload( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters, + int priority) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + + if (priority != 0) + { + sb.AppendLine($" [OverloadResolutionPriority({priority})]"); + } + + sb.AppendLine($" public static ValueTaskAsyncFunc{returnTypeName}{typeParamList} That{typeParamList}("); + sb.AppendLine($" Func> func,"); + sb.AppendLine($" [CallerArgumentExpression(nameof(func))] string? expression = null){constraints}"); + sb.AppendLine(" {"); + sb.AppendLine($" return new ValueTaskAsyncFunc{returnTypeName}{typeParamList}(func, expression);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateTaskOverload( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters, + int priority) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + + if (priority != 0) + { + sb.AppendLine($" [OverloadResolutionPriority({priority})]"); + } + + sb.AppendLine($" public static Task{returnTypeName}{typeParamList} That{typeParamList}("); + sb.AppendLine($" Task<{sourceType}> task,"); + sb.AppendLine($" [CallerArgumentExpression(nameof(task))] string? expression = null){constraints}"); + sb.AppendLine(" {"); + sb.AppendLine($" return new Task{returnTypeName}{typeParamList}(task, expression);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void GenerateValueTaskOverload( + StringBuilder sb, + string returnTypeName, + SourceTypeInfo sourceTypeInfo, + ImmutableArray typeParameters, + int priority) + { + var typeParamList = GetTypeParameterList(typeParameters); + var constraints = GetTypeConstraints(typeParameters); + var sourceType = sourceTypeInfo.MinimalTypeName; + + if (priority != 0) + { + sb.AppendLine($" [OverloadResolutionPriority({priority})]"); + } + + sb.AppendLine($" public static ValueTask{returnTypeName}{typeParamList} That{typeParamList}("); + sb.AppendLine($" ValueTask<{sourceType}> valueTask,"); + sb.AppendLine($" [CallerArgumentExpression(nameof(valueTask))] string? expression = null){constraints}"); + sb.AppendLine(" {"); + sb.AppendLine($" return new ValueTask{returnTypeName}{typeParamList}(valueTask, expression);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private record OverloadMethodData( + IMethodSymbol Method, + int Priority, + bool GenerateFunc, + bool GenerateFuncTask, + bool GenerateFuncValueTask, + bool GenerateTask, + bool GenerateValueTask); + + private record SourceTypeInfo( + string FullTypeName, + string MinimalTypeName, + string SafeTypeName, + bool IsNullable, + string UnderlyingFullTypeName, + string UnderlyingMinimalTypeName); +} diff --git a/TUnit.Assertions.Tests/DictionaryAssertionTests.cs b/TUnit.Assertions.Tests/DictionaryAssertionTests.cs new file mode 100644 index 0000000000..9c79c846c7 --- /dev/null +++ b/TUnit.Assertions.Tests/DictionaryAssertionTests.cs @@ -0,0 +1,563 @@ +using System.Collections.Concurrent; + +namespace TUnit.Assertions.Tests; + +/// +/// Integration tests for dictionary assertion methods (ContainsKey, DoesNotContainKey, ContainsValue, DoesNotContainValue). +/// Tests cover IDictionary, IReadOnlyDictionary, and concrete dictionary types with chaining and failure scenarios. +/// +public class DictionaryAssertionTests +{ + #region IDictionary Direct Assertions + + [Test] + public async Task IDictionary_ContainsKey_Passes_When_Key_Exists() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1, + ["key2"] = 2 + }; + + await Assert.That(dictionary).ContainsKey("key1"); + } + + [Test] + public async Task IDictionary_ContainsKey_Fails_When_Key_Missing() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsKey("missing")); + } + + [Test] + public async Task IDictionary_DoesNotContainKey_Passes_When_Key_Missing() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainKey("missing"); + } + + [Test] + public async Task IDictionary_DoesNotContainKey_Fails_When_Key_Exists() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainKey("key1")); + } + + [Test] + public async Task IDictionary_ContainsValue_Passes_When_Value_Exists() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 42, + ["key2"] = 100 + }; + + await Assert.That(dictionary).ContainsValue(42); + } + + [Test] + public async Task IDictionary_ContainsValue_Fails_When_Value_Missing() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsValue(999)); + } + + [Test] + public async Task IDictionary_DoesNotContainValue_Passes_When_Value_Missing() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainValue(999); + } + + [Test] + public async Task IDictionary_DoesNotContainValue_Fails_When_Value_Exists() + { + IDictionary dictionary = new Dictionary + { + ["key1"] = 42 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainValue(42)); + } + + #endregion + + #region IReadOnlyDictionary Assertions + + [Test] + public async Task IReadOnlyDictionary_ContainsKey_Passes_When_Key_Exists() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1, + ["key2"] = 2 + }; + + await Assert.That(dictionary).ContainsKey("key1"); + } + + [Test] + public async Task IReadOnlyDictionary_ContainsKey_Fails_When_Key_Missing() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsKey("missing")); + } + + [Test] + public async Task IReadOnlyDictionary_DoesNotContainKey_Passes_When_Key_Missing() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainKey("missing"); + } + + [Test] + public async Task IReadOnlyDictionary_DoesNotContainKey_Fails_When_Key_Exists() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainKey("key1")); + } + + [Test] + public async Task IReadOnlyDictionary_ContainsValue_Passes_When_Value_Exists() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 42, + ["key2"] = 100 + }; + + await Assert.That(dictionary).ContainsValue(42); + } + + [Test] + public async Task IReadOnlyDictionary_ContainsValue_Fails_When_Value_Missing() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsValue(999)); + } + + [Test] + public async Task IReadOnlyDictionary_DoesNotContainValue_Passes_When_Value_Missing() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainValue(999); + } + + [Test] + public async Task IReadOnlyDictionary_DoesNotContainValue_Fails_When_Value_Exists() + { + IReadOnlyDictionary dictionary = new Dictionary + { + ["key1"] = 42 + }; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainValue(42)); + } + + #endregion + + #region Concrete Dictionary Types + + [Test] + public async Task Dictionary_ContainsKey_Works() + { + var dictionary = new Dictionary + { + ["key1"] = 1, + ["key2"] = 2 + }; + + await Assert.That(dictionary).ContainsKey("key1"); + } + + [Test] + public async Task Dictionary_ContainsValue_Works() + { + var dictionary = new Dictionary + { + ["key1"] = 42, + ["key2"] = 100 + }; + + await Assert.That(dictionary).ContainsValue(42); + } + + [Test] + public async Task Dictionary_DoesNotContainKey_Works() + { + var dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainKey("missing"); + } + + [Test] + public async Task Dictionary_DoesNotContainValue_Works() + { + var dictionary = new Dictionary + { + ["key1"] = 1 + }; + + await Assert.That(dictionary).DoesNotContainValue(999); + } + + [Test] + public async Task ConcurrentDictionary_ContainsKey_Works() + { + var dictionary = new ConcurrentDictionary(); + dictionary["key1"] = 1; + dictionary["key2"] = 2; + + await Assert.That(dictionary).ContainsKey("key1"); + } + + [Test] + public async Task ConcurrentDictionary_ContainsValue_Works() + { + var dictionary = new ConcurrentDictionary(); + dictionary["key1"] = 42; + dictionary["key2"] = 100; + + await Assert.That(dictionary).ContainsValue(42); + } + + [Test] + public async Task ConcurrentDictionary_DoesNotContainKey_Works() + { + var dictionary = new ConcurrentDictionary(); + dictionary["key1"] = 1; + + await Assert.That(dictionary).DoesNotContainKey("missing"); + } + + [Test] + public async Task ConcurrentDictionary_DoesNotContainValue_Works() + { + var dictionary = new ConcurrentDictionary(); + dictionary["key1"] = 1; + + await Assert.That(dictionary).DoesNotContainValue(999); + } + + [Test] + public async Task SortedDictionary_ContainsKey_Works() + { + var dictionary = new SortedDictionary + { + ["alpha"] = 1, + ["beta"] = 2, + ["gamma"] = 3 + }; + + await Assert.That(dictionary).ContainsKey("beta"); + } + + [Test] + public async Task SortedDictionary_ContainsValue_Works() + { + var dictionary = new SortedDictionary + { + ["alpha"] = 1, + ["beta"] = 2, + ["gamma"] = 3 + }; + + await Assert.That(dictionary).ContainsValue(2); + } + + #endregion + + #region Chained Assertions + + [Test] + public async Task Dictionary_Chained_ContainsKey_And_ContainsKey() + { + var dictionary = new Dictionary + { + ["a"] = 1, + ["b"] = 2, + ["c"] = 3 + }; + + await Assert.That(dictionary) + .ContainsKey("a") + .And.ContainsKey("b") + .And.DoesNotContainKey("missing"); + } + + [Test] + public async Task Dictionary_Chained_ContainsValue_And_DoesNotContainValue() + { + var dictionary = new Dictionary + { + ["a"] = 1, + ["b"] = 2, + ["c"] = 3 + }; + + await Assert.That(dictionary) + .ContainsValue(1) + .And.ContainsValue(2) + .And.DoesNotContainValue(999); + } + + [Test] + public async Task Dictionary_Chained_Mixed_Key_And_Value_Assertions() + { + var dictionary = new Dictionary + { + ["key1"] = 100, + ["key2"] = 200 + }; + + await Assert.That(dictionary) + .ContainsKey("key1") + .And.ContainsValue(100) + .And.DoesNotContainKey("missing") + .And.DoesNotContainValue(999); + } + + [Test] + public async Task Dictionary_Or_Chain_Works() + { + var dictionary = new Dictionary + { + ["key1"] = 1 + }; + + // Either condition is true - passes because "key1" exists + await Assert.That(dictionary) + .ContainsKey("nonexistent") + .Or.ContainsKey("key1"); + } + + [Test] + public async Task Dictionary_Chained_With_Collection_Assertions() + { + var dictionary = new Dictionary + { + ["a"] = 1, + ["b"] = 2 + }; + + await Assert.That(dictionary) + .ContainsKey("a") + .And.IsNotEmpty() + .And.HasCount(2); + } + + #endregion + + #region Failure Messages + + [Test] + public async Task ContainsKey_Failure_Has_Meaningful_Message() + { + var dictionary = new Dictionary + { + ["existing"] = 1 + }; + + var exception = await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsKey("missing")); + + // Verify that the failure message is meaningful + await Assert.That(exception.Message).Contains("contain key"); + } + + [Test] + public async Task DoesNotContainKey_Failure_Has_Meaningful_Message() + { + var dictionary = new Dictionary + { + ["existing"] = 1 + }; + + var exception = await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainKey("existing")); + + // Verify that the failure message is meaningful + await Assert.That(exception.Message).Contains("not contain key"); + } + + [Test] + public async Task ContainsValue_Failure_Has_Meaningful_Message() + { + var dictionary = new Dictionary + { + ["key"] = 1 + }; + + var exception = await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).ContainsValue(999)); + + // Verify that the failure message is meaningful + await Assert.That(exception.Message).Contains("contain value"); + } + + [Test] + public async Task DoesNotContainValue_Failure_Has_Meaningful_Message() + { + var dictionary = new Dictionary + { + ["key"] = 42 + }; + + var exception = await Assert.ThrowsAsync(async () => + await Assert.That(dictionary).DoesNotContainValue(42)); + + // Verify that the failure message is meaningful + await Assert.That(exception.Message).Contains("not contain value"); + } + + #endregion + + #region Edge Cases + + [Test] + public async Task Empty_Dictionary_DoesNotContainKey_Passes() + { + var dictionary = new Dictionary(); + + await Assert.That(dictionary).DoesNotContainKey("any"); + } + + [Test] + public async Task Empty_Dictionary_DoesNotContainValue_Passes() + { + var dictionary = new Dictionary(); + + await Assert.That(dictionary).DoesNotContainValue(42); + } + + [Test] + public async Task Dictionary_With_Null_Value_ContainsValue_Works() + { + var dictionary = new Dictionary + { + ["key1"] = "value1", + ["key2"] = null + }; + + await Assert.That(dictionary).ContainsValue(null); + } + + [Test] + public async Task Dictionary_With_Int_Keys_ContainsKey_Works() + { + var dictionary = new Dictionary + { + [1] = "one", + [2] = "two", + [3] = "three" + }; + + await Assert.That(dictionary).ContainsKey(2); + } + + [Test] + public async Task Dictionary_With_Complex_Value_ContainsValue_Works() + { + var person1 = new Person("Alice", 30); + var person2 = new Person("Bob", 25); + + var dictionary = new Dictionary + { + ["alice"] = person1, + ["bob"] = person2 + }; + + // Reference equality check - should find the exact same instance + await Assert.That(dictionary).ContainsValue(person1); + } + + [Test] + public async Task Null_Dictionary_ContainsKey_Fails() + { + IDictionary? dictionary = null; + + await Assert.ThrowsAsync(async () => + await Assert.That(dictionary!).ContainsKey("key")); + } + + #endregion + + #region Multiple Dictionary Entries + + [Test] + public async Task Dictionary_With_Many_Entries_ContainsKey_Works() + { + var dictionary = Enumerable.Range(1, 100) + .ToDictionary(i => $"key{i}", i => i); + + await Assert.That(dictionary).ContainsKey("key50"); + await Assert.That(dictionary).ContainsKey("key1"); + await Assert.That(dictionary).ContainsKey("key100"); + } + + [Test] + public async Task Dictionary_With_Many_Entries_ContainsValue_Works() + { + var dictionary = Enumerable.Range(1, 100) + .ToDictionary(i => $"key{i}", i => i * 10); + + await Assert.That(dictionary).ContainsValue(500); // key50 + await Assert.That(dictionary).ContainsValue(10); // key1 + await Assert.That(dictionary).ContainsValue(1000); // key100 + } + + #endregion + + private record Person(string Name, int Age); +} diff --git a/TUnit.Assertions/Attributes/GenerateAssertOverloadsAttribute.cs b/TUnit.Assertions/Attributes/GenerateAssertOverloadsAttribute.cs new file mode 100644 index 0000000000..18d1e713a0 --- /dev/null +++ b/TUnit.Assertions/Attributes/GenerateAssertOverloadsAttribute.cs @@ -0,0 +1,95 @@ +using System; + +namespace TUnit.Assertions.Attributes; + +/// +/// Marks an Assert.That() method for automatic generation of wrapper overloads. +/// Generates Func, Task, and ValueTask variants automatically. +/// +/// +/// +/// When applied to an Assert.That() method, the source generator will create: +/// - Wrapper types (FuncXxxAssertion, TaskXxxAssertion, etc.) implementing IAssertionSource<T> +/// - Assert.That() overloads for Func<T>, Func<Task<T>>, Func<ValueTask<T>>, Task<T>, and ValueTask<T> variants +/// +/// +/// This allows assertions to work seamlessly with lazy evaluation (Func), async operations (Task/ValueTask), +/// and async factories (Func<Task>/Func<ValueTask>). +/// +/// +/// +/// +/// public static partial class Assert +/// { +/// [GenerateAssertOverloads(Priority = 3)] +/// public static DictionaryAssertion<TKey, TValue> That<TKey, TValue>( +/// IReadOnlyDictionary<TKey, TValue>? value, +/// [CallerArgumentExpression(nameof(value))] string? expression = null) +/// { +/// return new DictionaryAssertion<TKey, TValue>(value, expression); +/// } +/// } +/// +/// // Generates overloads like: +/// // Assert.That(Func<IReadOnlyDictionary<TKey, TValue>?> func, ...) +/// // Assert.That(Task<IReadOnlyDictionary<TKey, TValue>?> task, ...) +/// // Assert.That(ValueTask<IReadOnlyDictionary<TKey, TValue>?> valueTask, ...) +/// +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] +public sealed class GenerateAssertOverloadsAttribute : Attribute +{ + /// + /// Overload resolution priority for generated overloads. + /// Higher values take precedence in overload resolution. + /// + /// + /// Use this to control which overload is selected when multiple could match. + /// When set to a non-zero value, the generated overloads will have the + /// [OverloadResolutionPriority] attribute applied with this value. + /// + public int Priority { get; set; } = 0; + + /// + /// Generate Func<T> overload. Default: true. + /// + /// + /// When true, generates an overload that accepts a Func<T> for lazy evaluation. + /// The value is only evaluated when the assertion runs. + /// + public bool Func { get; set; } = true; + + /// + /// Generate Func<Task<T>> overload. Default: true. + /// + /// + /// When true, generates an overload that accepts a Func<Task<T>> for async factory evaluation. + /// Useful when the async operation should be started fresh for each assertion. + /// + public bool FuncTask { get; set; } = true; + + /// + /// Generate Func<ValueTask<T>> overload. Default: true. + /// + /// + /// When true, generates an overload that accepts a Func<ValueTask<T>> for async factory evaluation. + /// Useful when the async operation should be started fresh for each assertion with ValueTask semantics. + /// + public bool FuncValueTask { get; set; } = true; + + /// + /// Generate Task<T> overload. Default: true. + /// + /// + /// When true, generates an overload that accepts a Task<T> for awaiting an already-started async operation. + /// + public bool Task { get; set; } = true; + + /// + /// Generate ValueTask<T> overload. Default: true. + /// + /// + /// When true, generates an overload that accepts a ValueTask<T> for awaiting an already-started async operation. + /// + public bool ValueTask { get; set; } = true; +} diff --git a/TUnit.Assertions/Conditions/DictionaryAssertionExtensions.cs b/TUnit.Assertions/Conditions/DictionaryAssertionExtensions.cs new file mode 100644 index 0000000000..ae61cfa1d0 --- /dev/null +++ b/TUnit.Assertions/Conditions/DictionaryAssertionExtensions.cs @@ -0,0 +1,30 @@ +using TUnit.Assertions.Attributes; + +namespace TUnit.Assertions.Conditions; + +/// +/// Source-generated assertions for dictionary types using [GenerateAssertion] attributes. +/// These wrap dictionary checks as extension methods. +/// +file static partial class DictionaryAssertionExtensions +{ + [GenerateAssertion(ExpectationMessage = "to contain key {expectedKey}", InlineMethodBody = true)] + public static bool ContainsKey( + this IReadOnlyDictionary dictionary, + TKey expectedKey) => dictionary.ContainsKey(expectedKey); + + [GenerateAssertion(ExpectationMessage = "to not contain key {expectedKey}", InlineMethodBody = true)] + public static bool DoesNotContainKey( + this IReadOnlyDictionary dictionary, + TKey expectedKey) => !dictionary.ContainsKey(expectedKey); + + [GenerateAssertion(ExpectationMessage = "to contain value {expectedValue}", InlineMethodBody = true)] + public static bool ContainsValue( + this IReadOnlyDictionary dictionary, + TValue expectedValue) => dictionary.Values.Contains(expectedValue); + + [GenerateAssertion(ExpectationMessage = "to not contain value {expectedValue}", InlineMethodBody = true)] + public static bool DoesNotContainValue( + this IReadOnlyDictionary dictionary, + TValue expectedValue) => !dictionary.Values.Contains(expectedValue); +} diff --git a/TUnit.Assertions/Extensions/Assert.cs b/TUnit.Assertions/Extensions/Assert.cs index c1835304b7..b6c2c51658 100644 --- a/TUnit.Assertions/Extensions/Assert.cs +++ b/TUnit.Assertions/Extensions/Assert.cs @@ -1,9 +1,11 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using TUnit.Assertions.Attributes; using TUnit.Assertions.Conditions; using TUnit.Assertions.Exceptions; using TUnit.Assertions.Sources; +using TUnit.Assertions.Wrappers; namespace TUnit.Assertions; @@ -12,13 +14,45 @@ namespace TUnit.Assertions; /// Provides Assert.That() overloads for different source types. /// [SuppressMessage("Usage", "TUnitAssertions0002:Assert statement not awaited")] -public static class Assert +public static partial class Assert { + /// + /// Creates an assertion for an IDictionary value. + /// Wraps IDictionary as IReadOnlyDictionary since IDictionary doesn't inherit from it in .NET < 9. + /// Example: await Assert.That(dict).ContainsKey("key"); + /// + /// + /// Note: Func/Task/ValueTask overloads are not generated for IDictionary because it cannot be + /// implicitly converted to IReadOnlyDictionary. Use IReadOnlyDictionary or Dictionary directly + /// for async/lazy assertions, or explicitly wrap with ReadOnlyDictionaryWrapper. + /// + [OverloadResolutionPriority(4)] + public static DictionaryAssertion That( + IDictionary? value, + [CallerArgumentExpression(nameof(value))] string? expression = null) + { + if (value is null) + { + return new DictionaryAssertion(null!, expression); + } + + // If it already implements IReadOnlyDictionary, use it directly + if (value is IReadOnlyDictionary readOnly) + { + return new DictionaryAssertion(readOnly, expression); + } + + // Wrap IDictionary as IReadOnlyDictionary + var wrapper = new ReadOnlyDictionaryWrapper(value); + return new DictionaryAssertion(wrapper, expression); + } + /// /// Creates an assertion for an IReadOnlyDictionary value. /// This overload enables better type inference for dictionary operations like ContainsKey. /// Example: await Assert.That(dict).ContainsKey("key"); /// + [GenerateAssertOverloads(Priority = 3)] [OverloadResolutionPriority(3)] public static DictionaryAssertion That( IReadOnlyDictionary value, diff --git a/TUnit.Assertions/Wrappers/ReadOnlyDictionaryWrapper.cs b/TUnit.Assertions/Wrappers/ReadOnlyDictionaryWrapper.cs new file mode 100644 index 0000000000..5d6a9b47d4 --- /dev/null +++ b/TUnit.Assertions/Wrappers/ReadOnlyDictionaryWrapper.cs @@ -0,0 +1,39 @@ +using System.Collections; +using System.Diagnostics.CodeAnalysis; + +namespace TUnit.Assertions.Wrappers; + +/// +/// Wraps an IDictionary as IReadOnlyDictionary for assertion purposes. +/// Preserves the original reference for identity comparisons. +/// +internal sealed class ReadOnlyDictionaryWrapper : IReadOnlyDictionary +{ + private readonly IDictionary _dictionary; + + /// + /// The original IDictionary reference, for IsSameReferenceAs assertions. + /// + public object OriginalReference => _dictionary; + + public ReadOnlyDictionaryWrapper(IDictionary dictionary) + => _dictionary = dictionary ?? throw new ArgumentNullException(nameof(dictionary)); + + public TValue this[TKey key] => _dictionary[key]; + public IEnumerable Keys => _dictionary.Keys; + public IEnumerable Values => _dictionary.Values; + public int Count => _dictionary.Count; + public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key); + +#if NETSTANDARD2_0 + public bool TryGetValue(TKey key, out TValue value) +#else + public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) +#endif + { + return _dictionary.TryGetValue(key, out value!); + } + + public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +}