001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.stomp;
018
019import java.io.DataInput;
020import java.io.DataInputStream;
021import java.io.DataOutput;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.PushbackInputStream;
026import java.util.HashMap;
027import java.util.Map;
028import java.util.concurrent.atomic.AtomicLong;
029
030import org.apache.activemq.util.ByteArrayInputStream;
031import org.apache.activemq.util.ByteArrayOutputStream;
032import org.apache.activemq.util.ByteSequence;
033import org.apache.activemq.wireformat.WireFormat;
034
035/**
036 * Implements marshalling and unmarsalling the <a
037 * href="http://stomp.codehaus.org/">Stomp</a> protocol.
038 */
039public class StompWireFormat implements WireFormat {
040
041    private static final byte[] NO_DATA = new byte[] {};
042    private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
043
044    private static final int MAX_COMMAND_LENGTH = 1024;
045    private static final int MAX_HEADER_LENGTH = 1024 * 10;
046    private static final int MAX_HEADERS = 1000;
047
048    public static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
049    public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE;
050    public static final long DEFAULT_CONNECTION_TIMEOUT = 30000;
051
052    private int version = 1;
053    private int maxDataLength = MAX_DATA_LENGTH;
054    private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
055    private String stompVersion = Stomp.DEFAULT_VERSION;
056    private long connectionAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT;
057
058    //The current frame size as it is unmarshalled from the stream
059    private final AtomicLong frameSize = new AtomicLong();
060
061    @Override
062    public ByteSequence marshal(Object command) throws IOException {
063        ByteArrayOutputStream baos = new ByteArrayOutputStream();
064        DataOutputStream dos = new DataOutputStream(baos);
065        marshal(command, dos);
066        dos.close();
067        return baos.toByteSequence();
068    }
069
070    @Override
071    public Object unmarshal(ByteSequence packet) throws IOException {
072        ByteArrayInputStream stream = new ByteArrayInputStream(packet);
073        DataInputStream dis = new DataInputStream(stream);
074        return unmarshal(dis);
075    }
076
077    private StringBuilder marshalHeaders(StompFrame stomp, StringBuilder buffer) throws IOException {
078        buffer.append(stomp.getAction());
079        buffer.append(Stomp.NEWLINE);
080
081        // Output the headers.
082        for (Map.Entry<String, String> entry : stomp.getHeaders().entrySet()) {
083            buffer.append(entry.getKey());
084            buffer.append(Stomp.Headers.SEPERATOR);
085            buffer.append(encodeHeader(entry.getValue()));
086            buffer.append(Stomp.NEWLINE);
087        }
088
089        // Add a newline to separate the headers from the content.
090        buffer.append(Stomp.NEWLINE);
091
092        return buffer;
093    }
094
095    @Override
096    public void marshal(Object command, DataOutput os) throws IOException {
097        StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
098
099        if (stomp.getAction().equals(Stomp.Commands.KEEPALIVE)) {
100            os.write(Stomp.BREAK);
101            return;
102        }
103
104        StringBuilder builder = new StringBuilder();
105
106        os.write(marshalHeaders(stomp, builder).toString().getBytes("UTF-8"));
107        os.write(stomp.getContent());
108        os.write(END_OF_FRAME);
109    }
110
111    public String marshalToString(StompFrame stomp) throws IOException {
112        if (stomp.getAction().equals(Stomp.Commands.KEEPALIVE)) {
113            return String.valueOf((char)Stomp.BREAK);
114        }
115
116        StringBuilder buffer = new StringBuilder();
117        marshalHeaders(stomp, buffer);
118
119        if (stomp.getContent() != null) {
120            String contentString = new String(stomp.getContent(), "UTF-8");
121            buffer.append(contentString);
122        }
123
124        buffer.append('\u0000');
125        return buffer.toString();
126    }
127
128    @Override
129    public Object unmarshal(DataInput in) throws IOException {
130        try {
131            // parse action
132            String action = parseAction(in, frameSize);
133
134            // Parse the headers
135            HashMap<String, String> headers = parseHeaders(in, frameSize);
136
137            // Read in the data part.
138            byte[] data = NO_DATA;
139            String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
140            if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) {
141
142                // Bless the client, he's telling us how much data to read in.
143                int length = parseContentLength(contentLength, frameSize);
144
145                data = new byte[length];
146                in.readFully(data);
147
148                if (in.readByte() != 0) {
149                    throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
150                }
151
152            } else {
153
154                // We don't know how much to read.. data ends when we hit a 0
155                byte b;
156                ByteArrayOutputStream baos = null;
157                while ((b = in.readByte()) != 0) {
158                    if (baos == null) {
159                        baos = new ByteArrayOutputStream();
160                    } else if (baos.size() > getMaxDataLength()) {
161                        throw new ProtocolException("The maximum data length was exceeded", true);
162                    } else {
163                        if (frameSize.incrementAndGet() > getMaxFrameSize()) {
164                            throw new ProtocolException("The maximum frame size was exceeded", true);
165                        }
166                    }
167
168                    baos.write(b);
169                }
170
171                if (baos != null) {
172                    baos.close();
173                    data = baos.toByteArray();
174                }
175            }
176
177            return new StompFrame(action, headers, data);
178
179        } catch (ProtocolException e) {
180            return new StompFrameError(e);
181        } finally {
182            frameSize.set(0);
183        }
184    }
185
186    private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
187        ByteSequence sequence = readHeaderLine(in, maxLength, errorMessage);
188        return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8").trim();
189    }
190
191    private ByteSequence readHeaderLine(DataInput in, int maxLength, String errorMessage) throws IOException {
192        byte b;
193        ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
194        while ((b = in.readByte()) != '\n') {
195            if (baos.size() > maxLength) {
196                baos.close();
197                throw new ProtocolException(errorMessage, true);
198            }
199            baos.write(b);
200        }
201
202        baos.close();
203        ByteSequence line = baos.toByteSequence();
204
205        if (stompVersion.equals(Stomp.V1_0) || stompVersion.equals(Stomp.V1_2)) {
206            int lineLength = line.getLength();
207            if (lineLength > 0 && line.data[lineLength-1] == '\r') {
208                line.setLength(lineLength-1);
209            }
210        }
211
212        return line;
213    }
214
215    protected String parseAction(DataInput in, AtomicLong frameSize) throws IOException {
216        String action = null;
217
218        // skip white space to next real action line
219        while (true) {
220            action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
221            if (action == null) {
222                throw new IOException("connection was closed");
223            } else {
224                action = action.trim();
225                if (action.length() > 0) {
226                    break;
227                }
228            }
229        }
230        frameSize.addAndGet(action.length());
231        return action;
232    }
233
234    protected HashMap<String, String> parseHeaders(DataInput in, AtomicLong frameSize) throws IOException {
235        HashMap<String, String> headers = new HashMap<>(25);
236        while (true) {
237            ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
238            if (line != null && line.length > 1) {
239
240                if (headers.size() > MAX_HEADERS) {
241                    throw new ProtocolException("The maximum number of headers was exceeded", true);
242                }
243                frameSize.addAndGet(line.length);
244
245                try {
246
247                    ByteArrayInputStream headerLine = new ByteArrayInputStream(line);
248                    ByteArrayOutputStream stream = new ByteArrayOutputStream(line.length);
249
250                    // First complete the name
251                    int result = -1;
252                    while ((result = headerLine.read()) != -1) {
253                        if (result != ':') {
254                            stream.write(result);
255                        } else {
256                            break;
257                        }
258                    }
259
260                    ByteSequence nameSeq = stream.toByteSequence();
261
262                    String name = new String(nameSeq.getData(), nameSeq.getOffset(), nameSeq.getLength(), "UTF-8");
263                    String value = decodeHeader(headerLine);
264                    if (stompVersion.equals(Stomp.V1_0)) {
265                        value = value.trim();
266                    }
267
268                    if (!headers.containsKey(name)) {
269                        headers.put(name, value);
270                    }
271
272                    stream.close();
273
274                } catch (Exception e) {
275                    throw new ProtocolException("Unable to parser header line [" + line + "]", true);
276                }
277            } else {
278                break;
279            }
280        }
281        return headers;
282    }
283
284    protected int parseContentLength(String contentLength, AtomicLong frameSize) throws ProtocolException {
285        int length;
286        try {
287            length = Integer.parseInt(contentLength.trim());
288        } catch (NumberFormatException e) {
289            throw new ProtocolException("Specified content-length is not a valid integer", true);
290        }
291
292        if (length > getMaxDataLength()) {
293            throw new ProtocolException("The maximum data length was exceeded", true);
294        }
295
296        if (frameSize.addAndGet(length) > getMaxFrameSize()) {
297            throw new ProtocolException("The maximum frame size was exceeded", true);
298        }
299
300        return length;
301    }
302
303    private String encodeHeader(String header) throws IOException {
304        String result = header;
305        if (!stompVersion.equals(Stomp.V1_0)) {
306            byte[] utf8buf = header.getBytes("UTF-8");
307            ByteArrayOutputStream stream = new ByteArrayOutputStream(utf8buf.length);
308            for(byte val : utf8buf) {
309                switch(val) {
310                case Stomp.ESCAPE:
311                    stream.write(Stomp.ESCAPE_ESCAPE_SEQ);
312                    break;
313                case Stomp.BREAK:
314                    stream.write(Stomp.NEWLINE_ESCAPE_SEQ);
315                    break;
316                case Stomp.COLON:
317                    stream.write(Stomp.COLON_ESCAPE_SEQ);
318                    break;
319                default:
320                    stream.write(val);
321                }
322            }
323            result =  new String(stream.toByteArray(), "UTF-8");
324            stream.close();
325        }
326
327        return result;
328    }
329
330    private String decodeHeader(InputStream header) throws IOException {
331        ByteArrayOutputStream decoded = new ByteArrayOutputStream();
332        PushbackInputStream stream = new PushbackInputStream(header);
333
334        int value = -1;
335        while( (value = stream.read()) != -1) {
336            if (value == 92) {
337
338                int next = stream.read();
339                if (next != -1) {
340                    switch(next) {
341                    case 110:
342                        decoded.write(Stomp.BREAK);
343                        break;
344                    case 99:
345                        decoded.write(Stomp.COLON);
346                        break;
347                    case 92:
348                        decoded.write(Stomp.ESCAPE);
349                        break;
350                    default:
351                        stream.unread(next);
352                        decoded.write(value);
353                    }
354                } else {
355                    decoded.write(value);
356                }
357
358            } else {
359                decoded.write(value);
360            }
361        }
362
363        decoded.close();
364
365        return new String(decoded.toByteArray(), "UTF-8");
366    }
367
368    @Override
369    public int getVersion() {
370        return version;
371    }
372
373    @Override
374    public void setVersion(int version) {
375        this.version = version;
376    }
377
378    public String getStompVersion() {
379        return stompVersion;
380    }
381
382    public void setStompVersion(String stompVersion) {
383        this.stompVersion = stompVersion;
384    }
385
386    public void setMaxDataLength(int maxDataLength) {
387        this.maxDataLength = maxDataLength;
388    }
389
390    public int getMaxDataLength() {
391        return maxDataLength;
392    }
393
394    public long getMaxFrameSize() {
395        return maxFrameSize;
396    }
397
398    public void setMaxFrameSize(long maxFrameSize) {
399        this.maxFrameSize = maxFrameSize;
400    }
401
402    public long getConnectionAttemptTimeout() {
403        return connectionAttemptTimeout;
404    }
405
406    public void setConnectionAttemptTimeout(long connectionAttemptTimeout) {
407        this.connectionAttemptTimeout = connectionAttemptTimeout;
408    }
409}