001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLPeerUnverifiedException;
040import javax.net.ssl.SSLSession;
041
042import org.apache.activemq.command.ConnectionInfo;
043import org.apache.activemq.openwire.OpenWireFormat;
044import org.apache.activemq.thread.TaskRunnerFactory;
045import org.apache.activemq.util.IOExceptionSupport;
046import org.apache.activemq.util.ServiceStopper;
047import org.apache.activemq.wireformat.WireFormat;
048import org.slf4j.Logger;
049import org.slf4j.LoggerFactory;
050
051public class NIOSSLTransport extends NIOTransport {
052
053    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
054
055    protected boolean needClientAuth;
056    protected boolean wantClientAuth;
057    protected String[] enabledCipherSuites;
058    protected String[] enabledProtocols;
059
060    protected SSLContext sslContext;
061    protected SSLEngine sslEngine;
062    protected SSLSession sslSession;
063
064    protected volatile boolean handshakeInProgress = false;
065    protected SSLEngineResult.Status status = null;
066    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
067    protected TaskRunnerFactory taskRunnerFactory;
068
069    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
070        super(wireFormat, socketFactory, remoteLocation, localLocation);
071    }
072
073    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
074            ByteBuffer inputBuffer) throws IOException {
075        super(wireFormat, socket, initBuffer);
076        this.sslEngine = engine;
077        if (engine != null) {
078            this.sslSession = engine.getSession();
079        }
080        this.inputBuffer = inputBuffer;
081    }
082
083    public void setSslContext(SSLContext sslContext) {
084        this.sslContext = sslContext;
085    }
086
087    volatile boolean hasSslEngine = false;
088
089    @Override
090    protected void initializeStreams() throws IOException {
091        if (sslEngine != null) {
092            hasSslEngine = true;
093        }
094        NIOOutputStream outputStream = null;
095        try {
096            channel = socket.getChannel();
097            channel.configureBlocking(false);
098
099            if (sslContext == null) {
100                sslContext = SSLContext.getDefault();
101            }
102
103            String remoteHost = null;
104            int remotePort = -1;
105
106            try {
107                URI remoteAddress = new URI(this.getRemoteAddress());
108                remoteHost = remoteAddress.getHost();
109                remotePort = remoteAddress.getPort();
110            } catch (Exception e) {
111            }
112
113            // initialize engine, the initial sslSession we get will need to be
114            // updated once the ssl handshake process is completed.
115            if (!hasSslEngine) {
116                if (remoteHost != null && remotePort != -1) {
117                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
118                } else {
119                    sslEngine = sslContext.createSSLEngine();
120                }
121
122                sslEngine.setUseClientMode(false);
123                if (enabledCipherSuites != null) {
124                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
125                }
126
127                if (enabledProtocols != null) {
128                    sslEngine.setEnabledProtocols(enabledProtocols);
129                }
130
131                if (wantClientAuth) {
132                    sslEngine.setWantClientAuth(wantClientAuth);
133                }
134
135                if (needClientAuth) {
136                    sslEngine.setNeedClientAuth(needClientAuth);
137                }
138
139                sslSession = sslEngine.getSession();
140
141                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
142                inputBuffer.clear();
143            }
144
145            outputStream = new NIOOutputStream(channel);
146            outputStream.setEngine(sslEngine);
147            this.dataOut = new DataOutputStream(outputStream);
148            this.buffOut = outputStream;
149
150            //If the sslEngine was not passed in, then handshake
151            if (!hasSslEngine) {
152                sslEngine.beginHandshake();
153            }
154            handshakeStatus = sslEngine.getHandshakeStatus();
155            if (!hasSslEngine) {
156                doHandshake();
157            }
158
159           // if (hasSslEngine) {
160            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
161                @Override
162                public void onSelect(SelectorSelection selection) {
163                    try {
164                        initialized.await();
165                    } catch (InterruptedException error) {
166                        onException(IOExceptionSupport.create(error));
167                    }
168                    serviceRead();
169                }
170
171                @Override
172                public void onError(SelectorSelection selection, Throwable error) {
173                    if (error instanceof IOException) {
174                        onException((IOException) error);
175                    } else {
176                        onException(IOExceptionSupport.create(error));
177                    }
178                }
179            });
180            doInit();
181
182        } catch (Exception e) {
183            try {
184                if(outputStream != null) {
185                    outputStream.close();
186                }
187                super.closeStreams();
188            } catch (Exception ex) {}
189            throw new IOException(e);
190        }
191    }
192
193    final protected CountDownLatch initialized = new CountDownLatch(1);
194
195    protected void doInit() throws Exception {
196        taskRunnerFactory.execute(new Runnable() {
197
198            @Override
199            public void run() {
200                //Need to start in new thread to let startup finish first
201                //We can trigger a read because we know the channel is ready since the SSL handshake
202                //already happened
203                serviceRead();
204                initialized.countDown();
205            }
206        });
207    }
208
209    //Only used for the auto transport to abort the openwire init method early if already initialized
210    boolean openWireInititialized = false;
211
212    protected void doOpenWireInit() throws Exception {
213        //Do this later to let wire format negotiation happen
214        if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
215            initBuffer.buffer.flip();
216            if (initBuffer.buffer.hasRemaining()) {
217                nextFrameSize = -1;
218                receiveCounter += initBuffer.readSize;
219                processCommand(initBuffer.buffer);
220                processCommand(initBuffer.buffer);
221                initBuffer.buffer.clear();
222                openWireInititialized = true;
223            }
224        }
225    }
226
227    protected void finishHandshake() throws Exception {
228        if (handshakeInProgress) {
229            handshakeInProgress = false;
230            nextFrameSize = -1;
231
232            // Once handshake completes we need to ask for the now real sslSession
233            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
234            // cipher suite.
235            sslSession = sslEngine.getSession();
236
237            // listen for events telling us when the socket is readable.
238            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
239                @Override
240                public void onSelect(SelectorSelection selection) {
241                    serviceRead();
242                }
243
244                @Override
245                public void onError(SelectorSelection selection, Throwable error) {
246                    if (error instanceof IOException) {
247                        onException((IOException) error);
248                    } else {
249                        onException(IOExceptionSupport.create(error));
250                    }
251                }
252            });
253        }
254    }
255
256    @Override
257    public void serviceRead() {
258        try {
259            if (handshakeInProgress) {
260                doHandshake();
261            }
262
263            doOpenWireInit();
264
265            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
266            plain.position(plain.limit());
267
268            while (true) {
269                if (!plain.hasRemaining()) {
270
271                    int readCount = secureRead(plain);
272
273                    if (readCount == 0) {
274                        break;
275                    }
276
277                    // channel is closed, cleanup
278                    if (readCount == -1) {
279                        onException(new EOFException());
280                        selection.close();
281                        break;
282                    }
283
284                    receiveCounter += readCount;
285                }
286
287                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
288                    processCommand(plain);
289                }
290            }
291        } catch (IOException e) {
292            onException(e);
293        } catch (Throwable e) {
294            onException(IOExceptionSupport.create(e));
295        }
296    }
297
298    protected void processCommand(ByteBuffer plain) throws Exception {
299
300        // Are we waiting for the next Command or are we building on the current one
301        if (nextFrameSize == -1) {
302
303            // We can get small packets that don't give us enough for the frame size
304            // so allocate enough for the initial size value and
305            if (plain.remaining() < Integer.SIZE) {
306                if (currentBuffer == null) {
307                    currentBuffer = ByteBuffer.allocate(4);
308                }
309
310                // Go until we fill the integer sized current buffer.
311                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
312                    currentBuffer.put(plain.get());
313                }
314
315                // Didn't we get enough yet to figure out next frame size.
316                if (currentBuffer.hasRemaining()) {
317                    return;
318                } else {
319                    currentBuffer.flip();
320                    nextFrameSize = currentBuffer.getInt();
321                }
322
323            } else {
324
325                // Either we are completing a previous read of the next frame size or its
326                // fully contained in plain already.
327                if (currentBuffer != null) {
328
329                    // Finish the frame size integer read and get from the current buffer.
330                    while (currentBuffer.hasRemaining()) {
331                        currentBuffer.put(plain.get());
332                    }
333
334                    currentBuffer.flip();
335                    nextFrameSize = currentBuffer.getInt();
336
337                } else {
338                    nextFrameSize = plain.getInt();
339                }
340            }
341
342            if (wireFormat instanceof OpenWireFormat) {
343                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
344                if (nextFrameSize > maxFrameSize) {
345                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
346                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
347                }
348            }
349
350            // now we got the data, lets reallocate and store the size for the marshaler.
351            // if there's more data in plain, then the next call will start processing it.
352            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
353            currentBuffer.putInt(nextFrameSize);
354
355        } else {
356            // If its all in one read then we can just take it all, otherwise take only
357            // the current frame size and the next iteration starts a new command.
358            if (currentBuffer != null) {
359                if (currentBuffer.remaining() >= plain.remaining()) {
360                    currentBuffer.put(plain);
361                } else {
362                    byte[] fill = new byte[currentBuffer.remaining()];
363                    plain.get(fill);
364                    currentBuffer.put(fill);
365                }
366
367                // Either we have enough data for a new command or we have to wait for some more.
368                if (currentBuffer.hasRemaining()) {
369                    return;
370                } else {
371                    currentBuffer.flip();
372                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
373                    doConsume(command);
374                    nextFrameSize = -1;
375                    currentBuffer = null;
376               }
377            }
378        }
379    }
380
381    protected int secureRead(ByteBuffer plain) throws Exception {
382
383        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
384            int bytesRead = channel.read(inputBuffer);
385
386            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
387                return 0;
388            }
389
390            if (bytesRead == -1) {
391                sslEngine.closeInbound();
392                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
393                    return -1;
394                }
395            }
396        }
397
398        plain.clear();
399
400        inputBuffer.flip();
401        SSLEngineResult res;
402        do {
403            res = sslEngine.unwrap(inputBuffer, plain);
404        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
405                && res.bytesProduced() == 0);
406
407        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
408            finishHandshake();
409        }
410
411        status = res.getStatus();
412        handshakeStatus = res.getHandshakeStatus();
413
414        // TODO deal with BUFFER_OVERFLOW
415
416        if (status == SSLEngineResult.Status.CLOSED) {
417            sslEngine.closeInbound();
418            return -1;
419        }
420
421        inputBuffer.compact();
422        plain.flip();
423
424        return plain.remaining();
425    }
426
427    protected void doHandshake() throws Exception {
428        handshakeInProgress = true;
429        Selector selector = null;
430        SelectionKey key = null;
431        boolean readable = true;
432        try {
433            while (true) {
434                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
435                switch (handshakeStatus) {
436                    case NEED_UNWRAP:
437                        if (readable) {
438                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
439                        }
440                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
441                            long now = System.currentTimeMillis();
442                            if (selector == null) {
443                                selector = Selector.open();
444                                key = channel.register(selector, SelectionKey.OP_READ);
445                            } else {
446                                key.interestOps(SelectionKey.OP_READ);
447                            }
448                            int keyCount = selector.select(this.getSoTimeout());
449                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
450                                throw new SocketTimeoutException("Timeout during handshake");
451                            }
452                            readable = key.isReadable();
453                        }
454                        break;
455                    case NEED_TASK:
456                        Runnable task;
457                        while ((task = sslEngine.getDelegatedTask()) != null) {
458                            task.run();
459                        }
460                        break;
461                    case NEED_WRAP:
462                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
463                        break;
464                    case FINISHED:
465                    case NOT_HANDSHAKING:
466                        finishHandshake();
467                        return;
468                }
469            }
470        } finally {
471            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
472            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
473        }
474    }
475
476    @Override
477    protected void doStart() throws Exception {
478        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
479        // no need to init as we can delay that until demand (eg in doHandshake)
480        super.doStart();
481    }
482
483    @Override
484    protected void doStop(ServiceStopper stopper) throws Exception {
485        initialized.countDown();
486
487        if (taskRunnerFactory != null) {
488            taskRunnerFactory.shutdownNow();
489            taskRunnerFactory = null;
490        }
491        if (channel != null) {
492            channel.close();
493            channel = null;
494        }
495        super.doStop(stopper);
496    }
497
498    /**
499     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
500     *
501     * @param command
502     *            The Command coming in.
503     */
504    @Override
505    public void doConsume(Object command) {
506        if (command instanceof ConnectionInfo) {
507            ConnectionInfo connectionInfo = (ConnectionInfo) command;
508            connectionInfo.setTransportContext(getPeerCertificates());
509        }
510        super.doConsume(command);
511    }
512
513    /**
514     * @return peer certificate chain associated with the ssl socket
515     */
516    @Override
517    public X509Certificate[] getPeerCertificates() {
518
519        X509Certificate[] clientCertChain = null;
520        try {
521            if (sslEngine.getSession() != null) {
522                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
523            }
524        } catch (SSLPeerUnverifiedException e) {
525            if (LOG.isTraceEnabled()) {
526                LOG.trace("Failed to get peer certificates.", e);
527            }
528        }
529
530        return clientCertChain;
531    }
532
533    public boolean isNeedClientAuth() {
534        return needClientAuth;
535    }
536
537    public void setNeedClientAuth(boolean needClientAuth) {
538        this.needClientAuth = needClientAuth;
539    }
540
541    public boolean isWantClientAuth() {
542        return wantClientAuth;
543    }
544
545    public void setWantClientAuth(boolean wantClientAuth) {
546        this.wantClientAuth = wantClientAuth;
547    }
548
549    public String[] getEnabledCipherSuites() {
550        return enabledCipherSuites;
551    }
552
553    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
554        this.enabledCipherSuites = enabledCipherSuites;
555    }
556
557    public String[] getEnabledProtocols() {
558        return enabledProtocols;
559    }
560
561    public void setEnabledProtocols(String[] enabledProtocols) {
562        this.enabledProtocols = enabledProtocols;
563    }
564}