From baa11a0ad4f4bafee790c7480d35d5ad9e872f4c Mon Sep 17 00:00:00 2001
From: david-sarah <david-sarah@jacaranda.org>
Date: Fri, 29 Jan 2010 04:38:45 -0800
Subject: [PATCH] New tests for #928

---
 src/allmydata/immutable/download.py    |  56 ++++-----
 src/allmydata/test/no_network.py       |  26 +++-
 src/allmydata/test/test_hung_server.py | 157 +++++++++++++++++++++++++
 3 files changed, 206 insertions(+), 33 deletions(-)
 create mode 100644 src/allmydata/test/test_hung_server.py

diff --git a/src/allmydata/immutable/download.py b/src/allmydata/immutable/download.py
index 05e126b0..d5e29b8d 100644
--- a/src/allmydata/immutable/download.py
+++ b/src/allmydata/immutable/download.py
@@ -1,6 +1,6 @@
 import random, weakref, itertools, time
 from zope.interface import implements
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from foolscap.api import DeadReferenceError, RemoteException, eventually
 
@@ -835,7 +835,7 @@ class CiphertextDownloader(log.PrefixingLogMixin):
 
         # first step: who should we download from?
         d = defer.maybeDeferred(self._get_all_shareholders)
-        d.addCallback(self._got_all_shareholders)
+        d.addBoth(self._got_all_shareholders)
         # now get the uri_extension block from somebody and integrity check
         # it and parse and validate its contents
         d.addCallback(self._obtain_uri_extension)
@@ -872,46 +872,55 @@ class CiphertextDownloader(log.PrefixingLogMixin):
         """ Once the number of buckets that I know about is >= K then I
         callback the Deferred that I return.
 
-        If all of the get_buckets deferreds have fired (whether callback or
-        errback) and I still don't have enough buckets then I'll callback the
-        Deferred that I return.
+        If all of the get_buckets deferreds have fired (whether callback
+        or errback) and I still don't have enough buckets then I'll also
+        callback -- not errback -- the Deferred that I return.
         """
-        self._wait_for_enough_buckets_d = defer.Deferred()
+        wait_for_enough_buckets_d = defer.Deferred()
+        self._wait_for_enough_buckets_d = wait_for_enough_buckets_d
 
-        self._queries_sent = 0
-        self._responses_received = 0
-        self._queries_failed = 0
         sb = self._storage_broker
         servers = sb.get_servers_for_index(self._storage_index)
         if not servers:
             raise NoServersError("broker gave us no servers!")
+
+        self._total_queries = len(servers)
+        self._responses_received = 0
+        self._queries_failed = 0
         for (peerid,ss) in servers:
             self.log(format="sending DYHB to [%(peerid)s]",
                      peerid=idlib.shortnodeid_b2a(peerid),
                      level=log.NOISY, umid="rT03hg")
-            self._queries_sent += 1
             d = ss.callRemote("get_buckets", self._storage_index)
             d.addCallbacks(self._got_response, self._got_error,
                            callbackArgs=(peerid,))
+            d.addBoth(self._check_got_all_responses)
+
         if self._status:
             self._status.set_status("Locating Shares (%d/%d)" %
-                                    (self._responses_received,
-                                     self._queries_sent))
-        return self._wait_for_enough_buckets_d
+                                    (len(self._share_buckets),
+                                     self._verifycap.needed_shares))
+        return wait_for_enough_buckets_d
+
+    def _check_got_all_responses(self, ignored=None):
+        assert (self._responses_received+self._queries_failed) <= self._total_queries
+        if self._wait_for_enough_buckets_d and (self._responses_received+self._queries_failed) == self._total_queries:
+            reactor.callLater(0, self._wait_for_enough_buckets_d.callback, False)
+            self._wait_for_enough_buckets_d = None
 
     def _got_response(self, buckets, peerid):
+        self._responses_received += 1
         self.log(format="got results from [%(peerid)s]: shnums %(shnums)s",
                  peerid=idlib.shortnodeid_b2a(peerid),
                  shnums=sorted(buckets.keys()),
                  level=log.NOISY, umid="o4uwFg")
-        self._responses_received += 1
         if self._results:
             elapsed = time.time() - self._started
             self._results.timings["servers_peer_selection"][peerid] = elapsed
         if self._status:
             self._status.set_status("Locating Shares (%d/%d)" %
                                     (self._responses_received,
-                                     self._queries_sent))
+                                     self._total_queries))
         for sharenum, bucket in buckets.iteritems():
             b = layout.ReadBucketProxy(bucket, peerid, self._storage_index)
             self.add_share_bucket(sharenum, b)
@@ -919,14 +928,7 @@ class CiphertextDownloader(log.PrefixingLogMixin):
             # deferred. Then remove it from self so that we don't fire it
             # again.
             if self._wait_for_enough_buckets_d and len(self._share_buckets) >= self._verifycap.needed_shares:
