From 40edccf9c50e8d1b507dd8b0a16f4bc505d4e504 Mon Sep 17 00:00:00 2001
From: Zooko O'Whielacronx <zooko@zooko.com>
Date: Mon, 10 Dec 2007 20:22:59 -0700
Subject: [PATCH] fix IntroducerClient.when_enough_peers() add
 IntroducerClient.when_few_enough_peers(), fix and improve test_introducer

---
 src/allmydata/introducer.py           | 87 +++++++++++++++++++--------
 src/allmydata/test/test_introducer.py | 45 +++++++-------
 2 files changed, 86 insertions(+), 46 deletions(-)

diff --git a/src/allmydata/introducer.py b/src/allmydata/introducer.py
index 86015e00..6fbc9c94 100644
--- a/src/allmydata/introducer.py
+++ b/src/allmydata/introducer.py
@@ -71,14 +71,24 @@ class IntroducerClient(service.Service, Referenceable):
 
         # The N'th element of _observers_of_enough_peers is None if nobody has
         # asked to be informed when N peers become connected, it is a
-        # OneShotObserverList if someone has asked to be informed, and that
-        # list is fired when N peers next become connected (or immediately if
-        # N peers are already connected when someone asks), and the N'th
-        # element is replaced by None when the number of connected peers falls
-        # below N.  _observers_of_enough_peers is always just long enough to
-        # hold the highest-numbered N that anyone is interested in (i.e.,
-        # there are never trailing Nones in _observers_of_enough_peers).
+        # OneShotObserverList if someone has asked to be informed, and that list
+        # is fired when N peers next become connected (or immediately if N peers
+        # are already connected when they asked), and the N'th element is
+        # replaced by None when the number of connected peers falls below N.
+        # _observers_of_enough_peers is always just long enough to hold the
+        # highest-numbered N that anyone is interested in (i.e., there are never
+        # trailing Nones in _observers_of_enough_peers).
         self._observers_of_enough_peers = []
+        # The N'th element of _observers_of_fewer_than_peers is None if nobody
+        # has asked to be informed when we become connected to fewer than N
+        # peers, it is a OneShotObserverList if someone has asked to be
+        # informed, and that list is fired when we become connected to fewer
+        # than N peers (or immediately if we are already connected to fewer than
+        # N peers when they asked).  _observers_of_fewer_than_peers is always
+        # just long enough to hold the highest-numbered N that anyone is
+        # interested in (i.e., there are never trailing Nones in
+        # _observers_of_fewer_than_peers).
+        self._observers_of_fewer_than_peers = []
 
     def startService(self):
         service.Service.startService(self)
@@ -107,6 +117,27 @@ class IntroducerClient(service.Service, Referenceable):
         for reconnector in self.reconnectors.itervalues():
             reconnector.stopConnecting()
 
+    def _notify_observers_of_enough_peers(self, numpeers):
+        if len(self._observers_of_enough_peers) > numpeers:
+            osol = self._observers_of_enough_peers[numpeers]
+            if osol:
+                osol.fire(None)
+
+    def _remove_observers_of_enough_peers(self, numpeers):
+        if len(self._observers_of_enough_peers) > numpeers:
+            self._observers_of_enough_peers[numpeers] = None
+            while self._observers_of_enough_peers and (not self._observers_of_enough_peers[-1]):
+                self._observers_of_enough_peers.pop()
+
+    def _notify_observers_of_fewer_than_peers(self, numpeers):
+        if len(self._observers_of_fewer_than_peers) > numpeers:
+            osol = self._observers_of_fewer_than_peers[numpeers]
+            if osol:
+                osol.fire(None)
+                self._observers_of_fewer_than_peers[numpeers] = None
+                while len(self._observers_of_fewer_than_peers) > numpeers and (not self._observers_of_fewer_than_peers[-1]):
+                    self._observers_of_fewer_than_peers.pop()
+
     def _new_peer(self, furl):
         if furl in self.reconnectors:
             return
@@ -126,24 +157,18 @@ class IntroducerClient(service.Service, Referenceable):
             self.log("connected to %s" % b32encode(nodeid).lower()[:8])
             self.connection_observers.notify(nodeid, rref)
             self.connections[nodeid] = rref
-            if len(self._observers_of_enough_peers) > len(self.connections):
-                osol = self._observers_of_enough_peers[len(self.connections)]
-                if osol:
-                    osol.fire(None)
+            self._notify_observers_of_enough_peers(len(self.connections))
+            self._notify_observers_of_fewer_than_peers(len(self.connections))
             def _lost():
                 # TODO: notifyOnDisconnect uses eventually(), but connects do
                 # not. Could this cause a problem?
+
+                # We know that this observer list must have been fired, since we
+                # had enough peers before this one was lost.
+                self._remove_observers_of_enough_peers(len(self.connections))
+                self._notify_observers_of_fewer_than_peers(len(self.connections)+1)
+
                 del self.connections[nodeid]
-                if len(self._observers_of_enough_peers) > len(self.connections):
-                    self._observers_of_enough_peers[len(self.connections)] = None
-                    while self._observers_of_enough_peers and (not self._observers_of_enough_peers[-1]):
-                        self._observers_of_enough_peers.pop()
-                for numpeers in self._observers_of_enough_peers:
-                    if len(self.connections) == (numpeers-1):
-                        # We know that this observer list must have been
-                        # fired, since we had enough peers before this one was
-                        # lost.
-                        del self._observers_of_enough_peers[numpeers]
 
             rref.notifyOnDisconnect(_lost)
         self.log("connecting to %s" % b32encode(nodeid).lower()[:8])
