001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.activemq.transport.nio;
019    
020    import java.io.DataInputStream;
021    import java.io.DataOutputStream;
022    import java.io.EOFException;
023    import java.io.IOException;
024    import java.net.Socket;
025    import java.net.URI;
026    import java.net.UnknownHostException;
027    import java.nio.ByteBuffer;
028    import java.security.cert.X509Certificate;
029    
030    import javax.net.SocketFactory;
031    import javax.net.ssl.SSLContext;
032    import javax.net.ssl.SSLEngine;
033    import javax.net.ssl.SSLEngineResult;
034    import javax.net.ssl.SSLPeerUnverifiedException;
035    import javax.net.ssl.SSLSession;
036    
037    import org.apache.activemq.command.Command;
038    import org.apache.activemq.command.ConnectionInfo;
039    import org.apache.activemq.openwire.OpenWireFormat;
040    import org.apache.activemq.thread.TaskRunnerFactory;
041    import org.apache.activemq.util.IOExceptionSupport;
042    import org.apache.activemq.util.ServiceStopper;
043    import org.apache.activemq.wireformat.WireFormat;
044    
045    public class NIOSSLTransport extends NIOTransport  {
046    
047        protected boolean needClientAuth;
048        protected boolean wantClientAuth;
049        protected String[] enabledCipherSuites;
050    
051        protected SSLContext sslContext;
052        protected SSLEngine sslEngine;
053        protected SSLSession sslSession;
054    
055        protected volatile boolean handshakeInProgress = false;
056        protected SSLEngineResult.Status status = null;
057        protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
058        protected TaskRunnerFactory taskRunnerFactory;
059    
060        public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
061            super(wireFormat, socketFactory, remoteLocation, localLocation);
062        }
063    
064        public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
065            super(wireFormat, socket);
066        }
067    
068        public void setSslContext(SSLContext sslContext) {
069            this.sslContext = sslContext;
070        }
071    
072        @Override
073        protected void initializeStreams() throws IOException {
074            try {
075                channel = socket.getChannel();
076                channel.configureBlocking(false);
077    
078                if (sslContext == null) {
079                    sslContext = SSLContext.getDefault();
080                }
081    
082                // initialize engine, the initial sslSession we get will need to be
083                // updated once the ssl handshake process is completed.
084                sslEngine = sslContext.createSSLEngine();
085                sslEngine.setUseClientMode(false);
086                if (enabledCipherSuites != null) {
087                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
088                }
089                sslEngine.setNeedClientAuth(needClientAuth);
090                sslEngine.setWantClientAuth(wantClientAuth);
091    
092                sslSession = sslEngine.getSession();
093    
094                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
095                inputBuffer.clear();
096                currentBuffer = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
097    
098                NIOOutputStream outputStream = new NIOOutputStream(channel);
099                outputStream.setEngine(sslEngine);
100                this.dataOut = new DataOutputStream(outputStream);
101                this.buffOut = outputStream;
102                sslEngine.beginHandshake();
103                handshakeStatus = sslEngine.getHandshakeStatus();
104                doHandshake();
105            } catch (Exception e) {
106                throw new IOException(e);
107            }
108        }
109    
110        protected void finishHandshake() throws Exception  {
111              if (handshakeInProgress) {
112                  handshakeInProgress = false;
113                  nextFrameSize = -1;
114    
115                  // Once handshake completes we need to ask for the now real sslSession
116                  // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
117                  // cipher suite.
118                  sslSession = sslEngine.getSession();
119    
120                  // listen for events telling us when the socket is readable.
121                  selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
122                      public void onSelect(SelectorSelection selection) {
123                          serviceRead();
124                      }
125    
126                      public void onError(SelectorSelection selection, Throwable error) {
127                          if (error instanceof IOException) {
128                              onException((IOException) error);
129                          } else {
130                              onException(IOExceptionSupport.create(error));
131                          }
132                      }
133                  });
134              }
135        }
136    
137        protected void serviceRead() {
138            try {
139                if (handshakeInProgress) {
140                    doHandshake();
141                }
142    
143                ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
144                plain.position(plain.limit());
145    
146                while(true) {
147                    if (!plain.hasRemaining()) {
148    
149                        if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
150                            plain.clear();
151                        } else {
152                            plain.compact();
153                        }
154                        int readCount = secureRead(plain);
155    
156    
157                        if (readCount == 0)
158                            break;
159    
160                        // channel is closed, cleanup
161                        if (readCount== -1) {
162                            onException(new EOFException());
163                            selection.close();
164                            break;
165                        }
166                    }
167    
168                    if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
169                        processCommand(plain);
170                    }
171                }
172            } catch (IOException e) {
173                onException(e);
174            } catch (Throwable e) {
175                onException(IOExceptionSupport.create(e));
176            }
177        }
178    
179        protected void processCommand(ByteBuffer plain) throws Exception {
180            nextFrameSize = plain.getInt();
181            if (wireFormat instanceof OpenWireFormat) {
182                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
183                if (nextFrameSize > maxFrameSize) {
184                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
185                }
186            }
187            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
188            currentBuffer.putInt(nextFrameSize);
189            if (currentBuffer.hasRemaining()) {
190                if (currentBuffer.remaining() >= plain.remaining()) {
191                    currentBuffer.put(plain);
192                } else {
193                    byte[] fill = new byte[currentBuffer.remaining()];
194                    plain.get(fill);
195                    currentBuffer.put(fill);
196                }
197            }
198    
199            if (currentBuffer.hasRemaining()) {
200                return;
201            } else {
202                currentBuffer.flip();
203                Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
204                doConsume((Command) command);
205                nextFrameSize = -1;
206            }
207        }
208    
209        protected int secureRead(ByteBuffer plain) throws Exception {
210    
211            if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
212                int bytesRead = channel.read(inputBuffer);
213    
214                if (bytesRead == -1) {
215                    sslEngine.closeInbound();
216                    if (inputBuffer.position() == 0 ||
217                            status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
218                        return -1;
219                    }
220                }
221            }
222    
223            plain.clear();
224    
225            inputBuffer.flip();
226            SSLEngineResult res;
227            do {
228                res = sslEngine.unwrap(inputBuffer, plain);
229            } while (res.getStatus() == SSLEngineResult.Status.OK &&
230                    res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
231                    res.bytesProduced() == 0);
232    
233            if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
234               finishHandshake();
235            }
236    
237            status = res.getStatus();
238            handshakeStatus = res.getHandshakeStatus();
239    
240            //TODO deal with BUFFER_OVERFLOW
241    
242            if (status == SSLEngineResult.Status.CLOSED) {
243                sslEngine.closeInbound();
244                return -1;
245            }
246    
247            inputBuffer.compact();
248            plain.flip();
249    
250            return plain.remaining();
251        }
252    
253        protected void doHandshake() throws Exception {
254            handshakeInProgress = true;
255            while (true) {
256                switch (sslEngine.getHandshakeStatus()) {
257                    case NEED_UNWRAP:
258                        secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
259                        break;
260                    case NEED_TASK:
261                        Runnable task;
262                        while ((task = sslEngine.getDelegatedTask()) != null) {
263                            taskRunnerFactory.execute(task);
264                        }
265                        break;
266                    case NEED_WRAP:
267                        ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
268                        break;
269                    case FINISHED:
270                    case NOT_HANDSHAKING:
271                        finishHandshake();
272                        return;
273                }
274            }
275        }
276    
277        @Override
278        protected void doStart() throws Exception {
279            taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
280            // no need to init as we can delay that until demand (eg in doHandshake)
281            super.doStart();
282        }
283    
284        @Override
285        protected void doStop(ServiceStopper stopper) throws Exception {
286            if (taskRunnerFactory != null) {
287                taskRunnerFactory.shutdownNow();
288                taskRunnerFactory = null;
289            }
290            if (channel != null) {
291                channel.close();
292                channel = null;
293            }
294            super.doStop(stopper);
295        }
296    
297        /**
298         * Overriding in order to add the client's certificates to ConnectionInfo Commmands.
299         *
300         * @param command The Command coming in.
301         */
302        @Override
303        public void doConsume(Object command) {
304            if (command instanceof ConnectionInfo) {
305                ConnectionInfo connectionInfo = (ConnectionInfo)command;
306                connectionInfo.setTransportContext(getPeerCertificates());
307            }
308            super.doConsume(command);
309        }
310    
311        /**
312         * @return peer certificate chain associated with the ssl socket
313         */
314        public X509Certificate[] getPeerCertificates() {
315    
316            X509Certificate[] clientCertChain = null;
317            try {
318                if (sslSession != null) {
319                    clientCertChain = (X509Certificate[])sslSession.getPeerCertificates();
320                }
321            } catch (SSLPeerUnverifiedException e) {
322            }
323    
324            return clientCertChain;
325        }
326    
327        public boolean isNeedClientAuth() {
328            return needClientAuth;
329        }
330    
331        public void setNeedClientAuth(boolean needClientAuth) {
332            this.needClientAuth = needClientAuth;
333        }
334    
335        public boolean isWantClientAuth() {
336            return wantClientAuth;
337        }
338    
339        public void setWantClientAuth(boolean wantClientAuth) {
340            this.wantClientAuth = wantClientAuth;
341        }
342    
343        public String[] getEnabledCipherSuites() {
344            return enabledCipherSuites;
345        }
346    
347        public void setEnabledCipherSuites(String[] enabledCipherSuites) {
348            this.enabledCipherSuites = enabledCipherSuites;
349        }
350    }