-                self._wait_for_enough_buckets_d.callback(True)
-                self._wait_for_enough_buckets_d = None
-
-            # Else, if we ran out of outstanding requests then fire it and
-            # remove it from self.
-            assert (self._responses_received+self._queries_failed) <= self._queries_sent
-            if self._wait_for_enough_buckets_d and (self._responses_received+self._queries_failed) == self._queries_sent:
-                self._wait_for_enough_buckets_d.callback(False)
+                reactor.callLater(0, self._wait_for_enough_buckets_d.callback, True)
                 self._wait_for_enough_buckets_d = None
 
             if self._results:
@@ -939,18 +941,12 @@ class CiphertextDownloader(log.PrefixingLogMixin):
         self._share_buckets.setdefault(sharenum, []).append(bucket)
 
     def _got_error(self, f):
+        self._queries_failed += 1
         level = log.WEIRD
         if f.check(DeadReferenceError):
             level = log.UNUSUAL
         self.log("Error during get_buckets", failure=f, level=level,
                          umid="3uuBUQ")
-        # If we ran out of outstanding requests then errback it and remove it
-        # from self.
-        self._queries_failed += 1
-        assert (self._responses_received+self._queries_failed) <= self._queries_sent
-        if self._wait_for_enough_buckets_d and self._responses_received == self._queries_sent:
-            self._wait_for_enough_buckets_d.errback()
-            self._wait_for_enough_buckets_d = None
 
     def bucket_failed(self, vbucket):
         shnum = vbucket.sharenum
diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py
index ab13bb92..714653f8 100644
--- a/src/allmydata/test/no_network.py
+++ b/src/allmydata/test/no_network.py
@@ -16,7 +16,7 @@
 import os.path
 from zope.interface import implements
 from twisted.application import service
-from twisted.internet import reactor
+from twisted.internet import defer, reactor
 from twisted.python.failure import Failure
 from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
@@ -38,6 +38,7 @@ class LocalWrapper:
     def __init__(self, original):
         self.original = original
         self.broken = False
+        self.hung_until = None
         self.post_call_notifier = None
         self.disconnectors = {}
 
@@ -57,11 +58,25 @@ class LocalWrapper:
                 return a
         args = tuple([wrap(a) for a in args])
         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
+
+        def _really_call():
+            meth = getattr(self.original, "remote_" + methname)
+            return meth(*args, **kwargs)
+
         def _call():
             if self.broken:
                 raise IntentionalError("I was asked to break")
-            meth = getattr(self.original, "remote_" + methname)
-            return meth(*args, **kwargs)
+            if self.hung_until:
+                d2 = defer.Deferred()
+                self.hung_until.addCallback(lambda ign: _really_call())
+                self.hung_until.addCallback(lambda res: d2.callback(res))
+                def _err(res):
+                    d2.errback(res)
+                    return res
+                self.hung_until.addErrback(_err)
+                return d2
+            return _really_call()
+
         d = fireEventually()
         d.addCallback(lambda res: _call())
         def _wrap_exception(f):
@@ -240,6 +255,11 @@ class NoNetworkGrid(service.MultiService):
         # asked to hold a share
         self.servers_by_id[serverid].broken = True
 
+    def hang_server(self, serverid, until=defer.Deferred()):
+        # hang the given server until 'until' fires
+        self.servers_by_id[serverid].hung_until = until
+
+
 class GridTestMixin:
     def setUp(self):
         self.s = service.MultiService()
