001 /**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.activemq.transport.nio;
019
020 import java.io.DataInputStream;
021 import java.io.DataOutputStream;
022 import java.io.EOFException;
023 import java.io.IOException;
024 import java.net.Socket;
025 import java.net.URI;
026 import java.net.UnknownHostException;
027 import java.nio.ByteBuffer;
028 import java.security.cert.X509Certificate;
029
030 import javax.net.SocketFactory;
031 import javax.net.ssl.SSLContext;
032 import javax.net.ssl.SSLEngine;
033 import javax.net.ssl.SSLEngineResult;
034 import javax.net.ssl.SSLPeerUnverifiedException;
035 import javax.net.ssl.SSLSession;
036
037 import org.apache.activemq.command.Command;
038 import org.apache.activemq.command.ConnectionInfo;
039 import org.apache.activemq.openwire.OpenWireFormat;
040 import org.apache.activemq.thread.TaskRunnerFactory;
041 import org.apache.activemq.util.IOExceptionSupport;
042 import org.apache.activemq.util.ServiceStopper;
043 import org.apache.activemq.wireformat.WireFormat;
044
045 public class NIOSSLTransport extends NIOTransport {
046
047 protected boolean needClientAuth;
048 protected boolean wantClientAuth;
049 protected String[] enabledCipherSuites;
050
051 protected SSLContext sslContext;
052 protected SSLEngine sslEngine;
053 protected SSLSession sslSession;
054
055 protected volatile boolean handshakeInProgress = false;
056 protected SSLEngineResult.Status status = null;
057 protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
058 protected TaskRunnerFactory taskRunnerFactory;
059
060 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
061 super(wireFormat, socketFactory, remoteLocation, localLocation);
062 }
063
064 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
065 super(wireFormat, socket);
066 }
067
068 public void setSslContext(SSLContext sslContext) {
069 this.sslContext = sslContext;
070 }
071
072 @Override
073 protected void initializeStreams() throws IOException {
074 try {
075 channel = socket.getChannel();
076 channel.configureBlocking(false);
077
078 if (sslContext == null) {
079 sslContext = SSLContext.getDefault();
080 }
081
082 // initialize engine, the initial sslSession we get will need to be
083 // updated once the ssl handshake process is completed.
084 sslEngine = sslContext.createSSLEngine();
085 sslEngine.setUseClientMode(false);
086 if (enabledCipherSuites != null) {
087 sslEngine.setEnabledCipherSuites(enabledCipherSuites);
088 }
089 sslEngine.setNeedClientAuth(needClientAuth);
090 sslEngine.setWantClientAuth(wantClientAuth);
091
092 sslSession = sslEngine.getSession();
093
094 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
095 inputBuffer.clear();
096 currentBuffer = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
097
098 NIOOutputStream outputStream = new NIOOutputStream(channel);
099 outputStream.setEngine(sslEngine);
100 this.dataOut = new DataOutputStream(outputStream);
101 this.buffOut = outputStream;
102 sslEngine.beginHandshake();
103 handshakeStatus = sslEngine.getHandshakeStatus();
104 doHandshake();
105 } catch (Exception e) {
106 throw new IOException(e);
107 }
108 }
109
110 protected void finishHandshake() throws Exception {
111 if (handshakeInProgress) {
112 handshakeInProgress = false;
113 nextFrameSize = -1;
114
115 // Once handshake completes we need to ask for the now real sslSession
116 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
117 // cipher suite.
118 sslSession = sslEngine.getSession();
119
120 // listen for events telling us when the socket is readable.
121 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
122 public void onSelect(SelectorSelection selection) {
123 serviceRead();
124 }
125
126 public void onError(SelectorSelection selection, Throwable error) {
127 if (error instanceof IOException) {
128 onException((IOException) error);
129 } else {
130 onException(IOExceptionSupport.create(error));
131 }
132 }
133 });
134 }
135 }
136
137 protected void serviceRead() {
138 try {
139 if (handshakeInProgress) {
140 doHandshake();
141 }
142
143 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
144 plain.position(plain.limit());
145
146 while(true) {
147 if (!plain.hasRemaining()) {
148
149 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
150 plain.clear();
151 } else {
152 plain.compact();
153 }
154 int readCount = secureRead(plain);
155
156
157 if (readCount == 0)
158 break;
159
160 // channel is closed, cleanup
161 if (readCount== -1) {
162 onException(new EOFException());
163 selection.close();
164 break;
165 }
166 }
167
168 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
169 processCommand(plain);
170 }
171 }
172 } catch (IOException e) {
173 onException(e);
174 } catch (Throwable e) {
175 onException(IOExceptionSupport.create(e));
176 }
177 }
178
179 protected void processCommand(ByteBuffer plain) throws Exception {
180 nextFrameSize = plain.getInt();
181 if (wireFormat instanceof OpenWireFormat) {
182 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
183 if (nextFrameSize > maxFrameSize) {
184 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
185 }
186 }
187 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
188 currentBuffer.putInt(nextFrameSize);
189 if (currentBuffer.hasRemaining()) {
190 if (currentBuffer.remaining() >= plain.remaining()) {
191 currentBuffer.put(plain);
192 } else {
193 byte[] fill = new byte[currentBuffer.remaining()];
194 plain.get(fill);
195 currentBuffer.put(fill);
196 }
197 }
198
199 if (currentBuffer.hasRemaining()) {
200 return;
201 } else {
202 currentBuffer.flip();
203 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
204 doConsume((Command) command);
205 nextFrameSize = -1;
206 }
207 }
208
209 protected int secureRead(ByteBuffer plain) throws Exception {
210
211 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
212 int bytesRead = channel.read(inputBuffer);
213
214 if (bytesRead == -1) {
215 sslEngine.closeInbound();
216 if (inputBuffer.position() == 0 ||
217 status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
218 return -1;
219 }
220 }
221 }
222
223 plain.clear();
224
225 inputBuffer.flip();
226 SSLEngineResult res;
227 do {
228 res = sslEngine.unwrap(inputBuffer, plain);
229 } while (res.getStatus() == SSLEngineResult.Status.OK &&
230 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
231 res.bytesProduced() == 0);
232
233 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
234 finishHandshake();
235 }
236
237 status = res.getStatus();
238 handshakeStatus = res.getHandshakeStatus();
239
240 //TODO deal with BUFFER_OVERFLOW
241
242 if (status == SSLEngineResult.Status.CLOSED) {
243 sslEngine.closeInbound();
244 return -1;
245 }
246
247 inputBuffer.compact();
248 plain.flip();
249
250 return plain.remaining();
251 }
252
253 protected void doHandshake() throws Exception {
254 handshakeInProgress = true;
255 while (true) {
256 switch (sslEngine.getHandshakeStatus()) {
257 case NEED_UNWRAP:
258 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
259 break;
260 case NEED_TASK:
261 Runnable task;
262 while ((task = sslEngine.getDelegatedTask()) != null) {
263 taskRunnerFactory.execute(task);
264 }
265 break;
266 case NEED_WRAP:
267 ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
268 break;
269 case FINISHED:
270 case NOT_HANDSHAKING:
271 finishHandshake();
272 return;
273 }
274 }
275 }
276
277 @Override
278 protected void doStart() throws Exception {
279 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
280 // no need to init as we can delay that until demand (eg in doHandshake)
281 super.doStart();
282 }
283
284 @Override
285 protected void doStop(ServiceStopper stopper) throws Exception {
286 if (taskRunnerFactory != null) {
287 taskRunnerFactory.shutdownNow();
288 taskRunnerFactory = null;
289 }
290 if (channel != null) {
291 channel.close();
292 channel = null;
293 }
294 super.doStop(stopper);
295 }
296
297 /**
298 * Overriding in order to add the client's certificates to ConnectionInfo Commmands.
299 *
300 * @param command The Command coming in.
301 */
302 @Override
303 public void doConsume(Object command) {
304 if (command instanceof ConnectionInfo) {
305 ConnectionInfo connectionInfo = (ConnectionInfo)command;
306 connectionInfo.setTransportContext(getPeerCertificates());
307 }
308 super.doConsume(command);
309 }
310
311 /**
312 * @return peer certificate chain associated with the ssl socket
313 */
314 public X509Certificate[] getPeerCertificates() {
315
316 X509Certificate[] clientCertChain = null;
317 try {
318 if (sslSession != null) {
319 clientCertChain = (X509Certificate[])sslSession.getPeerCertificates();
320 }
321 } catch (SSLPeerUnverifiedException e) {
322 }
323
324 return clientCertChain;
325 }
326
327 public boolean isNeedClientAuth() {
328 return needClientAuth;
329 }
330
331 public void setNeedClientAuth(boolean needClientAuth) {
332 this.needClientAuth = needClientAuth;
333 }
334
335 public boolean isWantClientAuth() {
336 return wantClientAuth;
337 }
338
339 public void setWantClientAuth(boolean wantClientAuth) {
340 this.wantClientAuth = wantClientAuth;
341 }
342
343 public String[] getEnabledCipherSuites() {
344 return enabledCipherSuites;
345 }
346
347 public void setEnabledCipherSuites(String[] enabledCipherSuites) {
348 this.enabledCipherSuites = enabledCipherSuites;
349 }
350 }