From e29323f68fc5447b9f2fd69839e4375d28287852 Mon Sep 17 00:00:00 2001
From: Kevan Carstensen <kevan@isnotajoke.com>
Date: Tue, 27 Dec 2011 21:33:58 -0800
Subject: [PATCH] mutable publish: track multiple servers-per-share. Fixes some
 of #1628.

The remaining work is to write additional tests.

src/allmydata/test/no_network.py:

 This supports tests in which servers leave the grid only to return with
 their shares intact at a later time.

src/allmydata/test/test_mutable.py:

 The UCWEs in the incident reports associated with #1628 all seem to be
 associated with shares that the servermap knows about, but which aren't
 accounted for during the publish process for whatever reason. Specifically,
 it looks like the publisher is only capable of keeping track of a single
 storage server for a given share. This makes the repair process worse than
 it was pre-MDMF at updating all of the shares of a particular file to the
 newest version, and can also cause spurious UCWEs. This test simulates such
 a layout and fails if an UCWE is thrown. We need to write another test to
 ensure that all copies of a share are updated to the latest version (or
 alter this test to do that), so that the test suite doesn't pass unless both
 regressions are fixed.

 We want the publisher to follow the existing share placement when uploading
 a new version of a mutable file, and we don't want this test to pass unless
 it does.

src/allmydata/mutable/publish.py:

 Before this commit, the publisher only kept track of a single writer for
 each share. This is insufficient to handle updates in which a single share
 may live on multiple servers. In the best case, an update will only update
 one of the existing shares instead of all of them. In some cases, the update
 will encounter the existing shares when publishing some other share,
 interpret it as a sign of an uncoordinated update, and fail. Keeping track
 of all of the writers helps ensure that all existing shares are updated, and
 helps avoid spurious uncoordinated write errors.
---
 src/allmydata/mutable/publish.py   | 142 ++++++++++++++++-------------
 src/allmydata/test/no_network.py   |   1 +
 src/allmydata/test/test_mutable.py |  39 ++++++++
 3 files changed, 119 insertions(+), 63 deletions(-)

diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py
index b028779b..1e2de76d 100644
--- a/src/allmydata/mutable/publish.py
+++ b/src/allmydata/mutable/publish.py
@@ -269,31 +269,33 @@ class Publish:
             cancel_secret = self._node.get_cancel_secret(server)
             secrets = (write_enabler, renew_secret, cancel_secret)
 
-            self.writers[shnum] =  writer_class(shnum,
-                                                server.get_rref(),
-                                                self._storage_index,
-                                                secrets,
-                                                self._new_seqnum,
-                                                self.required_shares,
-                                                self.total_shares,
-                                                self.segment_size,
-                                                self.datalength)
-            self.writers[shnum].server = server
+            writer = writer_class(shnum,
+                                  server.get_rref(),
+                                  self._storage_index,
+                                  secrets,
+                                  self._new_seqnum,
+                                  self.required_shares,
+                                  self.total_shares,
+                                  self.segment_size,
+                                  self.datalength)
+
+            self.writers.setdefault(shnum, []).append(writer)
+            writer.server = server
             known_shares = self._servermap.get_known_shares()
             assert (server, shnum) in known_shares
             old_versionid, old_timestamp = known_shares[(server,shnum)]
             (old_seqnum, old_root_hash, old_salt, old_segsize,
              old_datalength, old_k, old_N, old_prefix,
              old_offsets_tuple) = old_versionid
-            self.writers[shnum].set_checkstring(old_seqnum,
-                                                old_root_hash,
-                                                old_salt)
+            writer.set_checkstring(old_seqnum,
+                                   old_root_hash,
+                                   old_salt)
 
         # Our remote shares will not have a complete checkstring until
         # after we are done writing share data and have started to write
         # blocks. In the meantime, we need to know what to look for when
         # writing, so that we can detect UncoordinatedWriteErrors.
-        self._checkstring = self.writers.values()[0].get_checkstring()
+        self._checkstring = self.writers.values()[0][0].get_checkstring()
 
         # Now, we start pushing shares.
         self._status.timings["setup"] = time.time() - self._started
@@ -466,34 +468,35 @@ class Publish:
             cancel_secret = self._node.get_cancel_secret(server)
             secrets = (write_enabler, renew_secret, cancel_secret)
 
-            self.writers[shnum] =  writer_class(shnum,
-                                                server.get_rref(),
-                                                self._storage_index,
-                                                secrets,
-                                                self._new_seqnum,
-                                                self.required_shares,
-                                                self.total_shares,
-                                                self.segment_size,
-                                                self.datalength)
-            self.writers[shnum].server = server
+            writer =  writer_class(shnum,
+                                   server.get_rref(),
+                                   self._storage_index,
+                                   secrets,
+                                   self._new_seqnum,
+                                   self.required_shares,
+                                   self.total_shares,
+                                   self.segment_size,
+                                   self.datalength)
+            self.writers.setdefault(shnum, []).append(writer)
+            writer.server = server
             known_shares = self._servermap.get_known_shares()
             if (server, shnum) in known_shares:
                 old_versionid, old_timestamp = known_shares[(server,shnum)]
                 (old_seqnum, old_root_hash, old_salt, old_segsize,
                  old_datalength, old_k, old_N, old_prefix,
                  old_offsets_tuple) = old_versionid
