001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.activemq.transport.nio;
019    
020    import java.io.DataInputStream;
021    import java.io.DataOutputStream;
022    import java.io.EOFException;
023    import java.io.IOException;
024    import java.net.Socket;
025    import java.net.URI;
026    import java.net.UnknownHostException;
027    import java.nio.ByteBuffer;
028    import java.security.cert.X509Certificate;
029    
030    import javax.net.SocketFactory;
031    import javax.net.ssl.SSLContext;
032    import javax.net.ssl.SSLEngine;
033    import javax.net.ssl.SSLEngineResult;
034    import javax.net.ssl.SSLPeerUnverifiedException;
035    import javax.net.ssl.SSLSession;
036    
037    import org.apache.activemq.command.Command;
038    import org.apache.activemq.command.ConnectionInfo;
039    import org.apache.activemq.openwire.OpenWireFormat;
040    import org.apache.activemq.thread.TaskRunnerFactory;
041    import org.apache.activemq.util.IOExceptionSupport;
042    import org.apache.activemq.util.ServiceStopper;
043    import org.apache.activemq.wireformat.WireFormat;
044    import org.slf4j.Logger;
045    import org.slf4j.LoggerFactory;
046    
047    public class NIOSSLTransport extends NIOTransport {
048    
049        private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
050    
051        protected boolean needClientAuth;
052        protected boolean wantClientAuth;
053        protected String[] enabledCipherSuites;
054    
055        protected SSLContext sslContext;
056        protected SSLEngine sslEngine;
057        protected SSLSession sslSession;
058    
059        protected volatile boolean handshakeInProgress = false;
060        protected SSLEngineResult.Status status = null;
061        protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
062        protected TaskRunnerFactory taskRunnerFactory;
063    
064        public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
065            super(wireFormat, socketFactory, remoteLocation, localLocation);
066        }
067    
068        public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
069            super(wireFormat, socket);
070        }
071    
072        public void setSslContext(SSLContext sslContext) {
073            this.sslContext = sslContext;
074        }
075    
076        @Override
077        protected void initializeStreams() throws IOException {
078            try {
079                channel = socket.getChannel();
080                channel.configureBlocking(false);
081    
082                if (sslContext == null) {
083                    sslContext = SSLContext.getDefault();
084                }
085    
086                String remoteHost = null;
087                int remotePort = -1;
088    
089                try {
090                    URI remoteAddress = new URI(this.getRemoteAddress());
091                    remoteHost = remoteAddress.getHost();
092                    remotePort = remoteAddress.getPort();
093                } catch (Exception e) {
094                }
095    
096                // initialize engine, the initial sslSession we get will need to be
097                // updated once the ssl handshake process is completed.
098                if (remoteHost != null && remotePort != -1) {
099                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
100                } else {
101                    sslEngine = sslContext.createSSLEngine();
102                }
103    
104                sslEngine.setUseClientMode(false);
105                if (enabledCipherSuites != null) {
106                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
107                }
108    
109                if (wantClientAuth) {
110                    sslEngine.setWantClientAuth(wantClientAuth);
111                }
112    
113                if (needClientAuth) {
114                    sslEngine.setNeedClientAuth(needClientAuth);
115                }
116    
117                sslSession = sslEngine.getSession();
118    
119                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
120                inputBuffer.clear();
121    
122                NIOOutputStream outputStream = new NIOOutputStream(channel);
123                outputStream.setEngine(sslEngine);
124                this.dataOut = new DataOutputStream(outputStream);
125                this.buffOut = outputStream;
126                sslEngine.beginHandshake();
127                handshakeStatus = sslEngine.getHandshakeStatus();
128                doHandshake();
129            } catch (Exception e) {
130                throw new IOException(e);
131            }
132        }
133    
134        protected void finishHandshake() throws Exception {
135            if (handshakeInProgress) {
136                handshakeInProgress = false;
137                nextFrameSize = -1;
138    
139                // Once handshake completes we need to ask for the now real sslSession
140                // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
141                // cipher suite.
142                sslSession = sslEngine.getSession();
143    
144                // listen for events telling us when the socket is readable.
145                selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
146                    public void onSelect(SelectorSelection selection) {
147                        serviceRead();
148                    }
149    
150                    public void onError(SelectorSelection selection, Throwable error) {
151                        if (error instanceof IOException) {
152                            onException((IOException) error);
153                        } else {
154                            onException(IOExceptionSupport.create(error));
155                        }
156                    }
157                });
158            }
159        }
160    
161        protected void serviceRead() {
162            try {
163                if (handshakeInProgress) {
164                    doHandshake();
165                }
166    
167                ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
168                plain.position(plain.limit());
169    
170                while (true) {
171                    if (!plain.hasRemaining()) {
172    
173                        int readCount = secureRead(plain);
174    
175                        if (readCount == 0) {
176                            break;
177                        }
178    
179                        // channel is closed, cleanup
180                        if (readCount == -1) {
181                            onException(new EOFException());
182                            selection.close();
183                            break;
184                        }
185    
186                        receiveCounter += readCount;
187                    }
188    
189                    if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
190                        processCommand(plain);
191                    }
192                }
193            } catch (IOException e) {
194                onException(e);
195            } catch (Throwable e) {
196                onException(IOExceptionSupport.create(e));
197            }
198        }
199    
200        protected void processCommand(ByteBuffer plain) throws Exception {
201    
202            // Are we waiting for the next Command or are we building on the current one
203            if (nextFrameSize == -1) {
204    
205                // We can get small packets that don't give us enough for the frame size
206                // so allocate enough for the initial size value and
207                if (plain.remaining() < Integer.SIZE) {
208                    if (currentBuffer == null) {
209                        currentBuffer = ByteBuffer.allocate(4);
210                    }
211    
212                    // Go until we fill the integer sized current buffer.
213                    while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
214                        currentBuffer.put(plain.get());
215                    }
216    
217                    // Didn't we get enough yet to figure out next frame size.
218                    if (currentBuffer.hasRemaining()) {
219                        return;
220                    } else {
221                        currentBuffer.flip();
222                        nextFrameSize = currentBuffer.getInt();
223                    }
224    
225                } else {
226    
227                    // Either we are completing a previous read of the next frame size or its
228                    // fully contained in plain already.
229                    if (currentBuffer != null) {
230    
231                        // Finish the frame size integer read and get from the current buffer.
232                        while (currentBuffer.hasRemaining()) {
233                            currentBuffer.put(plain.get());
234                        }
235    
236                        currentBuffer.flip();
237                        nextFrameSize = currentBuffer.getInt();
238    
239                    } else {
240                        nextFrameSize = plain.getInt();
241                    }
242                }
243    
244                if (wireFormat instanceof OpenWireFormat) {
245                    long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
246                    if (nextFrameSize > maxFrameSize) {
247                        throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
248                                              " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
249                    }
250                }
251    
252                // now we got the data, lets reallocate and store the size for the marshaler.
253                // if there's more data in plain, then the next call will start processing it.
254                currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
255                currentBuffer.putInt(nextFrameSize);
256    
257            } else {
258    
259                // If its all in one read then we can just take it all, otherwise take only
260                // the current frame size and the next iteration starts a new command.
261                if (currentBuffer.remaining() >= plain.remaining()) {
262                    currentBuffer.put(plain);
263                } else {
264                    byte[] fill = new byte[currentBuffer.remaining()];
265                    plain.get(fill);
266                    currentBuffer.put(fill);
267                }
268    
269                // Either we have enough data for a new command or we have to wait for some more.
270                if (currentBuffer.hasRemaining()) {
271                    return;
272                } else {
273                    currentBuffer.flip();
274                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
275                    doConsume((Command) command);
276                    nextFrameSize = -1;
277                    currentBuffer = null;
278                }
279            }
280        }
281    
282        protected int secureRead(ByteBuffer plain) throws Exception {
283    
284            if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
285                int bytesRead = channel.read(inputBuffer);
286    
287                if (bytesRead == 0) {
288                    return 0;
289                }
290    
291                if (bytesRead == -1) {
292                    sslEngine.closeInbound();
293                    if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
294                        return -1;
295                    }
296                }
297            }
298    
299            plain.clear();
300    
301            inputBuffer.flip();
302            SSLEngineResult res;
303            do {
304                res = sslEngine.unwrap(inputBuffer, plain);
305            } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
306                    && res.bytesProduced() == 0);
307    
308            if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
309                finishHandshake();
310            }
311    
312            status = res.getStatus();
313            handshakeStatus = res.getHandshakeStatus();
314    
315            // TODO deal with BUFFER_OVERFLOW
316    
317            if (status == SSLEngineResult.Status.CLOSED) {
318                sslEngine.closeInbound();
319                return -1;
320            }
321    
322            inputBuffer.compact();
323            plain.flip();
324    
325            return plain.remaining();
326        }
327    
328        protected void doHandshake() throws Exception {
329            handshakeInProgress = true;
330            while (true) {
331                switch (sslEngine.getHandshakeStatus()) {
332                case NEED_UNWRAP:
333                    secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
334                    break;
335                case NEED_TASK:
336                    Runnable task;
337                    while ((task = sslEngine.getDelegatedTask()) != null) {
338                        taskRunnerFactory.execute(task);
339                    }
340                    break;
341                case NEED_WRAP:
342                    ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
343                    break;
344                case FINISHED:
345                case NOT_HANDSHAKING:
346                    finishHandshake();
347                    return;
348                }
349            }
350        }
351    
352        @Override
353        protected void doStart() throws Exception {
354            taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
355            // no need to init as we can delay that until demand (eg in doHandshake)
356            super.doStart();
357        }
358    
359        @Override
360        protected void doStop(ServiceStopper stopper) throws Exception {
361            if (taskRunnerFactory != null) {
362                taskRunnerFactory.shutdownNow();
363                taskRunnerFactory = null;
364            }
365            if (channel != null) {
366                channel.close();
367                channel = null;
368            }
369            super.doStop(stopper);
370        }
371    
372        /**
373         * Overriding in order to add the client's certificates to ConnectionInfo Commands.
374         *
375         * @param command
376         *            The Command coming in.
377         */
378        @Override
379        public void doConsume(Object command) {
380            if (command instanceof ConnectionInfo) {
381                ConnectionInfo connectionInfo = (ConnectionInfo) command;
382                connectionInfo.setTransportContext(getPeerCertificates());
383            }
384            super.doConsume(command);
385        }
386    
387        /**
388         * @return peer certificate chain associated with the ssl socket
389         */
390        public X509Certificate[] getPeerCertificates() {
391    
392            X509Certificate[] clientCertChain = null;
393            try {
394                if (sslEngine.getSession() != null) {
395                    clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
396                }
397            } catch (SSLPeerUnverifiedException e) {
398                if (LOG.isTraceEnabled()) {
399                    LOG.trace("Failed to get peer certificates.", e);
400                }
401            }
402    
403            return clientCertChain;
404        }
405    
406        public boolean isNeedClientAuth() {
407            return needClientAuth;
408        }
409    
410        public void setNeedClientAuth(boolean needClientAuth) {
411            this.needClientAuth = needClientAuth;
412        }
413    
414        public boolean isWantClientAuth() {
415            return wantClientAuth;
416        }
417    
418        public void setWantClientAuth(boolean wantClientAuth) {
419            this.wantClientAuth = wantClientAuth;
420        }
421    
422        public String[] getEnabledCipherSuites() {
423            return enabledCipherSuites;
424        }
425    
426        public void setEnabledCipherSuites(String[] enabledCipherSuites) {
427            this.enabledCipherSuites = enabledCipherSuites;
428        }
429    }