]> git.michaelhowe.org Git - packages/p/paho-mqtt.git/commitdiff
Separate in/out message queues.
authorRoger Light <roger@atchoo.org>
Fri, 20 Dec 2013 22:44:30 +0000 (22:44 +0000)
committerRoger Light <roger@atchoo.org>
Mon, 3 Feb 2014 21:20:23 +0000 (21:20 +0000)
src/paho/mqtt/client.py

index e4b083bfed711a584cff3a05acd1b3b05a614421..e1d0f5c2ed0bdad5951a54244018bfe610a28a2a 100755 (executable)
@@ -94,11 +94,6 @@ mqtt_cs_connected = 1
 mqtt_cs_disconnecting = 2
 mqtt_cs_connect_async = 3
 
-# Message direction
-mqtt_md_invalid = 0
-mqtt_md_in = 1
-mqtt_md_out = 2
-
 # Message state
 mqtt_ms_invalid = 0,
 mqtt_ms_wait_puback = 1
@@ -254,7 +249,6 @@ class MQTTMessage:
     """
     def __init__(self):
         self.timestamp = 0
-        self.direction = mqtt_md_invalid
         self.state = mqtt_ms_invalid
         self.dup = False
         self.mid = 0
@@ -264,7 +258,7 @@ class MQTTMessage:
         self.retain = False
 
 
-class Client:
+class Client(object):
     """MQTT version 3.1 client class.
 
     This is the main class for use communicating with an MQTT broker.
@@ -398,7 +392,8 @@ class Client:
         self._ping_t = 0
         self._last_mid = 0
         self._state = mqtt_cs_new
-        self._messages = []
+        self._out_messages = []
+        self._in_messages = []
         self._max_inflight_messages = 20
         self._inflight_messages = 0
         self._will = False
@@ -423,7 +418,8 @@ class Client:
         self._out_packet_mutex = threading.Lock()
         self._current_out_packet_mutex = threading.Lock()
         self._msgtime_mutex = threading.Lock()
-        self._message_mutex = threading.Lock()
+        self._out_message_mutex = threading.Lock()
+        self._in_message_mutex = threading.Lock()
         self._thread = None
         self._thread_terminate = False
         self._ssl = None
@@ -803,7 +799,6 @@ class Client:
         else:
             message = MQTTMessage()
             message.timestamp = time.time()
-            message.direction = mqtt_md_out
 
             message.mid = local_mid
             message.topic = topic
@@ -816,19 +811,19 @@ class Client:
             message.retain = retain
             message.dup = False
 
-            self._message_mutex.acquire()
-            self._messages.append(message)
+            self._out_message_mutex.acquire()
+            self._out_messages.append(message)
             if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages:
                 self._inflight_messages = self._inflight_messages+1
                 if qos == 1:
                     message.state = mqtt_ms_wait_puback
                 elif qos == 2:
                     message.state = mqtt_ms_wait_pubrec
-                self._message_mutex.release()
+                self._out_message_mutex.release()
 
                 rc = self._send_publish(message.mid, message.topic, message.payload, message.qos, message.retain, message.dup)
                 return (rc, local_mid)
-            self._message_mutex.release()
+            self._out_message_mutex.release()
             return (MQTT_ERR_SUCCESS, local_mid)
 
     def username_pw_set(self, username, password=None):
@@ -974,7 +969,7 @@ class Client:
         if self._sock is None and self._ssl is None:
             return MQTT_ERR_NO_CONN
 
-        max_packets = len(self._messages)
+        max_packets = len(self._out_messages) + len(self._in_messages)
         if max_packets < 1:
             max_packets = 1
 
@@ -999,7 +994,7 @@ class Client:
         if self._sock is None and self._ssl is None:
             return MQTT_ERR_NO_CONN
 
-        max_packets = len(self._messages)
+        max_packets = len(self._out_messages) + len(self._in_messages)
         if max_packets < 1:
             max_packets = 1
 
