]> git.michaelhowe.org Git - packages/p/paho-mqtt.git/commitdiff
Implement inflight message handling.
authorRoger Light <roger@atchoo.org>
Sat, 22 Jun 2013 21:38:49 +0000 (22:38 +0100)
committerRoger Light <roger@atchoo.org>
Mon, 3 Feb 2014 21:17:07 +0000 (21:17 +0000)
src/paho/mqtt/client.py

index 5a1f609f8f256c093afeb09821080ddf00e5687b..60644b49878cbab3e0cef8080c0e5bd4a367cdf7 100755 (executable)
@@ -393,6 +393,8 @@ class Client:
         self._last_mid = 0
         self._state = mqtt_cs_new
         self._messages = []
+        self._max_inflight_messages = 20
+        self._inflight_messages = 0
         self._will = False
         self._will_topic = ""
         self._will_payload = None
@@ -415,6 +417,7 @@ class Client:
         self._out_packet_mutex = threading.Lock()
         self._current_out_packet_mutex = threading.Lock()
         self._msgtime_mutex = threading.Lock()
+        self._message_mutex = threading.Lock()
         self._thread = None
         self._thread_terminate = False
         self._ssl = None
@@ -721,10 +724,6 @@ class Client:
             message = MQTTMessage()
             message.timestamp = time.time()
             message.direction = mqtt_md_out
-            if qos == 1:
-                message.state = mqtt_ms_wait_puback
-            elif qos == 2:
-                message.state = mqtt_ms_wait_pubrec
 
             message.mid = local_mid
             message.topic = topic
@@ -737,9 +736,20 @@ class Client:
             message.retain = retain
             message.dup = False
 
+            self._message_mutex.acquire()
             self._messages.append(message)
