4
4
namespace Tvl . VisualStudio . MouseFastScroll . IntegrationTests
5
5
{
6
6
using System ;
7
+ using System . Runtime . InteropServices ;
7
8
using System . Threading ;
8
9
using System . Threading . Tasks ;
9
10
using System . Windows . Automation ;
11
+ using Microsoft . VisualStudio ;
12
+ using Microsoft . Win32 . SafeHandles ;
10
13
using Xunit ;
14
+ using IMessageFilter = Microsoft . VisualStudio . OLE . Interop . IMessageFilter ;
15
+ using INTERFACEINFO = Microsoft . VisualStudio . OLE . Interop . INTERFACEINFO ;
16
+ using PENDINGMSG = Microsoft . VisualStudio . OLE . Interop . PENDINGMSG ;
17
+ using SERVERCALL = Microsoft . VisualStudio . OLE . Interop . SERVERCALL ;
11
18
12
19
[ CaptureTestName ]
13
20
[ Collection ( nameof ( SharedIntegrationHostFixture ) ) ]
14
21
public abstract class AbstractIntegrationTest : IAsyncLifetime , IDisposable
15
22
{
23
+ private readonly MessageFilter _messageFilter ;
16
24
private readonly VisualStudioInstanceFactory _instanceFactory ;
17
25
private readonly Version _version ;
18
26
private VisualStudioInstanceContext _visualStudioContext ;
19
27
20
28
protected AbstractIntegrationTest ( VisualStudioInstanceFactory instanceFactory , Version version )
21
29
{
22
30
Assert . Equal ( ApartmentState . STA , Thread . CurrentThread . GetApartmentState ( ) ) ;
31
+
32
+ // Install a COM message filter to handle retry operations when the first attempt fails
33
+ _messageFilter = RegisterMessageFilter ( ) ;
23
34
_instanceFactory = instanceFactory ;
24
35
_version = version ;
25
- Automation . TransactionTimeout = 20000 ;
36
+
37
+ try
38
+ {
39
+ Automation . TransactionTimeout = 20000 ;
40
+ }
41
+ catch
42
+ {
43
+ _messageFilter . Dispose ( ) ;
44
+ throw ;
45
+ }
26
46
}
27
47
28
48
public VisualStudioInstance VisualStudio => _visualStudioContext ? . Instance ;
29
49
30
50
public virtual async Task InitializeAsync ( )
31
51
{
32
- _visualStudioContext = await _instanceFactory . GetNewOrUsedInstanceAsync ( _version , SharedIntegrationHostFixture . RequiredPackageIds ) . ConfigureAwait ( false ) ;
52
+ try
53
+ {
54
+ _visualStudioContext = await _instanceFactory . GetNewOrUsedInstanceAsync ( _version , SharedIntegrationHostFixture . RequiredPackageIds ) . ConfigureAwait ( false ) ;
55
+ }
56
+ catch
57
+ {
58
+ _messageFilter . Dispose ( ) ;
59
+ throw ;
60
+ }
33
61
}
34
62
35
63
public Task DisposeAsync ( )
@@ -43,11 +71,126 @@ public void Dispose()
43
71
GC . SuppressFinalize ( this ) ;
44
72
}
45
73
74
+ protected virtual MessageFilter RegisterMessageFilter ( )
75
+ => new MessageFilter ( ) ;
76
+
46
77
protected virtual void Dispose ( bool disposing )
47
78
{
48
79
if ( disposing )
49
80
{
50
- _visualStudioContext . Dispose ( ) ;
81
+ try
82
+ {
83
+ _visualStudioContext . Dispose ( ) ;
84
+ }
85
+ finally
86
+ {
87
+ _messageFilter . Dispose ( ) ;
88
+ }
89
+ }
90
+ }
91
+
92
+ protected class MessageFilter : IMessageFilter , IDisposable
93
+ {
94
+ protected const uint CancelCall = ~ 0U ;
95
+
96
+ private readonly MessageFilterSafeHandle _messageFilterRegistration ;
97
+ private readonly TimeSpan _timeout ;
98
+ private readonly TimeSpan _retryDelay ;
99
+
100
+ public MessageFilter ( )
101
+ : this ( timeout : TimeSpan . FromSeconds ( 60 ) , retryDelay : TimeSpan . FromMilliseconds ( 150 ) )
102
+ {
103
+ }
104
+
105
+ public MessageFilter ( TimeSpan timeout , TimeSpan retryDelay )
106
+ {
107
+ _timeout = timeout ;
108
+ _retryDelay = retryDelay ;
109
+ _messageFilterRegistration = MessageFilterSafeHandle . Register ( this ) ;
110
+ }
111
+
112
+ public virtual uint HandleInComingCall ( uint dwCallType , IntPtr htaskCaller , uint dwTickCount , INTERFACEINFO [ ] lpInterfaceInfo )
113
+ {
114
+ return ( uint ) SERVERCALL . SERVERCALL_ISHANDLED ;
115
+ }
116
+
117
+ public virtual uint RetryRejectedCall ( IntPtr htaskCallee , uint dwTickCount , uint dwRejectType )
118
+ {
119
+ if ( ( SERVERCALL ) dwRejectType != SERVERCALL . SERVERCALL_RETRYLATER
120
+ && ( SERVERCALL ) dwRejectType != SERVERCALL . SERVERCALL_REJECTED )
121
+ {
122
+ return CancelCall ;
123
+ }
124
+
125
+ if ( dwTickCount >= _timeout . TotalMilliseconds )
126
+ {
127
+ return CancelCall ;
128
+ }
129
+
130
+ return ( uint ) _retryDelay . TotalMilliseconds ;
131
+ }
132
+
133
+ public virtual uint MessagePending ( IntPtr htaskCallee , uint dwTickCount , uint dwPendingType )
134
+ {
135
+ return ( uint ) PENDINGMSG . PENDINGMSG_WAITDEFPROCESS ;
136
+ }
137
+
138
+ protected virtual void Dispose ( bool disposing )
139
+ {
140
+ if ( disposing )
141
+ {
142
+ _messageFilterRegistration . Dispose ( ) ;
143
+ }
144
+ }
145
+
146
+ public void Dispose ( )
147
+ {
148
+ Dispose ( true ) ;
149
+ GC . SuppressFinalize ( this ) ;
150
+ }
151
+ }
152
+
153
+ private sealed class MessageFilterSafeHandle : SafeHandleMinusOneIsInvalid
154
+ {
155
+ private readonly IntPtr _oldFilter ;
156
+
157
+ private MessageFilterSafeHandle ( IntPtr handle )
158
+ : base ( true )
159
+ {
160
+ SetHandle ( handle ) ;
161
+
162
+ try
163
+ {
164
+ if ( CoRegisterMessageFilter ( handle , out _oldFilter ) != VSConstants . S_OK )
165
+ {
166
+ throw new InvalidOperationException ( "Failed to register a new message filter" ) ;
167
+ }
168
+ }
169
+ catch
170
+ {
171
+ SetHandleAsInvalid ( ) ;
172
+ throw ;
173
+ }
174
+ }
175
+
176
+ [ DllImport ( "ole32" , SetLastError = true ) ]
177
+ private static extern int CoRegisterMessageFilter ( IntPtr messageFilter , out IntPtr oldMessageFilter ) ;
178
+
179
+ public static MessageFilterSafeHandle Register < T > ( T messageFilter )
180
+ where T : IMessageFilter
181
+ {
182
+ var handle = Marshal . GetComInterfaceForObject < T , IMessageFilter > ( messageFilter ) ;
183
+ return new MessageFilterSafeHandle ( handle ) ;
184
+ }
185
+
186
+ protected override bool ReleaseHandle ( )
187
+ {
188
+ if ( CoRegisterMessageFilter ( _oldFilter , out _ ) == VSConstants . S_OK )
189
+ {
190
+ Marshal . Release ( handle ) ;
191
+ }
192
+
193
+ return true ;
51
194
}
52
195
}
53
196
}
0 commit comments