@@ -20,7 +20,7 @@ namespace Xunit.Analyzers.Fixes;
2020[ ExportCodeFixProvider ( LanguageNames . CSharp ) , Shared ]
2121public class AssertSingleShouldBeUsedForSingleParameterFixer : XunitCodeFixProvider
2222{
23- private const string DefaultParameterName = "item" ;
23+ const string DefaultParameterName = "item" ;
2424 public const string Key_UseSingleMethod = "xUnit2023_UseSingleMethod" ;
2525
2626 public AssertSingleShouldBeUsedForSingleParameterFixer ( ) :
@@ -51,13 +51,19 @@ static IEnumerable<SyntaxNode> GetLambdaStatements(SimpleLambdaExpressionSyntax
5151
5252 static SyntaxNode GetMethodInvocation (
5353 IdentifierNameSyntax methodExpression ,
54- string parameterName ) =>
55- ExpressionStatement (
56- InvocationExpression (
57- methodExpression ,
58- ArgumentList ( SingletonSeparatedList ( Argument ( IdentifierName ( parameterName ) ) ) )
59- )
60- ) ;
54+ string parameterName ,
55+ bool needAwait )
56+ {
57+ ExpressionSyntax invocation = InvocationExpression (
58+ methodExpression ,
59+ ArgumentList ( SingletonSeparatedList ( Argument ( IdentifierName ( parameterName ) ) ) )
60+ ) ;
61+
62+ if ( needAwait )
63+ invocation = AwaitExpression ( invocation ) ;
64+
65+ return ExpressionStatement ( invocation ) ;
66+ }
6167
6268 static LocalDeclarationStatementSyntax OneItemVariableStatement (
6369 string parameterName ,
@@ -86,17 +92,16 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
8692 if ( invocation is null )
8793 return ;
8894
89- var diagnostic = context . Diagnostics . FirstOrDefault ( ) ;
90- if ( diagnostic is null )
95+ if ( context . Diagnostics . FirstOrDefault ( ) is not Diagnostic diagnostic )
9196 return ;
92- if ( ! diagnostic . Properties . TryGetValue ( Constants . Properties . Replacement , out var replacement ) )
97+ if ( ! diagnostic . Properties . TryGetValue ( Constants . Properties . AssertMethodName , out var assertMethodName ) || assertMethodName is null )
9398 return ;
94- if ( replacement is null )
99+ if ( ! diagnostic . Properties . TryGetValue ( Constants . Properties . Replacement , out var replacement ) || replacement is null )
95100 return ;
96101
97102 context . RegisterCodeFix (
98103 XunitCodeAction . Create (
99- ct => UseSingleMethod ( context . Document , invocation , replacement , ct ) ,
104+ ct => UseSingleMethod ( context . Document , invocation , assertMethodName , replacement , ct ) ,
100105 Key_UseSingleMethod ,
101106 "Use Assert.{0}" , replacement
102107 ) ,
@@ -107,13 +112,13 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
107112 static async Task < Document > UseSingleMethod (
108113 Document document ,
109114 InvocationExpressionSyntax invocation ,
115+ string assertMethodName ,
110116 string replacementMethod ,
111117 CancellationToken cancellationToken )
112118 {
113119 var editor = await DocumentEditor . CreateAsync ( document , cancellationToken ) . ConfigureAwait ( false ) ;
114120
115- if ( invocation . Expression is MemberAccessExpressionSyntax memberAccess &&
116- invocation . ArgumentList . Arguments [ 0 ] . Expression is IdentifierNameSyntax collectionVariable )
121+ if ( invocation . Expression is MemberAccessExpressionSyntax memberAccess )
117122 {
118123 var semanticModel = await document . GetSemanticModelAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
119124 if ( semanticModel is not null && invocation . Parent is not null )
@@ -123,9 +128,13 @@ static async Task<Document> UseSingleMethod(
123128 var localSymbols = semanticModel . LookupSymbols ( startLocation ) . OfType < ILocalSymbol > ( ) . Select ( s => s . Name ) . ToImmutableHashSet ( ) ;
124129 var replacementNode =
125130 invocation
126- . WithArgumentList ( ArgumentList ( SeparatedList ( [ Argument ( collectionVariable ) ] ) ) )
131+ . WithArgumentList ( ArgumentList ( SeparatedList ( [ Argument ( invocation . ArgumentList . Arguments [ 0 ] . Expression ) ] ) ) )
127132 . WithExpression ( memberAccess . WithName ( IdentifierName ( replacementMethod ) ) ) ;
128133
134+ // We want to replace the whole expression, because it may include an unnecessary await, as we may be
135+ // converting from Assert.CollectionAsync (which needs await) to Assert.Single (which does not).
136+ var nodeToReplace = invocation . FirstAncestorOrSelf < ExpressionStatementSyntax > ( ) ?? invocation . Parent ;
137+
129138 if ( invocation . ArgumentList . Arguments [ 1 ] . Expression is SimpleLambdaExpressionSyntax lambdaExpression )
130139 {
131140 var originalParameterName = lambdaExpression . Parameter . Identifier . Text ;
@@ -139,11 +148,12 @@ static async Task<Document> UseSingleMethod(
139148 . DescendantTokens ( )
140149 . Where ( t => t . IsKind ( SyntaxKind . IdentifierToken ) && t . Text == originalParameterName )
141150 . ToArray ( ) ;
151+
142152 body = body . ReplaceTokens ( tokens , ( t1 , t2 ) => Identifier ( t2 . LeadingTrivia , parameterName , t2 . TrailingTrivia ) ) ;
143153 lambdaExpression = lambdaExpression . WithBody ( body ) ;
144154 }
145155
146- statements . Add ( OneItemVariableStatement ( parameterName , replacementNode ) . WithTriviaFrom ( invocation . Parent ) ) ;
156+ statements . Add ( OneItemVariableStatement ( parameterName , replacementNode ) . WithTriviaFrom ( nodeToReplace ) ) ;
147157 statements . AddRange ( GetLambdaStatements ( lambdaExpression ) ) ;
148158 }
149159 else if ( invocation . ArgumentList . Arguments [ 1 ] . Expression is IdentifierNameSyntax identifierExpression )
@@ -153,17 +163,13 @@ static async Task<Document> UseSingleMethod(
153163 {
154164 var parameterName = GetSafeVariableName ( DefaultParameterName , localSymbols ) ;
155165
156- var oneItemVariableStatement =
157- OneItemVariableStatement ( parameterName , replacementNode )
158- . WithLeadingTrivia ( invocation . Parent . GetLeadingTrivia ( ) ) ;
159-
160- statements . Add ( OneItemVariableStatement ( parameterName , replacementNode ) . WithTriviaFrom ( invocation . Parent ) ) ;
161- statements . Add ( GetMethodInvocation ( identifierExpression , parameterName ) ) ;
166+ statements . Add ( OneItemVariableStatement ( parameterName , replacementNode ) . WithTriviaFrom ( nodeToReplace ) ) ;
167+ statements . Add ( GetMethodInvocation ( identifierExpression , parameterName , needAwait : assertMethodName == Constants . Asserts . CollectionAsync ) ) ;
162168 }
163169 }
164170
165- editor . InsertBefore ( invocation . Parent , statements ) ;
166- editor . RemoveNode ( invocation . Parent ) ;
171+ editor . InsertBefore ( nodeToReplace , statements ) ;
172+ editor . RemoveNode ( nodeToReplace ) ;
167173 }
168174 }
169175
0 commit comments