@@ -177,9 +202,9 @@ class IntroducerClient(service.Service, Referenceable):
 
     def when_enough_peers(self, numpeers):
         """
-        I return a deferred that fires the next time that at least numpeers
-        are connected, or fires immediately if numpeers are currently
-        available.
+        I return a deferred that fires the next time that at least
+        numpeers are connected, or fires immediately if numpeers are
+        currently connected.
         """
         self._observers_of_enough_peers.extend([None]*(numpeers+1-len(self._observers_of_enough_peers)))
         if not self._observers_of_enough_peers[numpeers]:
@@ -187,3 +212,17 @@ class IntroducerClient(service.Service, Referenceable):
             if len(self.connections) >= numpeers:
                 self._observers_of_enough_peers[numpeers].fire(self)
         return self._observers_of_enough_peers[numpeers].when_fired()
+
+    def when_fewer_than_peers(self, numpeers):
+        """
+        I return a deferred that fires the next time that fewer than numpeers
+        are connected, or fires immediately if fewer than numpeers are currently
+        connected.
+        """
+        if len(self.connections) < numpeers:
+            return defer.succeed(None)
+        else:
+            self._observers_of_fewer_than_peers.extend([None]*(numpeers+1-len(self._observers_of_fewer_than_peers)))
+            if not self._observers_of_fewer_than_peers[numpeers]:
+                self._observers_of_fewer_than_peers[numpeers] = observer.OneShotObserverList()
+            return self._observers_of_fewer_than_peers[numpeers].when_fired()
diff --git a/src/allmydata/test/test_introducer.py b/src/allmydata/test/test_introducer.py
index 35e805be..7a8e6f09 100644
--- a/src/allmydata/test/test_introducer.py
+++ b/src/allmydata/test/test_introducer.py
@@ -90,14 +90,6 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
         iurl = tub.registerReference(i)
         NUMCLIENTS = 5
 
-        self.waiting_for_connections = NUMCLIENTS*NUMCLIENTS
-        d = self._done_counting = defer.Deferred()
-        def _count(nodeid, rref):
-            log.msg("NEW CONNECTION! %s %s" % (b32encode(nodeid).lower(), rref))
-            self.waiting_for_connections -= 1
-            if self.waiting_for_connections == 0:
-                self._done_counting.callback("done!")
-
         clients = []
         tubs = {}
         for i in range(NUMCLIENTS):
@@ -112,12 +104,19 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
             n = MyNode()
             node_furl = tub.registerReference(n)
             c = IntroducerClient(tub, iurl, node_furl)
-            c.notify_on_new_connection(_count)
+
             c.setServiceParent(self.parent)
             clients.append(c)
             tubs[c] = tub
 
-        # d will fire once everybody is connected
+        def _wait_for_all_connections(res):
+            dl = [] # list of when_enough_peers() for each peer
+            # will fire once everybody is connected
+            for c in clients:
+                dl.append(c.when_enough_peers(NUMCLIENTS))
+            return defer.DeferredList(dl, fireOnOneErrback=True)
+
+        d = _wait_for_all_connections(None)
 
         def _check1(res):
             log.msg("doing _check1")
@@ -125,11 +124,9 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
                 self.failUnlessEqual(len(c.connections), NUMCLIENTS)
                 self.failUnless(c._connected) # to the introducer
         d.addCallback(_check1)
+        origin_c = clients[0]
         def _disconnect_somebody_else(res):
             # now disconnect somebody's connection to someone else
-            self.waiting_for_connections = 2
-            d2 = self._done_counting = defer.Deferred()
-            origin_c = clients[0]
             # find a target that is not themselves
             for nodeid,rref in origin_c.connections.items():
                 if b32encode(nodeid).lower() != tubs[origin_c].tubID:
@@ -138,19 +135,23 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
             log.msg(" disconnecting %s->%s" % (tubs[origin_c].tubID, victim))
             victim.tracker.broker.transport.loseConnection()
             log.msg(" did disconnect")
-            return d2
         d.addCallback(_disconnect_somebody_else)
+        def _wait_til_he_notices(res):
+            # wait til the origin_c notices the loss
+            log.msg(" waiting until peer notices the disconnection")
+            return origin_c.when_fewer_than_peers(NUMCLIENTS)
+        d.addCallback(_wait_til_he_notices)
+        def _wait_for_reconnection(res):
+            log.msg(" doing _wait_for_reconnection()")
+            return origin_c.when_enough_peers(NUMCLIENTS)
+        d.addCallback(_wait_for_reconnection)
         def _check2(res):
             log.msg("doing _check2")
             for c in clients:
                 self.failUnlessEqual(len(c.connections), NUMCLIENTS)
         d.addCallback(_check2)
         def _disconnect_yourself(res):
-            # now disconnect somebody's connection to themselves. This will
-            # only result in one new connection, since it is a loopback.
-            self.waiting_for_connections = 1
-            d2 = self._done_counting = defer.Deferred()
-            origin_c = clients[0]
+            # now disconnect somebody's connection to themselves.
             # find a target that *is* themselves
             for nodeid,rref in origin_c.connections.items():
                 if b32encode(nodeid).lower() == tubs[origin_c].tubID:
@@ -158,9 +159,10 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
                     break
             log.msg(" disconnecting %s->%s" % (tubs[origin_c].tubID, victim))
             victim.tracker.broker.transport.loseConnection()
-            log.msg(" did disconnect")
-            return d2
+            log.msg(" did disconnect from self")
         d.addCallback(_disconnect_yourself)
+        d.addCallback(_wait_til_he_notices)
+        d.addCallback(_wait_for_all_connections)
         def _check3(res):
             log.msg("doing _check3")
             for c in clients:
@@ -271,4 +273,3 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
         d.addCallback(_check_again)
         return d
     del test_system_this_one_breaks_too
-
-- 
2.45.2