001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.amqp;
018
019import java.io.DataInput;
020import java.io.DataInputStream;
021import java.io.DataOutput;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.io.OutputStream;
025import java.nio.ByteBuffer;
026import java.nio.channels.Channels;
027import java.nio.channels.WritableByteChannel;
028
029import org.apache.activemq.transport.amqp.message.InboundTransformer;
030import org.apache.activemq.util.ByteArrayInputStream;
031import org.apache.activemq.util.ByteArrayOutputStream;
032import org.apache.activemq.util.ByteSequence;
033import org.apache.activemq.wireformat.WireFormat;
034import org.fusesource.hawtbuf.Buffer;
035import org.slf4j.Logger;
036import org.slf4j.LoggerFactory;
037
038public class AmqpWireFormat implements WireFormat {
039
040    private static final Logger LOG = LoggerFactory.getLogger(AmqpWireFormat.class);
041
042    public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE;
043    public static final int NO_AMQP_MAX_FRAME_SIZE = -1;
044    public static final int DEFAULT_CONNECTION_TIMEOUT = 30000;
045    public static final int DEFAULT_IDLE_TIMEOUT = 30000;
046    public static final int DEFAULT_PRODUCER_CREDIT = 1000;
047    public static final boolean DEFAULT_ALLOW_NON_SASL_CONNECTIONS = false;
048    public static final int DEFAULT_ANQP_FRAME_SIZE = 128 * 1024;
049
050    private static final int SASL_PROTOCOL = 3;
051
052    private int version = 1;
053    private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
054    private int maxAmqpFrameSize = DEFAULT_ANQP_FRAME_SIZE;
055    private int connectAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT;
056    private int idelTimeout = DEFAULT_IDLE_TIMEOUT;
057    private int producerCredit = DEFAULT_PRODUCER_CREDIT;
058    private String transformer = InboundTransformer.TRANSFORMER_JMS;
059    private boolean allowNonSaslConnections = DEFAULT_ALLOW_NON_SASL_CONNECTIONS;
060
061    private boolean magicRead = false;
062    private ResetListener resetListener;
063
064    public interface ResetListener {
065        void onProtocolReset();
066    }
067
068    @Override
069    public ByteSequence marshal(Object command) throws IOException {
070        ByteArrayOutputStream baos = new ByteArrayOutputStream();
071        DataOutputStream dos = new DataOutputStream(baos);
072        marshal(command, dos);
073        dos.close();
074        return baos.toByteSequence();
075    }
076
077    @Override
078    public Object unmarshal(ByteSequence packet) throws IOException {
079        ByteArrayInputStream stream = new ByteArrayInputStream(packet);
080        DataInputStream dis = new DataInputStream(stream);
081        return unmarshal(dis);
082    }
083
084    @Override
085    public void marshal(Object command, DataOutput dataOut) throws IOException {
086        if (command instanceof ByteBuffer) {
087            ByteBuffer buffer = (ByteBuffer) command;
088
089            if (dataOut instanceof OutputStream) {
090                WritableByteChannel channel = Channels.newChannel((OutputStream) dataOut);
091                channel.write(buffer);
092            } else {
093                while (buffer.hasRemaining()) {
094                    dataOut.writeByte(buffer.get());
095                }
096            }
097        } else {
098            Buffer frame = (Buffer) command;
099            frame.writeTo(dataOut);
100        }
101    }
102
103    @Override
104    public Object unmarshal(DataInput dataIn) throws IOException {
105        if (!magicRead) {
106            Buffer magic = new Buffer(8);
107            magic.readFrom(dataIn);
108            magicRead = true;
109            return new AmqpHeader(magic, false);
110        } else {
111            int size = dataIn.readInt();
112            if (size > maxFrameSize) {
113                throw new AmqpProtocolException("Frame size exceeded max frame length.");
114            } else if (size <= 0) {
115                throw new AmqpProtocolException("Frame size value was invalid: " + size);
116            }
117            Buffer frame = new Buffer(size);
118            frame.bigEndianEditor().writeInt(size);
119            frame.readFrom(dataIn);
120            frame.clear();
121            return frame;
122        }
123    }
124
125    /**
126     * Given an AMQP header validate that the AMQP magic is present and
127     * if so that the version and protocol values align with what we support.
128     *
129     * In the case where authentication occurs the client sends us two AMQP
130     * headers, the first being the SASL initial header which triggers the
131     * authentication process and then if that succeeds we should get a second
132     * AMQP header that does not contain the SASL protocol ID indicating the
133     * connection process should follow the normal path.  We validate that the
134     * header align with these expectations.
135     *
136     * @param header
137     *        the header instance received from the client.
138     * @param authenticated
139     *        has the client already authenticated already.
140     *
141     * @return true if the header is valid against the current WireFormat.
142     */
143    public boolean isHeaderValid(AmqpHeader header, boolean authenticated) {
144        if (!header.hasValidPrefix()) {
145            LOG.trace("AMQP Header arrived with invalid prefix: {}", header);
146            return false;
147        }
148
149        if (!(header.getProtocolId() == 0 || header.getProtocolId() == SASL_PROTOCOL)) {
150            LOG.trace("AMQP Header arrived with invalid protocol ID: {}", header);
151            return false;
152        }
153
154        if (!authenticated && !isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) {
155            LOG.trace("AMQP Header arrived without SASL and server requires SASL: {}", header);
156            return false;
157        }
158
159        if (header.getMajor() != 1 || header.getMinor() != 0 || header.getRevision() != 0) {
160            LOG.trace("AMQP Header arrived invalid version: {}", header);
161            return false;
162        }
163
164        return true;
165    }
166
167    /**
168     * Returns an AMQP Header object that represents the minimally protocol
169     * versions supported by this transport.  A client that attempts to
170     * connect with an AMQP version that doesn't at least meat this value
171     * will receive this prior to the connection being closed.
172     *
173     * @return the minimal AMQP version needed from the client.
174     */
175    public AmqpHeader getMinimallySupportedHeader() {
176        AmqpHeader header = new AmqpHeader();
177        if (!isAllowNonSaslConnections()) {
178            header.setProtocolId(3);
179        }
180
181        return header;
182    }
183
184    @Override
185    public void setVersion(int version) {
186        this.version = version;
187    }
188
189    @Override
190    public int getVersion() {
191        return this.version;
192    }
193
194    public void resetMagicRead() {
195        this.magicRead = false;
196        if (resetListener != null) {
197            resetListener.onProtocolReset();
198        }
199    }
200
201    public void setProtocolResetListener(ResetListener listener) {
202        this.resetListener = listener;
203    }
204
205    public boolean isMagicRead() {
206        return this.magicRead;
207    }
208
209    public long getMaxFrameSize() {
210        return maxFrameSize;
211    }
212
213    public void setMaxFrameSize(long maxFrameSize) {
214        this.maxFrameSize = maxFrameSize;
215    }
216
217    public int getMaxAmqpFrameSize() {
218        return maxAmqpFrameSize;
219    }
220
221    public void setMaxAmqpFrameSize(int maxAmqpFrameSize) {
222        this.maxAmqpFrameSize = maxAmqpFrameSize;
223    }
224
225    public boolean isAllowNonSaslConnections() {
226        return allowNonSaslConnections;
227    }
228
229    public void setAllowNonSaslConnections(boolean allowNonSaslConnections) {
230        this.allowNonSaslConnections = allowNonSaslConnections;
231    }
232
233    public int getConnectAttemptTimeout() {
234        return connectAttemptTimeout;
235    }
236
237    public void setConnectAttemptTimeout(int connectAttemptTimeout) {
238        this.connectAttemptTimeout = connectAttemptTimeout;
239    }
240
241    public void setProducerCredit(int producerCredit) {
242        this.producerCredit = producerCredit;
243    }
244
245    public int getProducerCredit() {
246        return producerCredit;
247    }
248
249    public String getTransformer() {
250        return transformer;
251    }
252
253    public void setTransformer(String transformer) {
254        this.transformer = transformer;
255    }
256
257    public int getIdleTimeout() {
258        return idelTimeout;
259    }
260
261    public void setIdleTimeout(int idelTimeout) {
262        this.idelTimeout = idelTimeout;
263    }
264}