From c0c8d72b44a642be38a7a3311e70c886b71729f9 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Wed, 16 Apr 2008 14:49:47 -0700
Subject: [PATCH] mutable WIP: rewrite ServerMap data structure, add tests

---
 src/allmydata/encode.py            |   1 -
 src/allmydata/mutable/node.py      |  15 ++-
 src/allmydata/mutable/publish.py   |  54 ++++-----
 src/allmydata/mutable/retrieve.py  |   1 -
 src/allmydata/mutable/servermap.py | 172 ++++++++++++++++-------------
 src/allmydata/test/test_mutable.py | 145 ++++++++++++++++--------
 6 files changed, 229 insertions(+), 159 deletions(-)

diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py
index 1441b9c6..affb181f 100644
--- a/src/allmydata/encode.py
+++ b/src/allmydata/encode.py
@@ -61,7 +61,6 @@ hash tree is put into the URI.
 """
 
 class NotEnoughSharesError(Exception):
-    worth_retrying = False
     servermap = None
     pass
 
diff --git a/src/allmydata/mutable/node.py b/src/allmydata/mutable/node.py
index 4c901ace..8aecf6ca 100644
--- a/src/allmydata/mutable/node.py
+++ b/src/allmydata/mutable/node.py
@@ -280,8 +280,8 @@ class MutableFileNode:
         d.addCallback(_done)
         return d
 
-    def _update_and_retrieve_best(self, old_map=None):
-        d = self.update_servermap(old_map=old_map, mode=MODE_ENOUGH)
+    def _update_and_retrieve_best(self, old_map=None, mode=MODE_ENOUGH):
+        d = self.update_servermap(old_map=old_map, mode=mode)
         def _updated(smap):
             goal = smap.best_recoverable_version()
             if not goal:
@@ -296,13 +296,12 @@ class MutableFileNode:
         def _maybe_retry(f):
             f.trap(NotEnoughSharesError)
             e = f.value
-            if not e.worth_retrying:
-                return f
             # the download is worth retrying once. Make sure to use the old
-            # servermap, since it is what remembers the bad shares. TODO:
-            # consider allowing this to retry multiple times.. this approach
-            # will let us tolerate about 8 bad shares, I think.
-            return self._update_and_retrieve_best(e.servermap)
+            # servermap, since it is what remembers the bad shares, but use
+            # MODE_WRITE to make it look for even more shares. TODO: consider
+            # allowing this to retry multiple times.. this approach will let
+            # us tolerate about 8 bad shares, I think.
+            return self._update_and_retrieve_best(e.servermap, mode=MODE_WRITE)
         d.addErrback(_maybe_retry)
         d.addBoth(self.release_lock)
         return d
diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py
index 4d2aa447..b53c639a 100644
--- a/src/allmydata/mutable/publish.py
+++ b/src/allmydata/mutable/publish.py
@@ -202,10 +202,9 @@ class Publish:
 
         # we use the servermap to populate the initial goal: this way we will
         # try to update each existing share in place.
-        for (peerid, shares) in self._servermap.servermap.items():
-            for (shnum, versionid, timestamp) in shares:
-                self.goal.add( (peerid, shnum) )
-                self.connections[peerid] = self._servermap.connections[peerid]
+        for (peerid, shnum) in self._servermap.servermap:
+            self.goal.add( (peerid, shnum) )
+            self.connections[peerid] = self._servermap.connections[peerid]
 
         # create the shares. We'll discard these as they are delivered. SMDF:
         # we're allowed to hold everything in memory.
@@ -476,23 +475,23 @@ class Publish:
         all_tw_vectors = {} # maps peerid to tw_vectors
         sm = self._servermap.servermap
 
-        for (peerid, shnum) in needed:
-            testvs = []
-            for (old_shnum, old_versionid, old_timestamp) in sm.get(peerid,[]):
-                if old_shnum == shnum:
-                    # an old version of that share already exists on the
-                    # server, according to our servermap. We will create a
-                    # request that attempts to replace it.
-                    (old_seqnum, old_root_hash, old_salt, old_segsize,
-                     old_datalength, old_k, old_N, old_prefix,
-                     old_offsets_tuple) = old_versionid
-                    old_checkstring = pack_checkstring(old_seqnum,
-                                                       old_root_hash,
-                                                       old_salt)
-                    testv = (0, len(old_checkstring), "eq", old_checkstring)
-                    testvs.append(testv)
-                    break
-            if not testvs:
+        for key in needed:
+            (peerid, shnum) = key
+
+            if key in sm:
+                # an old version of that share already exists on the
+                # server, according to our servermap. We will create a
+                # request that attempts to replace it.
+                old_versionid, old_timestamp = sm[key]
+                (old_seqnum, old_root_hash, old_salt, old_segsize,
+                 old_datalength, old_k, old_N, old_prefix,
+                 old_offsets_tuple) = old_versionid
+                old_checkstring = pack_checkstring(old_seqnum,
+                                                   old_root_hash,
+                                                   old_salt)
+                testv = (0, len(old_checkstring), "eq", old_checkstring)
+
+            else:
                 # add a testv that requires the share not exist
                 #testv = (0, 1, 'eq', "")
 
@@ -506,8 +505,8 @@ class Publish:
                 # (constant) tuple, by creating a new copy of this vector
                 # each time. This bug is fixed in later versions of foolscap.
                 testv = tuple([0, 1, 'eq', ""])
-                testvs.append(testv)
 
+            testvs = [testv]
             # the write vector is simply the share
             writev = [(0, self.shares[shnum])]
 
@@ -568,7 +567,6 @@ class Publish:
                       idlib.shortnodeid_b2a(peerid))
         for shnum in shnums:
             self.outstanding.discard( (peerid, shnum) )
-        sm = self._servermap.servermap
 
         wrote, read_data = answer
 
@@ -599,13 +597,9 @@ class Publish:
 
         for shnum in shnums:
             self.placed.add( (peerid, shnum) )
-            # and update the servermap. We strip the old entry out..
-            newset = set([ t
-                           for t in sm.get(peerid, [])
-                           if t[0] != shnum ])
-            sm[peerid] = newset
-            # and add a new one
-            sm[peerid].add( (shnum, self.versioninfo, started) )
+            # and update the servermap
+            self._servermap.add_new_share(peerid, shnum,
+                                          self.versioninfo, started)
 
         surprise_shares = set(read_data.keys()) - set(shnums)
         if surprise_shares:
diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py
index b28203cb..71a51335 100644
--- a/src/allmydata/mutable/retrieve.py
+++ b/src/allmydata/mutable/retrieve.py
@@ -394,7 +394,6 @@ class Retrieve:
                          "update the servermap and try again to check "
                          "more peers",
                          level=log.WEIRD)
-                err.worth_retrying = True
                 err.servermap = self.servermap
             raise err
 
diff --git a/src/allmydata/mutable/servermap.py b/src/allmydata/mutable/servermap.py
index fb38a12f..5a157020 100644
--- a/src/allmydata/mutable/servermap.py
+++ b/src/allmydata/mutable/servermap.py
@@ -26,6 +26,13 @@ class ServerMap:
     has changed since I last retrieved this data'. This reduces the chances
     of clobbering a simultaneous (uncoordinated) write.
 
+    @ivar servermap: a dictionary, mapping a (peerid, shnum) tuple to a
+                     (versionid, timestamp) tuple. Each 'versionid' is a
+                     tuple of (seqnum, root_hash, IV, segsize, datalength,
+                     k, N, signed_prefix, offsets)
+
+    @ivar connections: maps peerid to a RemoteReference
+
     @ivar bad_shares: a sequence of (peerid, shnum) tuples, describing
                       shares that I should ignore (because a previous user of
                       the servermap determined that they were invalid). The
@@ -36,11 +43,8 @@ class ServerMap:
     """
 
     def __init__(self):
