001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032
033import javax.net.SocketFactory;
034import javax.net.ssl.SSLContext;
035import javax.net.ssl.SSLEngine;
036import javax.net.ssl.SSLEngineResult;
037import javax.net.ssl.SSLEngineResult.HandshakeStatus;
038import javax.net.ssl.SSLPeerUnverifiedException;
039import javax.net.ssl.SSLSession;
040
041import org.apache.activemq.command.ConnectionInfo;
042import org.apache.activemq.openwire.OpenWireFormat;
043import org.apache.activemq.thread.TaskRunnerFactory;
044import org.apache.activemq.util.IOExceptionSupport;
045import org.apache.activemq.util.ServiceStopper;
046import org.apache.activemq.wireformat.WireFormat;
047import org.slf4j.Logger;
048import org.slf4j.LoggerFactory;
049
050public class NIOSSLTransport extends NIOTransport {
051
052    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
053
054    protected boolean needClientAuth;
055    protected boolean wantClientAuth;
056    protected String[] enabledCipherSuites;
057    protected String[] enabledProtocols;
058
059    protected SSLContext sslContext;
060    protected SSLEngine sslEngine;
061    protected SSLSession sslSession;
062
063    protected volatile boolean handshakeInProgress = false;
064    protected SSLEngineResult.Status status = null;
065    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
066    protected TaskRunnerFactory taskRunnerFactory;
067
068    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
069        super(wireFormat, socketFactory, remoteLocation, localLocation);
070    }
071
072    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
073            ByteBuffer inputBuffer) throws IOException {
074        super(wireFormat, socket, initBuffer);
075        this.sslEngine = engine;
076        if (engine != null) {
077            this.sslSession = engine.getSession();
078        }
079        this.inputBuffer = inputBuffer;
080    }
081
082    public void setSslContext(SSLContext sslContext) {
083        this.sslContext = sslContext;
084    }
085
086    volatile boolean hasSslEngine = false;
087
088    @Override
089    protected void initializeStreams() throws IOException {
090        if (sslEngine != null) {
091            hasSslEngine = true;
092        }
093        NIOOutputStream outputStream = null;
094        try {
095            channel = socket.getChannel();
096            channel.configureBlocking(false);
097
098            if (sslContext == null) {
099                sslContext = SSLContext.getDefault();
100            }
101
102            String remoteHost = null;
103            int remotePort = -1;
104
105            try {
106                URI remoteAddress = new URI(this.getRemoteAddress());
107                remoteHost = remoteAddress.getHost();
108                remotePort = remoteAddress.getPort();
109            } catch (Exception e) {
110            }
111
112            // initialize engine, the initial sslSession we get will need to be
113            // updated once the ssl handshake process is completed.
114            if (!hasSslEngine) {
115                if (remoteHost != null && remotePort != -1) {
116                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
117                } else {
118                    sslEngine = sslContext.createSSLEngine();
119                }
120
121                sslEngine.setUseClientMode(false);
122                if (enabledCipherSuites != null) {
123                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
124                }
125
126                if (enabledProtocols != null) {
127                    sslEngine.setEnabledProtocols(enabledProtocols);
128                }
129
130                if (wantClientAuth) {
131                    sslEngine.setWantClientAuth(wantClientAuth);
132                }
133
134                if (needClientAuth) {
135                    sslEngine.setNeedClientAuth(needClientAuth);
136                }
137
138                sslSession = sslEngine.getSession();
139
140                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
141                inputBuffer.clear();
142            }
143
144            outputStream = new NIOOutputStream(channel);
145            outputStream.setEngine(sslEngine);
146            this.dataOut = new DataOutputStream(outputStream);
147            this.buffOut = outputStream;
148
149            //If the sslEngine was not passed in, then handshake
150            if (!hasSslEngine) {
151                sslEngine.beginHandshake();
152            }
153            handshakeStatus = sslEngine.getHandshakeStatus();
154            if (!hasSslEngine) {
155                doHandshake();
156            }
157
158           // if (hasSslEngine) {
159            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
160                @Override
161                public void onSelect(SelectorSelection selection) {
162                    serviceRead();
163                }
164
165                @Override
166                public void onError(SelectorSelection selection, Throwable error) {
167                    if (error instanceof IOException) {
168                        onException((IOException) error);
169                    } else {
170                        onException(IOExceptionSupport.create(error));
171                    }
172                }
173            });
174            doInit();
175
176        } catch (Exception e) {
177            try {
178                if(outputStream != null) {
179                    outputStream.close();
180                }
181                super.closeStreams();
182            } catch (Exception ex) {}
183            throw new IOException(e);
184        }
185    }
186
187    protected void doInit() throws Exception {
188
189    }
190
191    protected void doOpenWireInit() throws Exception {
192        //Do this later to let wire format negotiation happen
193        if (initBuffer != null && this.wireFormat instanceof OpenWireFormat) {
194            initBuffer.buffer.flip();
195            if (initBuffer.buffer.hasRemaining()) {
196                nextFrameSize = -1;
197                receiveCounter += initBuffer.readSize;
198                processCommand(initBuffer.buffer);
199                processCommand(initBuffer.buffer);
200                initBuffer.buffer.clear();
201            }
202        }
203    }
204
205    protected void finishHandshake() throws Exception {
206        if (handshakeInProgress) {
207            handshakeInProgress = false;
208            nextFrameSize = -1;
209
210            // Once handshake completes we need to ask for the now real sslSession
211            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
212            // cipher suite.
213            sslSession = sslEngine.getSession();
214
215            // listen for events telling us when the socket is readable.
216            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
217                @Override
218                public void onSelect(SelectorSelection selection) {
219                    serviceRead();
220                }
221
222                @Override
223                public void onError(SelectorSelection selection, Throwable error) {
224                    if (error instanceof IOException) {
225                        onException((IOException) error);
226                    } else {
227                        onException(IOExceptionSupport.create(error));
228                    }
229                }
230            });
231        }
232    }
233
234    @Override
235    public void serviceRead() {
236        try {
237            if (handshakeInProgress) {
238                doHandshake();
239            }
240
241            doOpenWireInit();
242
243            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
244            plain.position(plain.limit());
245
246            while (true) {
247                if (!plain.hasRemaining()) {
248
249                    int readCount = secureRead(plain);
250
251                    if (readCount == 0) {
252                        break;
253                    }
254
255                    // channel is closed, cleanup
256                    if (readCount == -1) {
257                        onException(new EOFException());
258                        selection.close();
259                        break;
260                    }
261
262                    receiveCounter += readCount;
263                }
264
265                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
266                    processCommand(plain);
267                }
268            }
269        } catch (IOException e) {
270            onException(e);
271        } catch (Throwable e) {
272            onException(IOExceptionSupport.create(e));
273        }
274    }
275
276    protected void processCommand(ByteBuffer plain) throws Exception {
277
278        // Are we waiting for the next Command or are we building on the current one
279        if (nextFrameSize == -1) {
280
281            // We can get small packets that don't give us enough for the frame size
282            // so allocate enough for the initial size value and
283            if (plain.remaining() < Integer.SIZE) {
284                if (currentBuffer == null) {
285                    currentBuffer = ByteBuffer.allocate(4);
286                }
287
288                // Go until we fill the integer sized current buffer.
289                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
290                    currentBuffer.put(plain.get());
291                }
292
293                // Didn't we get enough yet to figure out next frame size.
294                if (currentBuffer.hasRemaining()) {
295                    return;
296                } else {
297                    currentBuffer.flip();
298                    nextFrameSize = currentBuffer.getInt();
299                }
300
301            } else {
302
303                // Either we are completing a previous read of the next frame size or its
304                // fully contained in plain already.
305                if (currentBuffer != null) {
306
307                    // Finish the frame size integer read and get from the current buffer.
308                    while (currentBuffer.hasRemaining()) {
309                        currentBuffer.put(plain.get());
310                    }
311
312                    currentBuffer.flip();
313                    nextFrameSize = currentBuffer.getInt();
314
315                } else {
316                    nextFrameSize = plain.getInt();
317                }
318            }
319
320            if (wireFormat instanceof OpenWireFormat) {
321                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
322                if (nextFrameSize > maxFrameSize) {
323                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
324                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
325                }
326            }
327
328            // now we got the data, lets reallocate and store the size for the marshaler.
329            // if there's more data in plain, then the next call will start processing it.
330            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
331            currentBuffer.putInt(nextFrameSize);
332
333        } else {
334            // If its all in one read then we can just take it all, otherwise take only
335            // the current frame size and the next iteration starts a new command.
336            if (currentBuffer != null) {
337                if (currentBuffer.remaining() >= plain.remaining()) {
338                    currentBuffer.put(plain);
339                } else {
340                    byte[] fill = new byte[currentBuffer.remaining()];
341                    plain.get(fill);
342                    currentBuffer.put(fill);
343                }
344
345                // Either we have enough data for a new command or we have to wait for some more.
346                if (currentBuffer.hasRemaining()) {
347                    return;
348                } else {
349                    currentBuffer.flip();
350                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
351                    doConsume(command);
352                    nextFrameSize = -1;
353                    currentBuffer = null;
354               }
355            }
356        }
357    }
358
359    protected int secureRead(ByteBuffer plain) throws Exception {
360
361        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
362            int bytesRead = channel.read(inputBuffer);
363
364            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
365                return 0;
366            }
367
368            if (bytesRead == -1) {
369                sslEngine.closeInbound();
370                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
371                    return -1;
372                }
373            }
374        }
375
376        plain.clear();
377
378        inputBuffer.flip();
379        SSLEngineResult res;
380        do {
381            res = sslEngine.unwrap(inputBuffer, plain);
382        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
383                && res.bytesProduced() == 0);
384
385        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
386            finishHandshake();
387        }
388
389        status = res.getStatus();
390        handshakeStatus = res.getHandshakeStatus();
391
392        // TODO deal with BUFFER_OVERFLOW
393
394        if (status == SSLEngineResult.Status.CLOSED) {
395            sslEngine.closeInbound();
396            return -1;
397        }
398
399        inputBuffer.compact();
400        plain.flip();
401
402        return plain.remaining();
403    }
404
405    protected void doHandshake() throws Exception {
406        handshakeInProgress = true;
407        Selector selector = null;
408        SelectionKey key = null;
409        boolean readable = true;
410        try {
411            while (true) {
412                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
413                switch (handshakeStatus) {
414                    case NEED_UNWRAP:
415                        if (readable) {
416                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
417                        }
418                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
419                            long now = System.currentTimeMillis();
420                            if (selector == null) {
421                                selector = Selector.open();
422                                key = channel.register(selector, SelectionKey.OP_READ);
423                            } else {
424                                key.interestOps(SelectionKey.OP_READ);
425                            }
426                            int keyCount = selector.select(this.getSoTimeout());
427                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
428                                throw new SocketTimeoutException("Timeout during handshake");
429                            }
430                            readable = key.isReadable();
431                        }
432                        break;
433                    case NEED_TASK:
434                        Runnable task;
435                        while ((task = sslEngine.getDelegatedTask()) != null) {
436                            task.run();
437                        }
438                        break;
439                    case NEED_WRAP:
440                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
441                        break;
442                    case FINISHED:
443                    case NOT_HANDSHAKING:
444                        finishHandshake();
445                        return;
446                }
447            }
448        } finally {
449            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
450            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
451        }
452    }
453
454    @Override
455    protected void doStart() throws Exception {
456        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
457        // no need to init as we can delay that until demand (eg in doHandshake)
458        super.doStart();
459    }
460
461    @Override
462    protected void doStop(ServiceStopper stopper) throws Exception {
463        if (taskRunnerFactory != null) {
464            taskRunnerFactory.shutdownNow();
465            taskRunnerFactory = null;
466        }
467        if (channel != null) {
468            channel.close();
469            channel = null;
470        }
471        super.doStop(stopper);
472    }
473
474    /**
475     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
476     *
477     * @param command
478     *            The Command coming in.
479     */
480    @Override
481    public void doConsume(Object command) {
482        if (command instanceof ConnectionInfo) {
483            ConnectionInfo connectionInfo = (ConnectionInfo) command;
484            connectionInfo.setTransportContext(getPeerCertificates());
485        }
486        super.doConsume(command);
487    }
488
489    /**
490     * @return peer certificate chain associated with the ssl socket
491     */
492    public X509Certificate[] getPeerCertificates() {
493
494        X509Certificate[] clientCertChain = null;
495        try {
496            if (sslEngine.getSession() != null) {
497                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
498            }
499        } catch (SSLPeerUnverifiedException e) {
500            if (LOG.isTraceEnabled()) {
501                LOG.trace("Failed to get peer certificates.", e);
502            }
503        }
504
505        return clientCertChain;
506    }
507
508    public boolean isNeedClientAuth() {
509        return needClientAuth;
510    }
511
512    public void setNeedClientAuth(boolean needClientAuth) {
513        this.needClientAuth = needClientAuth;
514    }
515
516    public boolean isWantClientAuth() {
517        return wantClientAuth;
518    }
519
520    public void setWantClientAuth(boolean wantClientAuth) {
521        this.wantClientAuth = wantClientAuth;
522    }
523
524    public String[] getEnabledCipherSuites() {
525        return enabledCipherSuites;
526    }
527
528    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
529        this.enabledCipherSuites = enabledCipherSuites;
530    }
531
532    public String[] getEnabledProtocols() {
533        return enabledProtocols;
534    }
535
536    public void setEnabledProtocols(String[] enabledProtocols) {
537        this.enabledProtocols = enabledProtocols;
538    }
539}