1- // Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+ // Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22//
33// This software, the RabbitMQ Java client library, is triple-licensed under the
44// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
2323import java .nio .channels .ReadableByteChannel ;
2424import java .nio .channels .SocketChannel ;
2525import java .nio .channels .WritableByteChannel ;
26+ import org .slf4j .Logger ;
27+ import org .slf4j .LoggerFactory ;
2628
2729import static javax .net .ssl .SSLEngineResult .HandshakeStatus .FINISHED ;
30+ import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NEED_TASK ;
31+ import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NEED_WRAP ;
2832import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NOT_HANDSHAKING ;
2933
3034/**
3135 *
3236 */
3337public class SslEngineHelper {
3438
39+ private static final Logger LOGGER = LoggerFactory .getLogger (SslEngineHelper .class );
40+
3541 public static boolean doHandshake (SocketChannel socketChannel , SSLEngine engine ) throws IOException {
3642
3743 ByteBuffer plainOut = ByteBuffer .allocate (engine .getSession ().getApplicationBufferSize ());
3844 ByteBuffer plainIn = ByteBuffer .allocate (engine .getSession ().getApplicationBufferSize ());
3945 ByteBuffer cipherOut = ByteBuffer .allocate (engine .getSession ().getPacketBufferSize ());
4046 ByteBuffer cipherIn = ByteBuffer .allocate (engine .getSession ().getPacketBufferSize ());
4147
48+ LOGGER .debug ("Starting TLS handshake" );
49+
4250 SSLEngineResult .HandshakeStatus handshakeStatus = engine .getHandshakeStatus ();
51+ LOGGER .debug ("Initial handshake status is {}" , handshakeStatus );
4352 while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING ) {
53+ LOGGER .debug ("Handshake status is {}" , handshakeStatus );
4454 switch (handshakeStatus ) {
4555 case NEED_TASK :
56+ LOGGER .debug ("Running tasks" );
4657 handshakeStatus = runDelegatedTasks (engine );
4758 break ;
4859 case NEED_UNWRAP :
60+ LOGGER .debug ("Unwrapping..." );
4961 handshakeStatus = unwrap (cipherIn , plainIn , socketChannel , engine );
5062 break ;
5163 case NEED_WRAP :
64+ LOGGER .debug ("Wrapping..." );
5265 handshakeStatus = wrap (plainOut , cipherOut , socketChannel , engine );
5366 break ;
67+ case FINISHED :
68+ break ;
69+ case NOT_HANDSHAKING :
70+ break ;
71+ default :
72+ throw new SSLException ("Unexpected handshake status " + handshakeStatus );
5473 }
5574 }
75+
76+
77+ LOGGER .debug ("TLS handshake completed" );
5678 return true ;
5779 }
5880
5981 private static SSLEngineResult .HandshakeStatus runDelegatedTasks (SSLEngine sslEngine ) {
6082 // FIXME run in executor?
6183 Runnable runnable ;
6284 while ((runnable = sslEngine .getDelegatedTask ()) != null ) {
85+ LOGGER .debug ("Running delegated task" );
6386 runnable .run ();
6487 }
6588 return sslEngine .getHandshakeStatus ();
@@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn
6891 private static SSLEngineResult .HandshakeStatus unwrap (ByteBuffer cipherIn , ByteBuffer plainIn ,
6992 ReadableByteChannel channel , SSLEngine sslEngine ) throws IOException {
7093 SSLEngineResult .HandshakeStatus handshakeStatus = sslEngine .getHandshakeStatus ();
71-
72- if (channel .read (cipherIn ) < 0 ) {
73- throw new SSLException ("Could not read from socket channel" );
94+ LOGGER .debug ("Handshake status is {} before unwrapping" , handshakeStatus );
95+
96+ LOGGER .debug ("Cipher in position {}" , cipherIn .position ());
97+ int read ;
98+ if (cipherIn .position () == 0 ) {
99+ LOGGER .debug ("Reading from channel" );
100+ read = channel .read (cipherIn );
101+ LOGGER .debug ("Read {} byte(s) from channel" , read );
102+ if (read < 0 ) {
103+ throw new SSLException ("Could not read from socket channel" );
104+ }
105+ cipherIn .flip ();
106+ } else {
107+ LOGGER .debug ("Not reading" );
74108 }
75- cipherIn .flip ();
76109
77110 SSLEngineResult .Status status ;
111+ SSLEngineResult unwrapResult ;
78112 do {
79- SSLEngineResult unwrapResult = sslEngine .unwrap (cipherIn , plainIn );
113+ int positionBeforeUnwrapping = cipherIn .position ();
114+ unwrapResult = sslEngine .unwrap (cipherIn , plainIn );
115+ LOGGER .debug ("SSL engine result is {} after unwrapping" , unwrapResult );
80116 status = unwrapResult .getStatus ();
81117 switch (status ) {
82118 case OK :
83119 plainIn .clear ();
84- handshakeStatus = runDelegatedTasks (sslEngine );
120+ if (unwrapResult .getHandshakeStatus () == NEED_TASK ) {
121+ handshakeStatus = runDelegatedTasks (sslEngine );
122+ int newPosition = positionBeforeUnwrapping + unwrapResult .bytesConsumed ();
123+ if (newPosition == cipherIn .limit ()) {
124+ LOGGER .debug ("Clearing cipherIn because all bytes have been read and unwrapped" );
125+ cipherIn .clear ();
126+ } else {
127+ LOGGER .debug ("Setting cipherIn position to {} (limit is {})" , newPosition , cipherIn .limit ());
128+ cipherIn .position (positionBeforeUnwrapping + unwrapResult .bytesConsumed ());
129+ }
130+ } else {
131+ handshakeStatus = unwrapResult .getHandshakeStatus ();
132+ }
85133 break ;
86134 case BUFFER_OVERFLOW :
87135 throw new SSLException ("Buffer overflow during handshake" );
88136 case BUFFER_UNDERFLOW :
137+ LOGGER .debug ("Buffer underflow" );
89138 cipherIn .compact ();
90- int read = NioHelper .read (channel , cipherIn );
139+ LOGGER .debug ("Reading from channel..." );
140+ read = NioHelper .read (channel , cipherIn );
91141 if (read <= 0 ) {
92142 retryRead (channel , cipherIn );
93143 }
144+ LOGGER .debug ("Done reading from channel..." );
94145 cipherIn .flip ();
95146 break ;
96147 case CLOSED :
@@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB
100151 throw new SSLException ("Unexpected status from " + unwrapResult );
101152 }
102153 }
103- while (cipherIn . hasRemaining () );
154+ while (unwrapResult . getHandshakeStatus () != NEED_WRAP && unwrapResult . getHandshakeStatus () != FINISHED );
104155
105- cipherIn . compact ( );
156+ LOGGER . debug ( " cipherIn position after unwrap {}" , cipherIn . position () );
106157 return handshakeStatus ;
107158 }
108159
@@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr
127178 private static SSLEngineResult .HandshakeStatus wrap (ByteBuffer plainOut , ByteBuffer cipherOut ,
128179 WritableByteChannel channel , SSLEngine sslEngine ) throws IOException {
129180 SSLEngineResult .HandshakeStatus handshakeStatus = sslEngine .getHandshakeStatus ();
130- SSLEngineResult .Status status = sslEngine .wrap (plainOut , cipherOut ).getStatus ();
131- switch (status ) {
181+ LOGGER .debug ("Handshake status is {} before wrapping" , handshakeStatus );
182+ SSLEngineResult result = sslEngine .wrap (plainOut , cipherOut );
183+ LOGGER .debug ("SSL engine result is {} after wrapping" , result );
184+ switch (result .getStatus ()) {
132185 case OK :
133- handshakeStatus = runDelegatedTasks (sslEngine );
134186 cipherOut .flip ();
135187 while (cipherOut .hasRemaining ()) {
136- channel .write (cipherOut );
188+ int written = channel .write (cipherOut );
189+ LOGGER .debug ("Wrote {} byte(s)" , written );
137190 }
138191 cipherOut .clear ();
192+ if (result .getHandshakeStatus () == NEED_TASK ) {
193+ handshakeStatus = runDelegatedTasks (sslEngine );
194+ } else {
195+ handshakeStatus = result .getHandshakeStatus ();
196+ }
197+
139198 break ;
140199 case BUFFER_OVERFLOW :
141200 throw new SSLException ("Buffer overflow during handshake" );
142201 default :
143- throw new SSLException ("Unexpected status " + status );
202+ throw new SSLException ("Unexpected status " + result . getStatus () );
144203 }
145204 return handshakeStatus ;
146205 }
147206
148- static int bufferCopy (ByteBuffer from , ByteBuffer to ) {
149- int maxTransfer = Math .min (to .remaining (), from .remaining ());
150-
151- ByteBuffer temporaryBuffer = from .duplicate ();
152- temporaryBuffer .limit (temporaryBuffer .position () + maxTransfer );
153- to .put (temporaryBuffer );
154-
155- from .position (from .position () + maxTransfer );
156-
157- return maxTransfer ;
158- }
159-
160207 public static void write (WritableByteChannel socketChannel , SSLEngine engine , ByteBuffer plainOut , ByteBuffer cypherOut ) throws IOException {
161208 while (plainOut .hasRemaining ()) {
162209 cypherOut .clear ();
0 commit comments