-        # 'servermap' maps peerid to sets of (shnum, versionid, timestamp)
-        # tuples. Each 'versionid' is a (seqnum, root_hash, IV, segsize,
-        # datalength, k, N, signed_prefix, offsets) tuple
-        self.servermap = DictOfSets()
-        self.connections = {} # maps peerid to a RemoteReference
+        self.servermap = {}
+        self.connections = {}
         self.unreachable_peers = set() # peerids that didn't respond to queries
         self.problems = [] # mostly for debugging
         self.bad_shares = set()
@@ -53,77 +57,74 @@ class ServerMap:
         it from our list of useful shares, and remember that it is bad so we
         don't add it back again later.
         """
-        self.bad_shares.add( (peerid, shnum) )
-        self._remove_share(peerid, shnum)
-
-    def _remove_share(self, peerid, shnum):
-        #(s_shnum, s_verinfo, s_timestamp) = share
-        to_remove = [share
-                     for share in self.servermap[peerid]
-                     if share[0] == shnum]
-        for share in to_remove:
-            self.servermap[peerid].discard(share)
-        if not self.servermap[peerid]:
-            del self.servermap[peerid]
+        key = (peerid, shnum)
+        self.bad_shares.add(key)
+        self.servermap.pop(key, None)
 
     def add_new_share(self, peerid, shnum, verinfo, timestamp):
         """We've written a new share out, replacing any that was there
         before."""
-        self.bad_shares.discard( (peerid, shnum) )
-        self._remove_share(peerid, shnum)
-        self.servermap.add(peerid, (shnum, verinfo, timestamp) )
+        key = (peerid, shnum)
+        self.bad_shares.discard(key)
+        self.servermap[key] = (verinfo, timestamp)
 
     def dump(self, out=sys.stdout):
         print >>out, "servermap:"
-        for (peerid, shares) in self.servermap.items():
-            for (shnum, versionid, timestamp) in sorted(shares):
-                (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-                 offsets_tuple) = versionid
-                print >>out, ("[%s]: sh#%d seq%d-%s %d-of-%d len%d" %
-                              (idlib.shortnodeid_b2a(peerid), shnum,
-                               seqnum, base32.b2a(root_hash)[:4], k, N,
-                               datalength))
+
+        for ( (peerid, shnum), (verinfo, timestamp) ) in self.servermap.items():
+            (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+             offsets_tuple) = verinfo
+            print >>out, ("[%s]: sh#%d seq%d-%s %d-of-%d len%d" %
+                          (idlib.shortnodeid_b2a(peerid), shnum,
+                           seqnum, base32.b2a(root_hash)[:4], k, N,
+                           datalength))
         return out
 
+    def all_peers(self):
+        return set([peerid
+                    for (peerid, shnum)
+                    in self.servermap])
+
     def make_versionmap(self):
         """Return a dict that maps versionid to sets of (shnum, peerid,
         timestamp) tuples."""
         versionmap = DictOfSets()
-        for (peerid, shares) in self.servermap.items():
-            for (shnum, verinfo, timestamp) in shares:
-                versionmap.add(verinfo, (shnum, peerid, timestamp))
+        for ( (peerid, shnum), (verinfo, timestamp) ) in self.servermap.items():
+            versionmap.add(verinfo, (shnum, peerid, timestamp))
         return versionmap
 
     def shares_on_peer(self, peerid):
         return set([shnum
-                    for (shnum, versionid, timestamp)
-                    in self.servermap.get(peerid, [])])
+                    for (s_peerid, shnum)
+                    in self.servermap
+                    if s_peerid == peerid])
 
     def version_on_peer(self, peerid, shnum):
-        shares = self.servermap.get(peerid, [])
-        for (sm_shnum, sm_versionid, sm_timestamp) in shares:
-            if sm_shnum == shnum:
-                return sm_versionid
+        key = (peerid, shnum)
+        if key in self.servermap:
+            (verinfo, timestamp) = self.servermap[key]
+            return verinfo
+        return None
         return None
 
     def shares_available(self):
-        """Return a dict that maps versionid to tuples of
+        """Return a dict that maps verinfo to tuples of
         (num_distinct_shares, k) tuples."""
         versionmap = self.make_versionmap()
         all_shares = {}