-                self.writers[shnum].set_checkstring(old_seqnum,
-                                                    old_root_hash,
-                                                    old_salt)
+                writer.set_checkstring(old_seqnum,
+                                       old_root_hash,
+                                       old_salt)
             elif (server, shnum) in self.bad_share_checkstrings:
                 old_checkstring = self.bad_share_checkstrings[(server, shnum)]
-                self.writers[shnum].set_checkstring(old_checkstring)
+                writer.set_checkstring(old_checkstring)
 
         # Our remote shares will not have a complete checkstring until
         # after we are done writing share data and have started to write
         # blocks. In the meantime, we need to know what to look for when
         # writing, so that we can detect UncoordinatedWriteErrors.
-        self._checkstring = self.writers.values()[0].get_checkstring()
+        self._checkstring = self.writers.values()[0][0].get_checkstring()
 
         # Now, we start pushing shares.
         self._status.timings["setup"] = time.time() - self._started
@@ -620,7 +623,10 @@ class Publish:
         # Can we still successfully publish this file?
         # TODO: Keep track of outstanding queries before aborting the
         #       process.
-        if len(self.writers) < self.required_shares or self.surprised:
+        all_writers = []
+        for shnum, writers in self.writers.iteritems():
+            all_writers.extend(writers)
+        if len(all_writers) < self.required_shares or self.surprised:
             return self._failure()
 
         # Figure out what we need to do next. Each of these needs to
@@ -675,8 +681,9 @@ class Publish:
         salt = os.urandom(16)
         assert self._version == SDMF_VERSION
 
-        for writer in self.writers.itervalues():
-            writer.put_salt(salt)
+        for shnum, writers in self.writers.iteritems():
+            for writer in writers:
+                writer.put_salt(salt)
 
 
     def _encode_segment(self, segnum):
@@ -751,8 +758,9 @@ class Publish:
             block_hash = hashutil.block_hash(hashed)
             self.blockhashes[shareid][segnum] = block_hash
             # find the writer for this share
-            writer = self.writers[shareid]
-            writer.put_block(sharedata, segnum, salt)
+            writers = self.writers[shareid]
+            for writer in writers:
+                writer.put_block(sharedata, segnum, salt)
 
 
     def push_everything_else(self):
@@ -775,8 +783,9 @@ class Publish:
     def push_encprivkey(self):
         encprivkey = self._encprivkey
         self._status.set_status("Pushing encrypted private key")
-        for writer in self.writers.itervalues():
-            writer.put_encprivkey(encprivkey)
+        for shnum, writers in self.writers.iteritems():
+            for writer in writers:
+                writer.put_encprivkey(encprivkey)
 
 
     def push_blockhashes(self):
@@ -788,8 +797,9 @@ class Publish:
             # set the leaf for future use.
             self.sharehash_leaves[shnum] = t[0]
 
-            writer = self.writers[shnum]
-            writer.put_blockhashes(self.blockhashes[shnum])
+            writers = self.writers[shnum]
+            for writer in writers:
+                writer.put_blockhashes(self.blockhashes[shnum])
 
 
     def push_sharehashes(self):
@@ -799,8 +809,9 @@ class Publish:
             needed_indices = share_hash_tree.needed_hashes(shnum)
             self.sharehashes[shnum] = dict( [ (i, share_hash_tree[i])
                                              for i in needed_indices] )
-            writer = self.writers[shnum]
-            writer.put_sharehashes(self.sharehashes[shnum])
+            writers = self.writers[shnum]
+            for writer in writers:
+                writer.put_sharehashes(self.sharehashes[shnum])
         self.root_hash = share_hash_tree[0]
 
 
@@ -811,8 +822,9 @@ class Publish:
         #   - Push the signature
         self._status.set_status("Pushing root hashes and signature")
         for shnum in xrange(self.total_shares):
-            writer = self.writers[shnum]
-            writer.put_root_hash(self.root_hash)
+            writers = self.writers[shnum]
+            for writer in writers:
+                writer.put_root_hash(self.root_hash)
         self._update_checkstring()
         self._make_and_place_signature()
 
@@ -825,7 +837,7 @@ class Publish:
         uncoordinated writes. SDMF files will have the same checkstring,
         so we need not do anything.
         """
-        self._checkstring = self.writers.values()[0].get_checkstring()
+        self._checkstring = self.writers.values()[0][0].get_checkstring()
 
 
     def _make_and_place_signature(self):
@@ -834,11 +846,12 @@ class Publish:
         """
         started = time.time()
         self._status.set_status("Signing prefix")
-        signable = self.writers[0].get_signable()
+        signable = self.writers.values()[0][0].get_signable()
         self.signature = self._privkey.sign(signable)
 