@@ -1155,7 +1150,7 @@ class Client:
                 if (self._thread_terminate is True
                         and self._current_out_packet is None
                         and len(self._out_packet) == 0
-                        and len(self._messages) == 0):
+                        and len(self._out_messages) == 0):
 
                     rc = 1
                     run = False
@@ -1668,22 +1663,10 @@ class Client:
             self._pack_str16(packet, t)
         return (self._packet_queue(command, packet, local_mid, 1), local_mid)
 
-    def _message_update(self, mid, direction, state):
-        self._message_mutex.acquire()
-        for m in self._messages:
-            if m.mid == mid and m.direction == direction:
-                m.state = state
-                m.timestamp = time.time()
-                self._message_mutex.release()
-                return MQTT_ERR_SUCCESS
-
-        self._message_mutex.release()
-        return MQTT_ERR_NOT_FOUND
-
-    def _message_retry_check(self):
-        self._message_mutex.acquire()
+    def _message_retry_check_actual(self, messages, mutex):
+        mutex.acquire()
         now = time.time()
-        for m in self._messages:
+        for m in messages:
             if m.timestamp + self._message_retry < now:
                 if m.state == mqtt_ms_wait_puback or m.state == mqtt_ms_wait_pubrec:
                     m.timestamp = now
@@ -1697,28 +1680,40 @@ class Client:
                     m.timestamp = now
                     m.dup = True
                     self._send_pubrel(m.mid, True)
-        self._message_mutex.release()
+        mutex.release()
 
-    def _messages_reconnect_reset(self):
-        self._message_mutex.acquire()
-        for m in self._messages:
+    def _message_retry_check(self):
+        self._message_retry_check_actual(self._out_messages, self._out_message_mutex)
+        self._message_retry_check_actual(self._in_messages, self._in_message_mutex)
+
+    def _messages_reconnect_reset_out(self):
+        self._out_message_mutex.acquire()
+        for m in self._out_messages:
             m.timestamp = 0
-            if m.direction == mqtt_md_out:
-                if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages:
-                    if m.qos == 1:
-                        m.state = mqtt_ms_wait_puback
-                    elif m.qos == 2:
-                        # Preserve current state
-                        pass
-                else:
-                    m.state = mqtt_ms_invalid
-            else:
-                if m.qos != 2:
-                    self._messages.pop(self._messages.index(m))
-                else:
+            if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages:
+                if m.qos == 1:
+                    m.state = mqtt_ms_wait_puback
+                elif m.qos == 2:
                     # Preserve current state
                     pass
-        self._message_mutex.release()
+            else:
+                m.state = mqtt_ms_invalid
+        self._out_message_mutex.release()
+
+    def _messages_reconnect_reset_in(self):
+        self._in_message_mutex.acquire()
+        for m in self._in_messages:
+            m.timestamp = 0
+            if m.qos != 2:
+                self._in_messages.pop(self._in_messages.index(m))
+            else:
+                # Preserve current state
+                pass
+        self._in_message_mutex.release()
+
+    def _messages_reconnect_reset(self):
+        self._messages_reconnect_reset_out()
+        self._messages_reconnect_reset_in()
 
     def _packet_queue(self, command, packet, mid, qos):
         mpkt = dict(
@@ -1832,7 +1827,6 @@ class Client:
 
         header = self._in_packet['command']
         message = MQTTMessage()
-        message.direction = mqtt_md_in
         message.dup = (header & 0x08)>>3
         message.qos = (header & 0x06)>>1
         message.retain = (header & 0x01)
@@ -1884,9 +1878,9 @@ class Client:
         elif message.qos == 2:
             rc = self._send_pubrec(message.mid)
             message.state = mqtt_ms_wait_pubrel
-            self._message_mutex.acquire()
-            self._messages.append(message)
-            self._message_mutex.release()
+            self._in_message_mutex.acquire()
+            self._in_messages.append(message)
+            self._in_message_mutex.release()
             return rc
         else:
             return MQTT_ERR_PROTOCOL
@@ -1903,37 +1897,39 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: "+str(mid)+")")
 
