From 1b07d307619049ddf394b69c4b68b7171643b328 Mon Sep 17 00:00:00 2001
From: david-sarah <david-sarah@jacaranda.org>
Date: Fri, 15 Jun 2012 01:48:55 +0000
Subject: [PATCH] After a server disconnects, make the IServer retain the dead
 RemoteReference, and continue to return it to anyone who calls get_rref().
 This removes the need for callers to guard against receiving a None (as long
 as the server was connected at least once, which is always the case for
 servers returned by get_servers_for_psi(), which is how all upload/download
 code gets servers). Includes test. fixes #1636

---
 src/allmydata/client.py           |  1 +
 src/allmydata/interfaces.py       |  7 ++++++-
 src/allmydata/storage_client.py   | 14 ++++++++++++--
 src/allmydata/test/test_system.py | 30 ++++++++++++++++++++++++++++++
 4 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/src/allmydata/client.py b/src/allmydata/client.py
index 1e4479bb..239fdc1f 100644
--- a/src/allmydata/client.py
+++ b/src/allmydata/client.py
@@ -213,6 +213,7 @@ class Client(node.Node, pollmixin.PollMixin):
         sk,vk_vs = keyutil.parse_privkey(sk_vs.strip())
         self.write_config("node.pubkey", vk_vs+"\n")
         self._server_key = sk
+        self.node_key_s = vk_vs
 
     def _init_permutation_seed(self, ss):
         seed = self.get_config_from_file("permutation-seed")
diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py
index b1c34a79..b9255524 100644
--- a/src/allmydata/interfaces.py
+++ b/src/allmydata/interfaces.py
@@ -444,7 +444,12 @@ class IServer(IDisplayableServer):
     def start_connecting(tub, trigger_cb):
         pass
     def get_rref():
-        pass
+        """Once a server is connected, I return a RemoteReference.
+        Before a server is connected for the first time, I return None.
+
+        Note that the rref I return will start producing DeadReferenceErrors
+        once the connection is lost.
+        """
 
 
 class IMutableSlotWriter(Interface):
diff --git a/src/allmydata/storage_client.py b/src/allmydata/storage_client.py
index 68823f01..b536c674 100644
--- a/src/allmydata/storage_client.py
+++ b/src/allmydata/storage_client.py
@@ -77,6 +77,7 @@ class StorageFarmBroker:
     def test_add_rref(self, serverid, rref, ann):
         s = NativeStorageServer(serverid, ann.copy())
         s.rref = rref
+        s._is_connected = True
         self.servers[serverid] = s
 
     def test_add_server(self, serverid, s):
@@ -129,7 +130,7 @@ class StorageFarmBroker:
         return frozenset(self.servers.keys())
 
     def get_connected_servers(self):
-        return frozenset([s for s in self.servers.values() if s.get_rref()])
+        return frozenset([s for s in self.servers.values() if s.is_connected()])
 
     def get_known_servers(self):
         return frozenset(self.servers.values())
@@ -215,6 +216,7 @@ class NativeStorageServer:
         self.last_loss_time = None
         self.remote_host = None
         self.rref = None
+        self._is_connected = False
         self._reconnector = None
         self._trigger_cb = None
 
@@ -254,6 +256,8 @@ class NativeStorageServer:
         return self.announcement
     def get_remote_host(self):
         return self.remote_host
+    def is_connected(self):
+        return self._is_connected
     def get_last_connect_time(self):
         return self.last_connect_time
     def get_last_loss_time(self):
@@ -287,6 +291,7 @@ class NativeStorageServer:
         self.last_connect_time = time.time()
         self.remote_host = rref.getPeer()
         self.rref = rref
+        self._is_connected = True
         rref.notifyOnDisconnect(self._lost)
 
     def get_rref(self):
@@ -296,7 +301,12 @@ class NativeStorageServer:
         log.msg(format="lost connection to %(name)s", name=self.get_name(),
                 facility="tahoe.storage_broker", umid="zbRllw")
         self.last_loss_time = time.time()
-        self.rref = None
+        # self.rref is now stale: all callRemote()s will get a
+        # DeadReferenceError. We leave the stale reference in place so that
+        # uploader/downloader code (which received this IServer through
+        # get_connected_servers() or get_servers_for_psi()) can continue to
+        # use s.get_rref().callRemote() and not worry about it being None.
+        self._is_connected = False
         self.remote_host = None
 
     def stop_connecting(self):
diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py
index f3d618d6..a9d20eaf 100644
--- a/src/allmydata/test/test_system.py
+++ b/src/allmydata/test/test_system.py
@@ -1883,3 +1883,33 @@ class SystemTest(SystemTestMixin, RunBinTahoeMixin, unittest.TestCase):
             return d
         d.addCallback(_got_lit_filenode)
         return d
+
+class Connections(SystemTestMixin, unittest.TestCase):
+    def test_rref(self):
+        self.basedir = "system/Connections/rref"
+        d = self.set_up_nodes(2)
+        def _start(ign):
+            self.c0 = self.clients[0]
+            for s in self.c0.storage_broker.get_connected_servers():
+                if "pub-"+s.get_longname() != self.c0.node_key_s:
+                    break
+            self.s1 = s # s1 is the server, not c0
+            self.s1_rref = s.get_rref()
+            self.failIfEqual(self.s1_rref, None)
+            self.failUnless(self.s1.is_connected())
+        d.addCallback(_start)
+
+        # now shut down the server
+        d.addCallback(lambda ign: self.clients[1].disownServiceParent())
+        # and wait for the client to notice
+        def _poll():
+            return len(self.c0.storage_broker.get_connected_servers()) < 2
+        d.addCallback(lambda ign: self.poll(_poll))
+
+        def _down(ign):
+            self.failIf(self.s1.is_connected())
+            rref = self.s1.get_rref()
+            self.failUnless(rref)
+            self.failUnlessIdentical(rref, self.s1_rref)
+        d.addCallback(_down)
+        return d
-- 
2.45.2