-        for (shnum, writer) in self.writers.iteritems():
-            writer.put_signature(self.signature)
+        for (shnum, writers) in self.writers.iteritems():
+            for writer in writers:
+                writer.put_signature(self.signature)
         self._status.timings['sign'] = time.time() - started
 
 
@@ -851,25 +864,26 @@ class Publish:
         ds = []
         verification_key = self._pubkey.serialize()
 
-        for (shnum, writer) in self.writers.copy().iteritems():
-            writer.put_verification_key(verification_key)
-            self.num_outstanding += 1
-            def _no_longer_outstanding(res):
-                self.num_outstanding -= 1
-                return res
-
-            d = writer.finish_publishing()
-            d.addBoth(_no_longer_outstanding)
-            d.addErrback(self._connection_problem, writer)
-            d.addCallback(self._got_write_answer, writer, started)
-            ds.append(d)
+        for (shnum, writers) in self.writers.copy().iteritems():
+            for writer in writers:
+                writer.put_verification_key(verification_key)
+                self.num_outstanding += 1
+                def _no_longer_outstanding(res):
+                    self.num_outstanding -= 1
+                    return res
+
+                d = writer.finish_publishing()
+                d.addBoth(_no_longer_outstanding)
+                d.addErrback(self._connection_problem, writer)
+                d.addCallback(self._got_write_answer, writer, started)
+                ds.append(d)
         self._record_verinfo()
         self._status.timings['pack'] = time.time() - started
         return defer.DeferredList(ds)
 
 
     def _record_verinfo(self):
-        self.versioninfo = self.writers.values()[0].get_verinfo()
+        self.versioninfo = self.writers.values()[0][0].get_verinfo()
 
 
     def _connection_problem(self, f, writer):
@@ -879,7 +893,7 @@ class Publish:
         """
         self.log("found problem: %s" % str(f))
         self._last_failure = f
-        del(self.writers[writer.shnum])
+        self.writers[writer.shnum].remove(writer)
 
 
     def log_goal(self, goal, message=""):
@@ -988,9 +1002,11 @@ class Publish:
         # knowingly also writing to that server from other writers.
 
         # TODO: Precompute this.
-        known_shnums = [x.shnum for x in self.writers.values()
-                        if x.server == server]
-        surprise_shares -= set(known_shnums)
+        shares = []
+        for shnum, writers in self.writers.iteritems():
+            shares.extend([x.shnum for x in writers if x.server == server])
+        known_shnums = set(shares)
+        surprise_shares -= known_shnums
         self.log("found the following surprise shares: %s" %
                  str(surprise_shares))
 
diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py
index bb9c8f2b..4bac7d1b 100644
--- a/src/allmydata/test/no_network.py
+++ b/src/allmydata/test/no_network.py
@@ -289,6 +289,7 @@ class NoNetworkGrid(service.MultiService):
         del self.wrappers_by_id[serverid]
         del self.proxies_by_id[serverid]
         self.rebuild_serverlist()
+        return ss
 
     def break_server(self, serverid):
         # mark the given server as broken, so it will throw exceptions when
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 32602bd7..2722a3cf 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -2534,6 +2534,45 @@ class Problems(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
         d.addCallback(_created)
         return d
 
+    def test_multiply_placed_shares(self):
+        self.basedir = "mutable/Problems/test_multiply_placed_shares"
+        self.set_up_grid()
+        self.g.clients[0].DEFAULT_ENCODING_PARAMETERS['n'] = 75
+        nm = self.g.clients[0].nodemaker
+        d = nm.create_mutable_file(MutableData("contents 1"))
+        # remove one of the servers and reupload the file.
+        def _created(n):
+            self._node = n
+
+            servers = self.g.get_all_serverids()
+            self.ss = self.g.remove_server(servers[len(servers)-1])
+
+            new_server = self.g.make_server(len(servers)-1)
+            self.g.add_server(len(servers)-1, new_server)
+
+            return self._node.download_best_version()
+        d.addCallback(_created)
+        d.addCallback(lambda data: MutableData(data))
+        d.addCallback(lambda data: self._node.overwrite(data))
+
+        # restore the server we removed earlier, then download+upload
+        # the file again
+        def _overwritten(ign):
+            self.g.add_server(len(self.g.servers_by_number), self.ss)
+            return self._node.download_best_version()
+        d.addCallback(_overwritten)
+        d.addCallback(lambda data: MutableData(data))
+        d.addCallback(lambda data: self._node.overwrite(data))
+        d.addCallback(lambda ignored:
+            self._node.get_servermap(MODE_CHECK))
+        def _overwritten_again(smap):
+            # Make sure that all shares were updated by making sure that
+            # there aren't any other versions in the sharemap.
+            self.failUnlessEqual(len(smap.recoverable_versions()), 1)
+            self.failUnlessEqual(len(smap.unrecoverable_versions()), 0)
+        d.addCallback(_overwritten_again)
+        return d
+
     def test_bad_server(self):
         # Break one server, then create the file: the initial publish should
         # complete with an alternate server. Breaking a second server should
-- 
2.45.2