Skip to content

Commit 5a0ba2f

Browse files
committed
Update xUnit2023 to support Assert.CollectionAsync
1 parent 301689b commit 5a0ba2f

File tree

4 files changed

+348
-210
lines changed

4 files changed

+348
-210
lines changed

src/xunit.analyzers.fixes/X2000/AssertSingleShouldBeUsedForSingleParameterFixer.cs

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace Xunit.Analyzers.Fixes;
2020
[ExportCodeFixProvider(LanguageNames.CSharp), Shared]
2121
public class AssertSingleShouldBeUsedForSingleParameterFixer : XunitCodeFixProvider
2222
{
23-
private const string DefaultParameterName = "item";
23+
const string DefaultParameterName = "item";
2424
public const string Key_UseSingleMethod = "xUnit2023_UseSingleMethod";
2525

2626
public AssertSingleShouldBeUsedForSingleParameterFixer() :
@@ -51,13 +51,19 @@ static IEnumerable<SyntaxNode> GetLambdaStatements(SimpleLambdaExpressionSyntax
5151

5252
static SyntaxNode GetMethodInvocation(
5353
IdentifierNameSyntax methodExpression,
54-
string parameterName) =>
55-
ExpressionStatement(
56-
InvocationExpression(
57-
methodExpression,
58-
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(parameterName))))
59-
)
60-
);
54+
string parameterName,
55+
bool needAwait)
56+
{
57+
ExpressionSyntax invocation = InvocationExpression(
58+
methodExpression,
59+
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(parameterName))))
60+
);
61+
62+
if (needAwait)
63+
invocation = AwaitExpression(invocation);
64+
65+
return ExpressionStatement(invocation);
66+
}
6167