-            rc = self._send_publish(message.mid, message.topic, message.payload, message.qos, message.retain, message.dup)
-            return (rc, local_mid)
+            if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages:
+                self._inflight_messages = self._inflight_messages+1
+                if qos == 1:
+                    message.state = mqtt_ms_wait_puback
+                elif qos == 2:
+                    message.state = mqtt_ms_wait_pubrec
+                self._message_mutex.release()
+
+                rc = self._send_publish(message.mid, message.topic, message.payload, message.qos, message.retain, message.dup)
+                return (rc, local_mid)
+            self._message_mutex.release()
+            return (MQTT_ERR_SUCCESS, local_mid)
 
     def username_pw_set(self, username, password=None):
         """Set a username and optionally a password for broker authentication.
@@ -907,6 +917,13 @@ class Client:
 
         return MQTT_ERR_SUCCESS
 
+    def max_inflight_messages_set(self, inflight):
+        """Set the maximum number of messages with QoS>0 that can be part way
+        through their network flow at once. Defaults to 20."""
+        if inflight < 0:
+            raise ValueError('Invalid inflight.')
+        self._max_inflight_messages = inflight
+
     def message_retry_set(self, retry):
         """Set the timeout in seconds before a message with QoS>0 is retried.
         20 seconds by default."""
@@ -1455,15 +1472,19 @@ class Client:
         return (self._packet_queue(command, packet, local_mid, 1), local_mid)
 
     def _message_update(self, mid, direction, state):
+        self._message_mutex.acquire()
         for m in self._messages:
             if m.mid == mid and m.direction == direction:
                 m.state = state
                 m.timestamp = time.time()
+                self._message_mutex.release()
                 return MQTT_ERR_SUCCESS
 
+        self._message_mutex.release()
         return MQTT_ERR_NOT_FOUND
 
     def _message_retry_check(self):
+        self._message_mutex.acquire()
         now = time.time()
         for m in self._messages:
             if m.timestamp + self._message_retry < now:
@@ -1479,17 +1500,23 @@ class Client:
                     m.timestamp = now
                     m.dup = True
                     self._send_pubrel(m.mid, True)
+        self._message_mutex.release()
 
     def _messages_reconnect_reset(self):
+        self._message_mutex.acquire()
         for m in self._messages:
             m.timestamp = 0
             if m.direction == mqtt_md_out:
-                if m.qos == 1:
-                    m.state = mqtt_ms_wait_puback
-                elif m.qos == 2:
-                    m.state = mqtt_ms_wait_pubrec
+                if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages:
+                    if m.qos == 1:
+                        m.state = mqtt_ms_wait_puback
+                    elif m.qos == 2:
+                        m.state = mqtt_ms_wait_pubrec
+                else:
+                    m.state = mqtt_ms_invalid
             else:
                 self._messages.pop(self._messages.index(m))
+        self._message_mutex.release()
 
     def _packet_queue(self, command, packet, mid, qos):
         mpkt = MQTTPacket(command, packet, mid, qos)
@@ -1648,7 +1675,9 @@ class Client:
         elif message.qos == 2:
             rc = self._send_pubrec(message.mid)
             message.state = mqtt_ms_wait_pubrel
+            self._message_mutex.acquire()
             self._messages.append(message)
+            self._message_mutex.release()
             return rc
         else:
             return MQTT_ERR_PROTOCOL
@@ -1665,6 +1694,7 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: "+str(mid)+")")
         
+        self._message_mutex.acquire()
         for i in range(len(self._messages)):
             if self._messages[i].direction == mqtt_md_in and self._messages[i].mid == mid:
 
@@ -1677,9 +1707,34 @@ class Client:
                     self._in_callback = False
                 self._callback_mutex.release()
                 self._messages.pop(i)
-
+                self._inflight_messages = self._inflight_messages - 1
+                if self._max_inflight_messages > 0:
+                    rc = self._update_inflight()
+                    if rc != MQTT_ERR_SUCCESS:
+                        self._message_mutex.release()
+                        return rc
+
+                self._message_mutex.release()
                 return self._send_pubcomp(mid)
 
+        self._message_mutex.release()
+        return MQTT_ERR_SUCCESS
+
+    def _update_inflight(self):
+        # Dont lock message_mutex here
+        for m in self._messages:
+            if self._inflight_messages < self._max_inflight_messages:
+                if m.qos > 0 and m.state == mqtt_ms_invalid and m.direction == mqtt_md_out:
+                    self._inflight_messages = self._inflight_messages + 1
+                    if m.qos == 1:
+                        m.state = mqtt_ms_wait_puback
+                    elif m.qos == 2:
+                        m.state = mqtt_ms_wait_pubrec
+                    rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup)
+                    if rc != 0:
+                        return rc
+            else:
+                return MQTT_ERR_SUCCESS
         return MQTT_ERR_SUCCESS
 
     def _handle_pubrec(self):
@@ -1691,12 +1746,15 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received PUBREC (Mid: "+str(mid)+")")
         
-        for i in range(len(self._messages)):
-            if self._messages[i].direction == mqtt_md_out and self._messages[i].mid == mid:
-                self._messages[i].state = mqtt_ms_wait_pubcomp
-                self._messages[i].timestamp = time.time()
+        self._message_mutex.acquire()
+        for m in self._messages:
+            if m.direction == mqtt_md_out and m.mid == mid:
+                m.state = mqtt_ms_wait_pubcomp
+                m.timestamp = time.time()
+                self._message_mutex.release()
                 return self._send_pubrel(mid, False)
         
+        self._message_mutex.release()
         return MQTT_ERR_SUCCESS
 
     def _handle_unsuback(self):
@@ -1724,6 +1782,7 @@ class Client:
         mid = mid[0]
         self._easy_log(MQTT_LOG_DEBUG, "Received "+cmd+" (Mid: "+str(mid)+")")
         
+        self._message_mutex.acquire()
         for i in range(len(self._messages)):
             try:
                 if self._messages[i].direction == mqtt_md_out and self._messages[i].mid == mid:
@@ -1736,11 +1795,20 @@ class Client:
 
                     self._callback_mutex.release()
                     self._messages.pop(i)
+                    self._inflight_messages = self._inflight_messages - 1
+                    if self._max_inflight_messages > 0:
+                        rc = self._update_inflight()
+                        if rc != MQTT_ERR_SUCCESS:
+                            self._message_mutex.release()
+                            return rc
+                    self._message_mutex.release()
+                    return MQTT_ERR_SUCCESS
             except IndexError:
                 # Have removed item so i>count.
                 # Not really an error.
                 pass
 
+        self._message_mutex.release()
         return MQTT_ERR_SUCCESS
 
     def _thread_main(self):