4040import java .net .StandardSocketOptions ;
4141import java .nio .ByteBuffer ;
4242import java .nio .channels .CompletionHandler ;
43+ import java .nio .channels .InterruptedByTimeoutException ;
4344import java .nio .channels .SelectionKey ;
4445import java .nio .channels .Selector ;
4546import java .nio .channels .SocketChannel ;
4950import java .util .concurrent .ExecutorService ;
5051import java .util .concurrent .Future ;
5152import java .util .concurrent .TimeUnit ;
53+ import java .util .concurrent .atomic .AtomicReference ;
5254
5355import static com .mongodb .assertions .Assertions .assertTrue ;
5456import static com .mongodb .assertions .Assertions .isTrue ;
@@ -97,21 +99,40 @@ public void close() {
9799 group .shutdown ();
98100 }
99101
102+ /**
103+ * Monitors `OP_CONNECT` events for socket connections.
104+ */
100105 private static class SelectorMonitor implements Closeable {
101106
102- private static final class Pair {
107+ static final class SocketRegistration {
103108 private final SocketChannel socketChannel ;
104- private final Runnable attachment ;
109+ private final AtomicReference < Runnable > afterConnectAction ;
105110
106- private Pair (final SocketChannel socketChannel , final Runnable attachment ) {
111+ SocketRegistration (final SocketChannel socketChannel , final Runnable afterConnectAction ) {
107112 this .socketChannel = socketChannel ;
108- this .attachment = attachment ;
113+ this .afterConnectAction = new AtomicReference <>(afterConnectAction );
114+ }
115+
116+ boolean tryCancelPendingConnection () {
117+ return tryTakeAction () != null ;
118+ }
119+
120+ void runAfterConnectActionIfNotCanceled () {
121+ Runnable afterConnectActionToExecute = tryTakeAction ();
122+ if (afterConnectActionToExecute != null ) {
123+ afterConnectActionToExecute .run ();
124+ }
125+ }
126+
127+ @ Nullable
128+ private Runnable tryTakeAction () {
129+ return afterConnectAction .getAndSet (null );
109130 }
110131 }
111132
112133 private final Selector selector ;
113134 private volatile boolean isClosed ;
114- private final ConcurrentLinkedDeque <Pair > pendingRegistrations = new ConcurrentLinkedDeque <>();
135+ private final ConcurrentLinkedDeque <SocketRegistration > pendingRegistrations = new ConcurrentLinkedDeque <>();
115136
116137 SelectorMonitor () {
117138 try {
@@ -127,17 +148,14 @@ void start() {
127148 while (!isClosed ) {
128149 try {
129150 selector .select ();
130-
131151 for (SelectionKey selectionKey : selector .selectedKeys ()) {
132152 selectionKey .cancel ();
133- Runnable runnable = (Runnable ) selectionKey .attachment ();
134- runnable .run ();
153+ ((SocketRegistration ) selectionKey .attachment ()).runAfterConnectActionIfNotCanceled ();
135154 }
136155
137- for (Iterator <Pair > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
138- Pair pendingRegistration = iter .next ();
139- pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT ,
140- pendingRegistration .attachment );
156+ for (Iterator <SocketRegistration > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
157+ SocketRegistration pendingRegistration = iter .next ();
158+ pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT , pendingRegistration );
141159 iter .remove ();
142160 }
143161 } catch (Exception e ) {
@@ -156,8 +174,8 @@ void start() {
156174 selectorThread .start ();
157175 }
158176
159- void register (final SocketChannel channel , final Runnable attachment ) {
160- pendingRegistrations .add (new Pair ( channel , attachment ) );
177+ void register (final SocketRegistration registration ) {
178+ pendingRegistrations .add (registration );
161179 selector .wakeup ();
162180 }
163181
@@ -200,44 +218,79 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
200218 if (getSettings ().getSendBufferSize () > 0 ) {
201219 socketChannel .setOption (StandardSocketOptions .SO_SNDBUF , getSettings ().getSendBufferSize ());
202220 }
203-
221+ //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
222+ int connectTimeoutMs = operationContext .getTimeoutContext ().getConnectTimeoutMs ();
204223 socketChannel .connect (getSocketAddresses (getServerAddress (), inetAddressResolver ).get (0 ));
224+ SelectorMonitor .SocketRegistration socketRegistration = new SelectorMonitor .SocketRegistration (
225+ socketChannel , () -> initializeTslChannel (handler , socketChannel ));
205226
206- selectorMonitor .register (socketChannel , () -> {
207- try {
208- if (!socketChannel .finishConnect ()) {
209- throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
210- }
227+ if (connectTimeoutMs > 0 ) {
228+ scheduleTimeoutInterruption (handler , socketRegistration , connectTimeoutMs );
229+ }
230+ selectorMonitor .register (socketRegistration );
231+ } catch (IOException e ) {
232+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
233+ } catch (Throwable t ) {
234+ handler .failed (t );
235+ }
236+ }
211237
212- SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
213- getServerAddress ().getPort ());
214- sslEngine .setUseClientMode (true );
238+ private void scheduleTimeoutInterruption (final AsyncCompletionHandler <Void > handler ,
239+ final SelectorMonitor .SocketRegistration socketRegistration ,
240+ final int connectTimeoutMs ) {
241+ group .getTimeoutExecutor ().schedule (() -> {
242+ if (socketRegistration .tryCancelPendingConnection ()) {
243+ closeAndTimeout (handler , socketRegistration .socketChannel );
244+ }
245+ }, connectTimeoutMs , TimeUnit .MILLISECONDS );
246+ }
215247
216- SSLParameters sslParameters = sslEngine .getSSLParameters ();
217- enableSni (getServerAddress ().getHost (), sslParameters );
248+ private void closeAndTimeout (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
249+ // We check if this stream was closed before timeout exception.
250+ boolean streamClosed = isClosed ();
251+ InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException ();
252+ try {
253+ socketChannel .close ();
254+ } catch (Exception e ) {
255+ timeoutException .addSuppressed (e );
256+ }
218257
219- if (!sslSettings .isInvalidHostNameAllowed ()) {
220- enableHostNameVerification (sslParameters );
221- }
222- sslEngine .setSSLParameters (sslParameters );
258+ if (streamClosed ) {
259+ handler .completed (null );
260+ } else {
261+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getAddress (), timeoutException ));
262+ }
263+ }
223264
224- BufferAllocator bufferAllocator = new BufferProviderAllocator ();
265+ private void initializeTslChannel (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
266+ try {
267+ if (!socketChannel .finishConnect ()) {
268+ throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
269+ }
225270
226- TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
227- .withEncryptedBufferAllocator (bufferAllocator )
228- .withPlainBufferAllocator (bufferAllocator )
229- .build ();
271+ SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
272+ getServerAddress ().getPort ());
273+ sslEngine .setUseClientMode (true );
230274
231- // build asynchronous channel, based in the TLS channel and associated with the global group.
232- setChannel ( new AsynchronousTlsChannelAdapter ( new AsynchronousTlsChannel ( group , tlsChannel , socketChannel )) );
275+ SSLParameters sslParameters = sslEngine . getSSLParameters ();
276+ enableSni ( getServerAddress (). getHost (), sslParameters );
233277
234- handler .completed (null );
235- } catch (IOException e ) {
236- handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
237- } catch (Throwable t ) {
238- handler .failed (t );
239- }
240- });
278+ if (!sslSettings .isInvalidHostNameAllowed ()) {
279+ enableHostNameVerification (sslParameters );
280+ }
281+ sslEngine .setSSLParameters (sslParameters );
282+
283+ BufferAllocator bufferAllocator = new BufferProviderAllocator ();
284+
285+ TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
286+ .withEncryptedBufferAllocator (bufferAllocator )
287+ .withPlainBufferAllocator (bufferAllocator )
288+ .build ();
289+
290+ // build asynchronous channel, based in the TLS channel and associated with the global group.
291+ setChannel (new AsynchronousTlsChannelAdapter (new AsynchronousTlsChannel (group , tlsChannel , socketChannel )));
292+
293+ handler .completed (null );
241294 } catch (IOException e ) {
242295 handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
243296 } catch (Throwable t ) {
0 commit comments