-        for versionid, shares in versionmap.items():
+        for verinfo, shares in versionmap.items():
             s = set()
             for (shnum, peerid, timestamp) in shares:
                 s.add(shnum)
             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-             offsets_tuple) = versionid
-            all_shares[versionid] = (len(s), k)
+             offsets_tuple) = verinfo
+            all_shares[verinfo] = (len(s), k)
         return all_shares
 
     def highest_seqnum(self):
         available = self.shares_available()
-        seqnums = [versionid[0]
-                   for versionid in available.keys()]
+        seqnums = [verinfo[0]
+                   for verinfo in available.keys()]
         seqnums.append(0)
         return max(seqnums)
 
@@ -306,7 +307,7 @@ class ServermapUpdater:
     def _build_initial_querylist(self):
         initial_peers_to_query = {}
         must_query = set()
-        for peerid in self._servermap.servermap.keys():
+        for peerid in self._servermap.all_peers():
             ss = self._servermap.connections[peerid]
             # we send queries to everyone who was already in the sharemap
             initial_peers_to_query[peerid] = ss
@@ -365,7 +366,7 @@ class ServermapUpdater:
         self._must_query.discard(peerid)
         self._queries_completed += 1
         if not self._running:
-            self.log("but we're not running, so we'll ignore it")
+            self.log("but we're not running, so we'll ignore it", parent=lp)
             return
 
         if datavs:
