001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.mqtt;
018
019import java.util.Map;
020import java.util.concurrent.ConcurrentHashMap;
021
022import org.apache.activemq.Service;
023import org.apache.activemq.broker.BrokerService;
024import org.apache.activemq.command.ActiveMQMessage;
025import org.apache.activemq.util.LRUCache;
026import org.apache.activemq.util.ServiceStopper;
027import org.apache.activemq.util.ServiceSupport;
028import org.fusesource.mqtt.codec.PUBLISH;
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032/**
033 * Manages PUBLISH packet ids for clients.
034 *
035 * @author Dhiraj Bokde
036 */
037public class MQTTPacketIdGenerator extends ServiceSupport {
038
039    private static final Logger LOG = LoggerFactory.getLogger(MQTTPacketIdGenerator.class);
040    private static final Object LOCK = new Object();
041
042    Map<String, PacketIdMaps> clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>();
043
044    private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();
045
046    private MQTTPacketIdGenerator() {
047    }
048
049    @Override
050    protected void doStop(ServiceStopper stopper) throws Exception {
051        synchronized (this) {
052            clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>();
053        }
054    }
055
056    @Override
057    protected void doStart() throws Exception {
058    }
059
060    public void startClientSession(String clientId) {
061        if (!clientIdMap.containsKey(clientId)) {
062            clientIdMap.put(clientId, new PacketIdMaps());
063        }
064    }
065
066    public boolean stopClientSession(String clientId) {
067        return clientIdMap.remove(clientId) != null;
068    }
069
070    public short setPacketId(String clientId, MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
071        final PacketIdMaps idMaps = clientIdMap.get(clientId);
072        if (idMaps == null) {
073            // maybe its a cleansession=true client id, use session less message id
074            final short id = messageIdGenerator.getNextSequenceId();
075            publish.messageId(id);
076            return id;
077        } else {
078            return idMaps.setPacketId(subscription, message, publish);
079        }
080    }
081
082    public void ackPacketId(String clientId, short packetId) {
083        final PacketIdMaps idMaps = clientIdMap.get(clientId);
084        if (idMaps != null) {
085            idMaps.ackPacketId(packetId);
086        }
087    }
088
089    public short getNextSequenceId(String clientId) {
090        final PacketIdMaps idMaps = clientIdMap.get(clientId);
091        return idMaps != null ? idMaps.getNextSequenceId(): messageIdGenerator.getNextSequenceId();
092    }
093
094    public static MQTTPacketIdGenerator getMQTTPacketIdGenerator(BrokerService broker) {
095        MQTTPacketIdGenerator result = null;
096        if (broker != null) {
097            synchronized (LOCK) {
098                Service[] services = broker.getServices();
099                if (services != null) {
100                    for (Service service : services) {
101                        if (service instanceof MQTTPacketIdGenerator) {
102                            return (MQTTPacketIdGenerator) service;
103                        }
104                    }
105                }
106                result = new MQTTPacketIdGenerator();
107                broker.addService(result);
108                if (broker.isStarted()) {
109                    try {
110                        result.start();
111                    } catch (Exception e) {
112                        LOG.warn("Couldn't start MQTTPacketIdGenerator");
113                    }
114                }
115            }
116        }
117
118
119        return result;
120    }
121
122    private class PacketIdMaps {
123
124        private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();
125        final Map<String, Short> activemqToPacketIds = new LRUCache<String, Short>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);
126        final Map<Short, String> packetIdsToActivemq = new LRUCache<Short, String>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);
127
128        short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
129            // subscription key
130            final StringBuilder subscriptionKey = new StringBuilder();
131            subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName())
132                .append(':').append(message.getJMSMessageID());
133            final String keyStr = subscriptionKey.toString();
134            Short packetId;
135            synchronized (activemqToPacketIds) {
136                packetId = activemqToPacketIds.get(keyStr);
137                if (packetId == null) {
138                    packetId = getNextSequenceId();
139                    activemqToPacketIds.put(keyStr, packetId);
140                    packetIdsToActivemq.put(packetId, keyStr);
141                } else {
142                    // mark publish as duplicate!
143                    publish.dup(true);
144                }
145            }
146            publish.messageId(packetId);
147            return packetId;
148        }
149
150        void ackPacketId(short packetId) {
151            synchronized (activemqToPacketIds) {
152                final String subscriptionKey = packetIdsToActivemq.remove(packetId);
153                if (subscriptionKey != null) {
154                    activemqToPacketIds.remove(subscriptionKey);
155                }
156            }
157        }
158
159        short getNextSequenceId() {
160            return messageIdGenerator.getNextSequenceId();
161        }
162
163    }
164
165    private class NonZeroSequenceGenerator {
166
167        private short lastSequenceId;
168
169        public synchronized short getNextSequenceId() {
170            final short val = ++lastSequenceId;
171            return val != 0 ? val : ++lastSequenceId;
172        }
173
174    }
175
176}