-        self._message_mutex.acquire()
-        for i in range(len(self._messages)):
-            if self._messages[i].direction == mqtt_md_in and self._messages[i].mid == mid:
+        self._in_message_mutex.acquire()
+        for i in range(len(self._in_messages)):
+            if self._in_messages[i].mid == mid:
 
                 # Only pass the message on if we have removed it from the queue - this
                 # prevents multiple callbacks for the same message.
                 self._callback_mutex.acquire()
                 if self.on_message:
                     self._in_callback = True
-                    self.on_message(self, self._userdata, self._messages[i])
+                    self.on_message(self, self._userdata, self._in_messages[i])
                     self._in_callback = False
                 self._callback_mutex.release()
-                self._messages.pop(i)
+                self._in_messages.pop(i)
                 self._inflight_messages = self._inflight_messages - 1
                 if self._max_inflight_messages > 0:
+                    self._out_message_mutex.acquire()
                     rc = self._update_inflight()
+                    self._out_message_mutex.release()
                     if rc != MQTT_ERR_SUCCESS:
-                        self._message_mutex.release()
+                        self._in_message_mutex.release()
                         return rc
 
-                self._message_mutex.release()
+                self._in_message_mutex.release()
                 return self._send_pubcomp(mid)
 
-        self._message_mutex.release()
+        self._in_message_mutex.release()
         return MQTT_ERR_SUCCESS
 
     def _update_inflight(self):
         # Dont lock message_mutex here
-        for m in self._messages:
+        for m in self._out_messages:
             if self._inflight_messages < self._max_inflight_messages:
-                if m.qos > 0 and m.state == mqtt_ms_invalid and m.direction == mqtt_md_out:
+                if m.qos > 0 and m.state == mqtt_ms_invalid:
                     self._inflight_messages = self._inflight_messages + 1
                     if m.qos == 1:
                         m.state = mqtt_ms_wait_puback
@@ -1955,15 +1951,15 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received PUBREC (Mid: "+str(mid)+")")
 
-        self._message_mutex.acquire()
-        for m in self._messages:
-            if m.direction == mqtt_md_out and m.mid == mid:
+        self._out_message_mutex.acquire()
+        for m in self._out_messages:
+            if m.mid == mid:
                 m.state = mqtt_ms_wait_pubcomp
                 m.timestamp = time.time()
-                self._message_mutex.release()
+                self._out_message_mutex.release()
                 return self._send_pubrel(mid, False)
 
-        self._message_mutex.release()
+        self._out_message_mutex.release()
         return MQTT_ERR_SUCCESS
 
     def _handle_unsuback(self):
@@ -1991,10 +1987,10 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received "+cmd+" (Mid: "+str(mid)+")")
 
-        self._message_mutex.acquire()
-        for i in range(len(self._messages)):
+        self._out_message_mutex.acquire()
+        for i in range(len(self._out_messages)):
             try:
-                if self._messages[i].direction == mqtt_md_out and self._messages[i].mid == mid:
+                if self._out_messages[i].mid == mid:
                     # Only inform the client the message has been sent once.
                     self._callback_mutex.acquire()
                     if self.on_publish:
@@ -2003,21 +1999,21 @@ class Client:
                         self._in_callback = False
 
                     self._callback_mutex.release()
-                    self._messages.pop(i)
+                    self._out_messages.pop(i)
                     self._inflight_messages = self._inflight_messages - 1
                     if self._max_inflight_messages > 0:
                         rc = self._update_inflight()
                         if rc != MQTT_ERR_SUCCESS:
-                            self._message_mutex.release()
+                            self._out_message_mutex.release()
                             return rc
-                    self._message_mutex.release()
+                    self._out_message_mutex.release()
                     return MQTT_ERR_SUCCESS
             except IndexError:
                 # Have removed item so i>count.
                 # Not really an error.
                 pass
 
-        self._message_mutex.release()
+        self._out_message_mutex.release()
         return MQTT_ERR_SUCCESS
 
     def _thread_main(self):