@@ -384,7 +385,8 @@ class ServermapUpdater:
             except CorruptShareError, e:
                 # log it and give the other shares a chance to be processed
                 f = failure.Failure()
-                self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD)
+                self.log("bad share: %s %s" % (f, f.value),
+                         parent=lp, level=log.WEIRD)
                 self._bad_peers.add(peerid)
                 self._last_failure = f
                 self._servermap.problems.append(f)
@@ -412,9 +414,9 @@ class ServermapUpdater:
         self.log("_got_results done", parent=lp)
 
     def _got_results_one_share(self, shnum, data, peerid):
-        self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
-                 shnum=shnum,
-                 peerid=idlib.shortnodeid_b2a(peerid))
+        lp = self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
+                      shnum=shnum,
+                      peerid=idlib.shortnodeid_b2a(peerid))
 
         # this might raise NeedMoreDataError, if the pubkey and signature
         # live at some weird offset. That shouldn't happen, so I'm going to
@@ -451,18 +453,21 @@ class ServermapUpdater:
             self.log(" found valid version %d-%s from %s-sh%d: %d-%d/%d/%d"
                      % (seqnum, base32.b2a(root_hash)[:4],
                         idlib.shortnodeid_b2a(peerid), shnum,
-                        k, N, segsize, datalength))
+                        k, N, segsize, datalength),
+                     parent=lp)
             self._valid_versions.add(verinfo)
         # We now know that this is a valid candidate verinfo.
 
-        if (peerid, shnum, verinfo) in self._servermap.bad_shares:
+        if (peerid, shnum) in self._servermap.bad_shares:
             # we've been told that the rest of the data in this share is
             # unusable, so don't add it to the servermap.
+            self.log("but we've been told this is a bad share",
+                     parent=lp, level=log.UNUSUAL)
             return verinfo
 
         # Add the info to our servermap.
         timestamp = time.time()
-        self._servermap.servermap.add(peerid, (shnum, verinfo, timestamp))
+        self._servermap.add_new_share(peerid, shnum, verinfo, timestamp)
         # and the versionmap
         self.versionmap.add(verinfo, (shnum, peerid, timestamp))
         return verinfo
@@ -556,28 +561,35 @@ class ServermapUpdater:
         #  return self._done() : all done
         #  return : keep waiting, no new queries
 
-        self.log(format=("_check_for_done, mode is '%(mode)s', "
-                         "%(outstanding)d queries outstanding, "
-                         "%(extra)d extra peers available, "
-                         "%(must)d 'must query' peers left"
-                         ),
-                 mode=self.mode,
-                 outstanding=len(self._queries_outstanding),
-                 extra=len(self.extra_peers),
-                 must=len(self._must_query),
-                 )
+        lp = self.log(format=("_check_for_done, mode is '%(mode)s', "
+                              "%(outstanding)d queries outstanding, "
+                              "%(extra)d extra peers available, "
+                              "%(must)d 'must query' peers left"
+                              ),
+                      mode=self.mode,
+                      outstanding=len(self._queries_outstanding),
+                      extra=len(self.extra_peers),
+                      must=len(self._must_query),
+                      level=log.NOISY,
+                      )
+
+        if not self._running:
+            self.log("but we're not running", parent=lp, level=log.NOISY)
+            return
 
         if self._must_query:
             # we are still waiting for responses from peers that used to have
             # a share, so we must continue to wait. No additional queries are
             # required at this time.