diff --git a/src/allmydata/test/test_hung_server.py b/src/allmydata/test/test_hung_server.py
new file mode 100644
index 00000000..5f855106
--- /dev/null
+++ b/src/allmydata/test/test_hung_server.py
@@ -0,0 +1,157 @@
+
+import os, shutil
+from twisted.trial import unittest
+from twisted.internet import defer, reactor
+from twisted.python import failure
+from allmydata import uri
+from allmydata.util.consumer import download_to_data
+from allmydata.immutable import upload
+from allmydata.storage.common import storage_index_to_dir
+from allmydata.test.no_network import GridTestMixin
+from allmydata.test.common import ShouldFailMixin
+from allmydata.interfaces import NotEnoughSharesError
+
+immutable_plaintext = "data" * 10000
+mutable_plaintext = "muta" * 10000
+
+class HungServerDownloadTest(GridTestMixin, ShouldFailMixin, unittest.TestCase):
+    timeout = 30
+
+    def _break(self, servers):
+        for (id, ss) in servers:
+            self.g.break_server(id)
+
+    def _hang(self, servers, **kwargs):
+        for (id, ss) in servers:
+            self.g.hang_server(id, **kwargs)
+
+    def _delete_all_shares_from(self, servers):
+        serverids = [id for (id, ss) in servers]
+        for (i_shnum, i_serverid, i_sharefile) in self.shares:
+            if i_serverid in serverids:
+                os.unlink(i_sharefile)
+
+    # untested
+    def _pick_a_share_from(self, server):
+        (id, ss) = server
+        for (i_shnum, i_serverid, i_sharefile) in self.shares:
+            if i_serverid == id:
+                return (i_shnum, i_sharefile)
+        raise AssertionError("server %r had no shares" % server)
+
+    # untested
+    def _copy_all_shares_from(self, from_servers, to_server):
+        serverids = [id for (id, ss) in from_servers]
+        for (i_shnum, i_serverid, i_sharefile) in self.shares:
+            if i_serverid in serverids:
+                self._copy_share((i_shnum, i_sharefile), to_server)
+
+    # untested
+    def _copy_share(self, share, to_server):
+         (sharenum, sharefile) = share
+         (id, ss) = to_server
+         # FIXME: this doesn't work because we only have a LocalWrapper
+         shares_dir = os.path.join(ss.storedir, "shares")
+         si = uri.from_string(self.uri).get_storage_index()
+         si_dir = os.path.join(shares_dir, storage_index_to_dir(si))
+         if not os.path.exists(si_dir):
+             os.makedirs(si_dir)
+         new_sharefile = os.path.join(si_dir, str(sharenum))
+         shutil.copy(sharefile, new_sharefile)
+         self.shares = self.find_shares(self.uri)
+         # Make sure that the storage server has the share.
+         self.failUnless((sharenum, ss.my_nodeid, new_sharefile)
+                         in self.shares)
+
+    # untested
+    def _add_server(self, server_number, readonly=False):
+        ss = self.g.make_server(server_number, readonly)
+        self.g.add_server(server_number, ss)
+        self.shares = self.find_shares(self.uri)
+
+    def _set_up(self, testdir, num_clients=1, num_servers=10):
+        self.basedir = "download/" + testdir
+        self.set_up_grid(num_clients=num_clients, num_servers=num_servers)
+
+        self.c0 = self.g.clients[0]
+        sb = self.c0.nodemaker.storage_broker
+        self.servers = [(id, ss) for (id, ss) in sb.get_all_servers()]
+
+        data = upload.Data(immutable_plaintext, convergence="")
+        d = self.c0.upload(data)
+        def _uploaded(ur):
+            self.uri = ur.uri
+            self.shares = self.find_shares(self.uri)
+        d.addCallback(_uploaded)
+        return d
+
+    def test_10_good_sanity_check(self):
+        d = self._set_up("test_10_good_sanity_check")
+        d.addCallback(lambda ign: self.download_immutable())
+        return d
+
+    def test_3_good_7_hung(self):
+        d = self._set_up("test_3_good_7_hung")
+        d.addCallback(lambda ign: self._hang(self.servers[3:]))
+        d.addCallback(lambda ign: self.download_immutable())
+        return d
+
+    def test_3_good_7_noshares(self):
+        d = self._set_up("test_3_good_7_noshares")
+        d.addCallback(lambda ign: self._delete_all_shares_from(self.servers[3:]))
+        d.addCallback(lambda ign: self.download_immutable())
+        return d
+
+    def test_2_good_8_broken_fail(self):
+        d = self._set_up("test_2_good_8_broken_fail")
+        d.addCallback(lambda ign: self._break(self.servers[2:]))
+        d.addCallback(lambda ign:
+                      self.shouldFail(NotEnoughSharesError, "test_2_good_8_broken_fail",
+                                      "Failed to get enough shareholders: have 2, need 3",
+                                      self.download_immutable))
+        return d
+
+    def test_2_good_8_noshares_fail(self):
+        d = self._set_up("test_2_good_8_noshares_fail")
+        d.addCallback(lambda ign: self._delete_all_shares_from(self.servers[2:]))
+        d.addCallback(lambda ign:
+                      self.shouldFail(NotEnoughSharesError, "test_2_good_8_noshares_fail",
+                                      "Failed to get enough shareholders: have 2, need 3",
+                                      self.download_immutable))
+        return d
+
+    def test_2_good_8_hung_then_1_recovers(self):
+        recovered = defer.Deferred()
+        d = self._set_up("test_2_good_8_hung_then_1_recovers")
+        d.addCallback(lambda ign: self._hang(self.servers[2:3], until=recovered))
+        d.addCallback(lambda ign: self._hang(self.servers[3:]))
+        d.addCallback(lambda ign: self.download_immutable())
+        reactor.callLater(5, recovered.callback, None)
+        return d
+
+    def test_2_good_8_hung_then_1_recovers_with_2_shares(self):
+        recovered = defer.Deferred()
+        d = self._set_up("test_2_good_8_hung_then_1_recovers_with_2_shares")
+        d.addCallback(lambda ign: self._copy_all_shares_from(self.servers[0:1], self.servers[2]))
+        d.addCallback(lambda ign: self._hang(self.servers[2:3], until=recovered))
+        d.addCallback(lambda ign: self._hang(self.servers[3:]))
+        d.addCallback(lambda ign: self.download_immutable())
+        reactor.callLater(5, recovered.callback, None)
+        return d
+
+    def download_immutable(self):
+        n = self.c0.create_node_from_uri(self.uri)
+        d = download_to_data(n)
+        def _got_data(data):
+            self.failUnlessEqual(data, immutable_plaintext)
+        d.addCallback(_got_data)
+        return d
+
+    # unused
+    def download_mutable(self):
+        n = self.c0.create_node_from_uri(self.uri)
+        d = n.download_best_version()
+        def _got_data(data):
+            self.failUnlessEqual(data, mutable_plaintext)
+        d.addCallback(_got_data)
+        return d
-- 
2.45.2