6268
static LocalDeclarationStatementSyntax OneItemVariableStatement(
6369
string parameterName,
@@ -86,17 +92,16 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
8692
if (invocation is null)
8793
return;
8894

89-
var diagnostic = context.Diagnostics.FirstOrDefault();
90-
if (diagnostic is null)
95+
if (context.Diagnostics.FirstOrDefault() is not Diagnostic diagnostic)
9196
return;
92-
if (!diagnostic.Properties.TryGetValue(Constants.Properties.Replacement, out var replacement))
97+
if (!diagnostic.Properties.TryGetValue(Constants.Properties.AssertMethodName, out var assertMethodName) || assertMethodName is null)
9398
return;
94-
if (replacement is null)
99+
if (!diagnostic.Properties.TryGetValue(Constants.Properties.Replacement, out var replacement) || replacement is null)
95100
return;
96101

97102
context.RegisterCodeFix(
98103
XunitCodeAction.Create(
99-
ct => UseSingleMethod(context.Document, invocation, replacement, ct),
104+
ct => UseSingleMethod(context.Document, invocation, assertMethodName, replacement, ct),
100105
Key_UseSingleMethod,
101106
"Use Assert.{0}", replacement
102107
),
@@ -107,13 +112,13 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
107112
static async Task<Document> UseSingleMethod(
108113
Document document,
109114
InvocationExpressionSyntax invocation,
115+
string assertMethodName,
110116
string replacementMethod,
111117
CancellationToken cancellationToken)
112118
{
113119
var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
114120

115-
if (invocation.Expression is MemberAccessExpressionSyntax memberAccess &&
116-
invocation.ArgumentList.Arguments[0].Expression is IdentifierNameSyntax collectionVariable)
121+
if (invocation.Expression is MemberAccessExpressionSyntax memberAccess)
117122
{
118123
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
119124
if (semanticModel is not null && invocation.Parent is not null)
@@ -123,9 +128,13 @@ static async Task<Document> UseSingleMethod(
123128
var localSymbols = semanticModel.LookupSymbols(startLocation).OfType<ILocalSymbol>().Select(s => s.Name).ToImmutableHashSet();
124129
var replacementNode =
125130
invocation
126-
.WithArgumentList(ArgumentList(SeparatedList([Argument(collectionVariable)])))
131+
.WithArgumentList(ArgumentList(SeparatedList([Argument(invocation.ArgumentList.Arguments[0].Expression)])))
127132
.WithExpression(memberAccess.WithName(IdentifierName(replacementMethod)));
128133

134+
// We want to replace the whole expression, because it may include an unnecessary await, as we may be
135+
// converting from Assert.CollectionAsync (which needs await) to Assert.Single (which does not).
136+
var nodeToReplace = invocation.FirstAncestorOrSelf<ExpressionStatementSyntax>() ?? invocation.Parent;
137+
129138
if (invocation.ArgumentList.Arguments[1].Expression is SimpleLambdaExpressionSyntax lambdaExpression)
130139
{
131140
var originalParameterName = lambdaExpression.Parameter.Identifier.Text;
@@ -139,11 +148,12 @@ static async Task<Document> UseSingleMethod(
139148
.DescendantTokens()
140149
.Where(t => t.IsKind(SyntaxKind.IdentifierToken) && t.Text == originalParameterName)
141150
.ToArray();
151+
142152
body = body.ReplaceTokens(tokens, (t1, t2) => Identifier(t2.LeadingTrivia, parameterName, t2.TrailingTrivia));
143153
lambdaExpression = lambdaExpression.WithBody(body);
144154
}
145155

146-
statements.Add(OneItemVariableStatement(parameterName, replacementNode).WithTriviaFrom(invocation.Parent));
156+
statements.Add(OneItemVariableStatement(parameterName, replacementNode).WithTriviaFrom(nodeToReplace));
147157
statements.AddRange(GetLambdaStatements(lambdaExpression));
148158
}
149159
else if (invocation.ArgumentList.Arguments[1].Expression is IdentifierNameSyntax identifierExpression)
@@ -153,17 +163,13 @@ static async Task<Document> UseSingleMethod(
153163
{
154164
var parameterName = GetSafeVariableName(DefaultParameterName, localSymbols);
155165

156-
var oneItemVariableStatement =
157-
OneItemVariableStatement(parameterName, replacementNode)
158-
.WithLeadingTrivia(invocation.Parent.GetLeadingTrivia());
159-
160-
statements.Add(OneItemVariableStatement(parameterName, replacementNode).WithTriviaFrom(invocation.Parent));
161-
statements.Add(GetMethodInvocation(identifierExpression, parameterName));
166+
statements.Add(OneItemVariableStatement(parameterName, replacementNode).WithTriviaFrom(nodeToReplace));
167+
statements.Add(GetMethodInvocation(identifierExpression, parameterName, needAwait: assertMethodName == Constants.Asserts.CollectionAsync));
162168
}
163169
}
164170

165-
editor.InsertBefore(invocation.Parent, statements);
166-
editor.RemoveNode(invocation.Parent);
171+
editor.InsertBefore(nodeToReplace, statements);
172+
editor.RemoveNode(nodeToReplace);
167173
}
168174
}
169175

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,93 @@
11
using System.Threading.Tasks;
2+
using Microsoft.CodeAnalysis.Testing;
23
using Xunit;
34
using Verify = CSharpVerifier<Xunit.Analyzers.AssertSingleShouldBeUsedForSingleParameter>;
45

56
public class AssertSingleShouldBeUsedForSingleParameterTests
67
{
7-
[Theory]
8-
[InlineData("default(IEnumerable<int>)")]
9-
#if NETCOREAPP3_0_OR_GREATER
10-
[InlineData("default(IAsyncEnumerable<int>)")]
11-
#endif
12-
public async Task ForSingleItemCollectionCheck_Triggers(string collection)
8+
[Fact]
9+
public async ValueTask EnumerableAcceptanceTest()
1310
{
14-
var code = string.Format(/* lang=c#-test */ """
15-
using Xunit;
11+
var code = /* lang=c#-test */ """
1612
using System.Collections.Generic;
13+
using System.Threading.Tasks;
14+
using Xunit;
1715
18-
public class TestClass {{
16+
public class TestClass {
1917
[Fact]
20-
public void TestMethod() {{
21-
{{|#0:Assert.Collection({0}, item => Assert.NotNull(item))|}};
22-
}}
23-
}}
24-
""", collection);
25-
var expected = Verify.Diagnostic().WithLocation(0).WithArguments("Collection");
18+
public async Task TestMethod() {
19+
{|#0:Assert.Collection(
20+
default(IEnumerable<object>),
21+
item => Assert.NotNull(item)
22+
)|};
23+
Assert.Collection(
24+
default(IEnumerable<object>),
25+
item => Assert.NotNull(item),
26+
item => Assert.NotNull(item)
27+
);
28+
29+
await {|#1:Assert.CollectionAsync(
30+
default(IEnumerable<Task<int>>),
31+
async item => Assert.NotNull(item)
32+
)|};
33+
await Assert.CollectionAsync(
34+
default(IEnumerable<Task<int>>),
35+
async item => Assert.Equal(42, await item),
36+
async item => Assert.Equal(2112, await item)
37+
);
38+
}
39+
}
40+
""";
41+
var expected = new DiagnosticResult[] {
42+
Verify.Diagnostic().WithLocation(0).WithArguments("Collection"),
43+
Verify.Diagnostic().WithLocation(1).WithArguments("CollectionAsync"),
44+
};
2645

2746
await Verify.VerifyAnalyzer(code, expected);
2847
}
2948

30-
[Theory]
31-
[InlineData("default(IEnumerable<int>)")]
3249
#if NETCOREAPP3_0_OR_GREATER
33-
[InlineData("default(IAsyncEnumerable<int>)")]
34-
#endif
35-
public async Task ForMultipleItemCollectionCheck_DoesNotTrigger(string collection)
50+
51+
[Fact]
52+
public async ValueTask AsyncEnumerableAcceptanceTest()
3653
{
37-
var code = string.Format(/* lang=c#-test */ """
38-
using Xunit;
54+
var code = /* lang=c#-test */ """
3955
using System.Collections.Generic;
56+
using System.Threading.Tasks;
57+
using Xunit;
4058
41-
public class TestClass {{
59+
public class TestClass {
4260
[Fact]
43-
public void TestMethod() {{
44-
Assert.Collection({0}, item1 => Assert.NotNull(item1), item2 => Assert.NotNull(item2));
45-
}}
46-
}}
47-
""", collection);
61+
public async Task TestMethod() {
62+
{|#0:Assert.Collection(
63+
default(IAsyncEnumerable<object>),
64+
item => Assert.NotNull(item)
65+
)|};
66+
Assert.Collection(
67+
default(IAsyncEnumerable<object>),
68+
item => Assert.NotNull(item),
69+
item => Assert.NotNull(item)
70+
);
71+
72+
await {|#1:Assert.CollectionAsync(
73+
default(IAsyncEnumerable<Task<int>>),
74+
async item => Assert.NotNull(item)
75+
)|};
76+
await Assert.CollectionAsync(
77+
default(IAsyncEnumerable<Task<int>>),
78+
async item => Assert.Equal(42, await item),
79+
async item => Assert.Equal(2112, await item)
80+
);
81+
}
82+
}
83+
""";
84+
var expected = new DiagnosticResult[] {
85+
Verify.Diagnostic().WithLocation(0).WithArguments("Collection"),
86+
Verify.Diagnostic().WithLocation(1).WithArguments("CollectionAsync"),
87+
};
4888

49-
await Verify.VerifyAnalyzer(code);
89+
await Verify.VerifyAnalyzer(code, expected);
5090
}
91+
92+
#endif // NETCOREAPP3_0_OR_GREATER
5193
}

0 commit comments

Comments
 (0)