-            self.log("%d 'must query' peers left" % len(self._must_query))
+            self.log("%d 'must query' peers left" % len(self._must_query),
+                     parent=lp)
             return
 
         if (not self._queries_outstanding and not self.extra_peers):
             # all queries have retired, and we have no peers left to ask. No
             # more progress can be made, therefore we are done.
-            self.log("all queries are retired, no extra peers: done")
+            self.log("all queries are retired, no extra peers: done",
+                     parent=lp)
             return self._done()
 
         recoverable_versions = self._servermap.recoverable_versions()
@@ -588,13 +600,15 @@ class ServermapUpdater:
         if self.mode == MODE_ANYTHING:
             if recoverable_versions:
                 self.log("MODE_ANYTHING and %d recoverable versions: done"
-                         % len(recoverable_versions))
+                         % len(recoverable_versions),
+                         parent=lp)
                 return self._done()
 
         if self.mode == MODE_CHECK:
             # we used self._must_query, and we know there aren't any
             # responses still waiting, so that means we must be done
-            self.log("MODE_CHECK: done")
+            self.log("MODE_CHECK: done",
+                     parent=lp)
             return self._done()
 
         MAX_IN_FLIGHT = 5
@@ -606,10 +620,12 @@ class ServermapUpdater:
             if self._queries_completed < self.num_peers_to_query:
                 self.log(format="ENOUGH, %(completed)d completed, %(query)d to query: need more",
                          completed=self._queries_completed,
-                         query=self.num_peers_to_query)
+                         query=self.num_peers_to_query,
+                         parent=lp)
                 return self._send_more_queries(MAX_IN_FLIGHT)
             if not recoverable_versions:
-                self.log("ENOUGH, no recoverable versions: need more")
+                self.log("ENOUGH, no recoverable versions: need more",
+                         parent=lp)
                 return self._send_more_queries(MAX_IN_FLIGHT)
             highest_recoverable = max(recoverable_versions)
             highest_recoverable_seqnum = highest_recoverable[0]
@@ -623,7 +639,8 @@ class ServermapUpdater:
                     return self._send_more_queries(MAX_IN_FLIGHT)
             # all the unrecoverable versions were old or concurrent with a
             # recoverable version. Good enough.
-            self.log("ENOUGH: no higher-seqnum: done")
+            self.log("ENOUGH: no higher-seqnum: done",
+                     parent=lp)
             return self._done()
 
         if self.mode == MODE_WRITE:
@@ -633,7 +650,8 @@ class ServermapUpdater:
             # every server in the world.
 
             if not recoverable_versions:
-                self.log("WRITE, no recoverable versions: need more")
+                self.log("WRITE, no recoverable versions: need more",
+                         parent=lp)
                 return self._send_more_queries(MAX_IN_FLIGHT)
 
             last_found = -1
@@ -656,7 +674,8 @@ class ServermapUpdater:
                         num_not_found += 1
                         if num_not_found >= self.EPSILON:
                             self.log("MODE_WRITE: found our boundary, %s" %
-                                     "".join(states))
+                                     "".join(states),
+                                     parent=lp)
                             found_boundary = True
                             break
 
@@ -678,10 +697,12 @@ class ServermapUpdater:
                 # everybody to the left of here
                 if last_not_responded == -1:
                     # we're done
-                    self.log("have all our answers")
+                    self.log("have all our answers",
+                             parent=lp)
                     # .. unless we're still waiting on the privkey
                     if self._need_privkey:
-                        self.log("but we're still waiting for the privkey")
+                        self.log("but we're still waiting for the privkey",
+                                 parent=lp)
                         # if we found the boundary but we haven't yet found
                         # the privkey, we may need to look further. If
                         # somehow all the privkeys were corrupted (but the
@@ -694,13 +715,14 @@ class ServermapUpdater:
 
             # if we hit here, we didn't find our boundary, so we're still
             # waiting for peers
-            self.log("MODE_WRITE: no boundary yet, %s" % "".join(states))
+            self.log("MODE_WRITE: no boundary yet, %s" % "".join(states),
+                     parent=lp)
             return self._send_more_queries(MAX_IN_FLIGHT)
 
         # otherwise, keep up to 5 queries in flight. TODO: this is pretty
         # arbitrary, really I want this to be something like k -
         # max(known_version_sharecounts) + some extra
-        self.log("catchall: need more")
+        self.log("catchall: need more", parent=lp)
         return self._send_more_queries(MAX_IN_FLIGHT)
 
     def _send_more_queries(self, num_outstanding):
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index c84fb96d..d0f3cb69 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -16,7 +16,7 @@ import sha
 
 from allmydata.mutable.node import MutableFileNode
 from allmydata.mutable.common import DictOfSets, ResponseCache, \
-     MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_ENOUGH
+     MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_ENOUGH, UnrecoverableFileError
 from allmydata.mutable.retrieve import Retrieve
 from allmydata.mutable.publish import Publish
 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
@@ -199,6 +199,43 @@ class FakeClient:
         return d
 
 
+def flip_bit(original, byte_offset):
+    return (original[:byte_offset] +
+            chr(ord(original[byte_offset]) ^ 0x01) +
+            original[byte_offset+1:])
+
+def corrupt(res, s, offset, shnums_to_corrupt=None):
+    # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
+    # list of shnums to corrupt.
+    for peerid in s._peers:
+        shares = s._peers[peerid]
+        for shnum in shares:
+            if (shnums_to_corrupt is not None
+                and shnum not in shnums_to_corrupt):
+                continue
+            data = shares[shnum]
+            (version,
+             seqnum,
+             root_hash,
+             IV,
+             k, N, segsize, datalen,
+             o) = unpack_header(data)
+            if isinstance(offset, tuple):
+                offset1, offset2 = offset
+            else:
+                offset1 = offset
+                offset2 = 0
+            if offset1 == "pubkey":
+                real_offset = 107
+            elif offset1 in o:
+                real_offset = o[offset1]
+            else:
+                real_offset = offset1
+            real_offset = int(real_offset) + offset2
+            assert isinstance(real_offset, int), offset
+            shares[shnum] = flip_bit(data, real_offset)
+    return res
+
 class Filenode(unittest.TestCase):
     def setUp(self):
         self.client = FakeClient()
@@ -428,6 +465,41 @@ class Servermap(unittest.TestCase):
 
         return d
 
+    def test_mark_bad(self):
+        d = defer.succeed(None)
+        ms = self.make_servermap
+        us = self.update_servermap
+
+        d.addCallback(lambda res: ms(mode=MODE_ENOUGH))
+        d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6))
+        def _made_map(sm):
+            v = sm.best_recoverable_version()
+            vm = sm.make_versionmap()
+            shares = list(vm[v])
+            self.failUnlessEqual(len(shares), 6)
+            self._corrupted = set()
+            # mark the first 5 shares as corrupt, then update the servermap.
+            # The map should not have the marked shares it in any more, and
+            # new shares should be found to replace the missing ones.
+            for (shnum, peerid, timestamp) in shares:
+                if shnum < 5:
+                    self._corrupted.add( (peerid, shnum) )
+                    sm.mark_bad_share(peerid, shnum)
+            return self.update_servermap(sm, MODE_WRITE)
+        d.addCallback(_made_map)
+        def _check_map(sm):
+            # this should find all 5 shares that weren't marked bad
+            v = sm.best_recoverable_version()
+            vm = sm.make_versionmap()
+            shares = list(vm[v])
+            for (peerid, shnum) in self._corrupted:
+                peer_shares = sm.shares_on_peer(peerid)
+                self.failIf(shnum in peer_shares,
+                            "%d was in %s" % (shnum, peer_shares))
+            self.failUnlessEqual(len(shares), 5)
+        d.addCallback(_check_map)
+        return d
+
     def failUnlessNoneRecoverable(self, sm):
         self.failUnlessEqual(len(sm.recoverable_versions()), 0)
         self.failUnlessEqual(len(sm.unrecoverable_versions()), 0)
@@ -565,11 +637,6 @@ class Roundtrip(unittest.TestCase):
         d.addCallback(_retrieved)
         return d
 
-    def flip_bit(self, original, byte_offset):
-        return (original[:byte_offset] +
-                chr(ord(original[byte_offset]) ^ 0x01) +
-                original[byte_offset+1:])
-
 
     def shouldFail(self, expected_failure, which, substring,
                     callable, *args, **kwargs):
@@ -588,46 +655,14 @@ class Roundtrip(unittest.TestCase):
         d.addBoth(done)
         return d
 
-    def _corrupt(self, res, s, offset, shnums_to_corrupt=None):
-        # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
-        # list of shnums to corrupt.
-        for peerid in s._peers:
-            shares = s._peers[peerid]
-            for shnum in shares:
-                if (shnums_to_corrupt is not None
-                    and shnum not in shnums_to_corrupt):
-                    continue
-                data = shares[shnum]
-                (version,
-                 seqnum,
-                 root_hash,
-                 IV,
-                 k, N, segsize, datalen,
-                 o) = unpack_header(data)
-                if isinstance(offset, tuple):
-                    offset1, offset2 = offset
-                else:
-                    offset1 = offset
-                    offset2 = 0
-                if offset1 == "pubkey":
-                    real_offset = 107
-                elif offset1 in o:
-                    real_offset = o[offset1]
-                else:
-                    real_offset = offset1
-                real_offset = int(real_offset) + offset2
-                assert isinstance(real_offset, int), offset
-                shares[shnum] = self.flip_bit(data, real_offset)
-        return res
-
     def _test_corrupt_all(self, offset, substring,
                           should_succeed=False, corrupt_early=True):
         d = defer.succeed(None)
         if corrupt_early:
-            d.addCallback(self._corrupt, self._storage, offset)
+            d.addCallback(corrupt, self._storage, offset)
         d.addCallback(lambda res: self.make_servermap())
         if not corrupt_early:
-            d.addCallback(self._corrupt, self._storage, offset)
+            d.addCallback(corrupt, self._storage, offset)
         def _do_retrieve(servermap):
             ver = servermap.best_recoverable_version()
             if ver is None and not should_succeed:
@@ -637,9 +672,8 @@ class Roundtrip(unittest.TestCase):
                     allproblems = [str(f) for f in servermap.problems]
                     self.failUnless(substring in "".join(allproblems))
                 return
-            r = Retrieve(self._fn, servermap, ver)
             if should_succeed:
-                d1 = r.download()
+                d1 = self._fn.download_to_data()
                 d1.addCallback(lambda new_contents:
                                self.failUnlessEqual(new_contents, self.CONTENTS))
                 return d1
@@ -647,7 +681,7 @@ class Roundtrip(unittest.TestCase):
                 return self.shouldFail(NotEnoughSharesError,
                                        "_corrupt_all(offset=%s)" % (offset,),
                                        substring,
-                                       r.download)
+                                       self._fn.download_to_data)
         d.addCallback(_do_retrieve)
         return d
 
@@ -732,7 +766,7 @@ class Roundtrip(unittest.TestCase):
         k = self._fn.get_required_shares()
         N = self._fn.get_total_shares()
         d = defer.succeed(None)
-        d.addCallback(self._corrupt, self._storage, "pubkey",
+        d.addCallback(corrupt, self._storage, "pubkey",
                       shnums_to_corrupt=range(0, N-k))
         d.addCallback(lambda res: self.make_servermap())
         def _do_retrieve(servermap):
@@ -747,6 +781,29 @@ class Roundtrip(unittest.TestCase):
                       self.failUnlessEqual(new_contents, self.CONTENTS))
         return d
 
+    def test_corrupt_some(self):
+        # corrupt the data of first five shares (so the servermap thinks
+        # they're good but retrieve marks them as bad), so that the
+        # MODE_ENOUGH set of 6 will be insufficient, forcing node.download to
+        # retry with more servers.
+        corrupt(None, self._storage, "share_data", range(5))
+        d = self.make_servermap()
+        def _do_retrieve(servermap):
+            ver = servermap.best_recoverable_version()
+            self.failUnless(ver)
+            return self._fn.download_to_data()
+        d.addCallback(_do_retrieve)
+        d.addCallback(lambda new_contents:
+                      self.failUnlessEqual(new_contents, self.CONTENTS))
+        return d
+
+    def test_download_fails(self):
+        corrupt(None, self._storage, "signature")
+        d = self.shouldFail(UnrecoverableFileError, "test_download_anyway",
+                            "no recoverable versions",
+                            self._fn.download_to_data)
+        return d
+
 
 class MultipleEncodings(unittest.TestCase):
     def setUp(self):
-- 
2.45.2