mutable: fix multiple-versions-interfering-with-each-other bug. replace() tests now...
authorBrian Warner <warner@lothar.com>
Thu, 8 Nov 2007 11:07:33 +0000 (04:07 -0700)
committerBrian Warner <warner@lothar.com>
Thu, 8 Nov 2007 11:07:33 +0000 (04:07 -0700)
src/allmydata/mutable.py
src/allmydata/test/test_system.py

index 7aea078fbd5e9212add1307f9d53a90314aff4bd..89422e9da27a00694de11f1f97ad5e77e557bc8d 100644 (file)
@@ -176,8 +176,8 @@ class Retrieve:
         self._last_failure = None
 
     def log(self, msg):
-        #self._node._client.log(msg)
-        pass
+        prefix = idlib.b2a(self._node.get_storage_index())[:6]
+        self._node._client.log("Retrieve(%s): %s" % (prefix, msg))
 
     def log_err(self, f):
         log.err(f)
@@ -227,8 +227,6 @@ class Retrieve:
         # we might not know how many shares we need yet.
         self._required_shares = self._node.get_required_shares()
         self._total_shares = self._node.get_total_shares()
-        self._segsize = None
-        self._datalength = None
 
         # self._valid_versions is a dictionary in which the keys are
         # 'verinfo' tuples (seqnum, root_hash, IV). Every time we hear about
@@ -343,7 +341,7 @@ class Retrieve:
                 self._pubkey = self._deserialize_pubkey(pubkey_s)
                 self._node._populate_pubkey(self._pubkey)
 
-            verinfo = (seqnum, root_hash, IV)
+            verinfo = (seqnum, root_hash, IV, segsize, datalength)
             if verinfo not in self._valid_versions:
                 # it's a new pair. Verify the signature.
                 valid = self._pubkey.verify(prefix, signature)
@@ -352,6 +350,10 @@ class Retrieve:
                                             "signature is invalid")
                 # ok, it's a valid verinfo. Add it to the list of validated
                 # versions.
+                self.log("found valid version %d-%s from %s-sh%d: %d-%d/%d/%d"
+                         % (seqnum, idlib.b2a(root_hash)[:4],
+                            idlib.shortnodeid_b2a(peerid), shnum,
+                            k, N, segsize, datalength))
                 self._valid_versions[verinfo] = (prefix, DictOfSets())
 
                 # and make a note of the other parameters we've just learned
@@ -361,10 +363,6 @@ class Retrieve:
                 if self._total_shares is None:
                     self._total_shares = N
                     self._node._populate_total_shares(N)
-                if self._segsize is None:
-                    self._segsize = segsize
-                if self._datalength is None:
-                    self._datalength = datalength
 
             # we've already seen this pair, and checked the signature so we
             # know it's a valid candidate. Accumulate the share info, if
@@ -404,7 +402,19 @@ class Retrieve:
             # len(sharemap) is the number of distinct shares that appear to
             # be available.
             if len(sharemap) >= self._required_shares:
-                # this one looks retrievable
+                # this one looks retrievable. TODO: our policy of decoding
+                # the first version that we can get is a bit troublesome: in
+                # a small grid with a large expansion factor, a single
+                # out-of-date server can cause us to retrieve an older
+                # version. Fixing this is equivalent to protecting ourselves
+                # against a rollback attack, and the best approach is
+                # probably to say that we won't do _attempt_decode until:
+                #  (we've received at least k+EPSILON shares or
+                #   we've received at least k shares and ran out of servers)
+                # in that case, identify the verinfos that are decodeable and
+                # attempt the one with the highest (seqnum,R) value. If the
+                # highest seqnum can't be recovered, only then might we fall
+                # back to an older version.
                 d = defer.maybeDeferred(self._attempt_decode, verinfo, sharemap)
                 def _problem(f):
                     self._last_failure = f
@@ -463,7 +473,7 @@ class Retrieve:
 
     def _attempt_decode(self, verinfo, sharemap):
         # sharemap is a dict which maps shnum to [(peerid,data)..] sets.
-        (seqnum, root_hash, IV) = verinfo
+        (seqnum, root_hash, IV, segsize, datalength) = verinfo
 
         # first, validate each share that we haven't validated yet. We use
         # self._valid_shares to remember which ones we've already checked.
@@ -500,7 +510,7 @@ class Retrieve:
         # at this point, all shares in the sharemap are valid, and they're
         # all for the same seqnum+root_hash version, so it's now down to
         # doing FEC and decrypt.
-        d = defer.maybeDeferred(self._decode, shares)
+        d = defer.maybeDeferred(self._decode, shares, segsize, datalength)
         d.addCallback(self._decrypt, IV, seqnum, root_hash)
         return d
 
@@ -528,9 +538,8 @@ class Retrieve:
         self.log(" data valid! len=%d" % len(share_data))
         return share_data
 
-    def _decode(self, shares_dict):
+    def _decode(self, shares_dict, segsize, datalength):
         # we ought to know these values by now
-        assert self._segsize is not None
         assert self._required_shares is not None
         assert self._total_shares is not None
 
@@ -546,7 +555,7 @@ class Retrieve:
         shares = shares[:self._required_shares]
 
         fec = codec.CRSDecoder()
-        params = "%d-%d-%d" % (self._segsize,
+        params = "%d-%d-%d" % (segsize,
                                self._required_shares, self._total_shares)
         fec.set_serialized_params(params)
 
@@ -556,7 +565,9 @@ class Retrieve:
         def _done(buffers):
             self.log(" decode done, %d buffers" % len(buffers))
             segment = "".join(buffers)
-            segment = segment[:self._datalength]
+            self.log(" joined length %d, datalength %d" %
+                     (len(segment), datalength))
+            segment = segment[:datalength]
             self.log(" segment len=%d" % len(segment))
             return segment
         def _err(f):
@@ -597,7 +608,7 @@ class Publish:
 
     def log(self, msg):
         prefix = idlib.b2a(self._node.get_storage_index())[:6]
-        self._node._client.log("%s: %s" % (prefix, msg))
+        self._node._client.log("Publish(%s): %s" % (prefix, msg))
 
     def log_err(self, f):
         log.err(f)
@@ -617,7 +628,7 @@ class Publish:
         # 4a: may need to run recovery algorithm
         # 5: when enough responses are back, we're done
 
-        self.log("starting publish")
+        self.log("starting publish, data is %r" % (newdata,))
 
         self._storage_index = self._node.get_storage_index()
         self._writekey = self._node.get_writekey()
@@ -666,14 +677,19 @@ class Publish:
         self.log("_query_peers")
 
         storage_index = self._storage_index
+
+        # we need to include ourselves in the list for two reasons. The most
+        # important is so that any shares which already exist on our own
+        # server get updated. The second is to ensure that we leave a share
+        # on our own server, so we're more likely to have the signing key
+        # around later. This way, even if all the servers die and the
+        # directory contents are unrecoverable, at least we can still push
+        # out a new copy with brand-new contents. TODO: it would be nice if
+        # the share we use for ourselves didn't count against the N total..
+        # maybe use N+1 if we find ourselves in the permuted list?
+
         peerlist = self._node._client.get_permuted_peers(storage_index,
-                                                         include_myself=False)
-        # we don't include ourselves in the N peers, but we *do* push an
-        # extra copy of share[0] to ourselves so we're more likely to have
-        # the signing key around later. This way, even if all the servers die
-        # and the directory contents are unrecoverable, at least we can still
-        # push out a new copy with brand-new contents.
-        # TODO: actually push this copy
+                                                         include_myself=True)
 
         current_share_peers = DictOfSets()
         reachable_peers = {}
@@ -938,6 +954,7 @@ class Publish:
                                               for i in needed_hashes ] )
         root_hash = share_hash_tree[0]
         assert len(root_hash) == 32
+        self.log("my new root_hash is %s" % idlib.b2a(root_hash))
 
         prefix = pack_prefix(seqnum, root_hash, IV,
                              required_shares, total_shares,
@@ -1051,19 +1068,20 @@ class Publish:
         wrote, read_data = answer
         surprised = False
 
-        if not wrote:
+        (new_seqnum,new_root_hash,new_IV) = unpack_checkstring(my_checkstring)
+
+        if wrote:
+            for shnum in tw_vectors:
+                dispatch_map.add(shnum, (peerid, new_seqnum, new_root_hash))
+        else:
             # surprise! our testv failed, so the write did not happen
             surprised = True
 
         for shnum, (old_cs,) in read_data.items():
             (old_seqnum, old_root_hash, IV) = unpack_checkstring(old_cs)
-            if wrote and shnum in tw_vectors:
-                cur_cs = my_checkstring
-            else:
-                cur_cs = old_cs
 
-            (cur_seqnum, cur_root_hash, IV) = unpack_checkstring(cur_cs)
-            dispatch_map.add(shnum, (peerid, cur_seqnum, cur_root_hash))
+            if not wrote:
+                dispatch_map.add(shnum, (peerid, old_seqnum, old_root_hash))
 
             if shnum not in expected_old_shares:
                 # surprise! there was a share we didn't know about
@@ -1077,9 +1095,19 @@ class Publish:
         if surprised:
             self._surprised = True
 
+    def log_dispatch_map(self, dispatch_map):
+        for shnum, places in dispatch_map.items():
+            sent_to = [(idlib.shortnodeid_b2a(peerid),
+                        seqnum,
+                        idlib.b2a(root_hash)[:4])
+                       for (peerid,seqnum,root_hash) in places]
+            self.log(" share %d sent to: %s" % (shnum, sent_to))
+
     def _maybe_recover(self, (surprised, dispatch_map)):
-        self.log("_maybe_recover")
+        self.log("_maybe_recover, surprised=%s, dispatch_map:" % surprised)
+        self.log_dispatch_map(dispatch_map)
         if not surprised:
+            self.log(" no recovery needed")
             return
         print "RECOVERY NOT YET IMPLEMENTED"
         # but dispatch_map will help us do it
index 5f35da8ac93bd591a7128bb2209f3ecc7b7b8d9d..6c309d1154a0d4c7127b60de2b9e00b3a2c77f47 100644 (file)
@@ -248,16 +248,13 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
         d = self.set_up_nodes()
 
         def _create_mutable(res):
-            from allmydata.mutable import MutableFileNode
-            #print "CREATING MUTABLE FILENODE"
             c = self.clients[0]
-            n = MutableFileNode(c)
-            d1 = n.create(DATA)
+            log.msg("starting create_mutable_file")
+            d1 = c.create_mutable_file(DATA)
             def _done(res):
                 log.msg("DONE: %s" % (res,))
                 self._mutable_node_1 = res
                 uri = res.get_uri()
-                #print "DONE", uri
             d1.addCallback(_done)
             return d1
         d.addCallback(_create_mutable)
@@ -335,57 +332,53 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
 
         d.addCallback(lambda res: self._mutable_node_1.download_to_data())
         def _check_download_1(res):
-            #print "_check_download_1"
             self.failUnlessEqual(res, DATA)
             # now we see if we can retrieve the data from a new node,
             # constructed using the URI of the original one. We do this test
             # on the same client that uploaded the data.
-            #print "download1 good, starting download2"
             uri = self._mutable_node_1.get_uri()
+            log.msg("starting retrieve1")
             newnode = self.clients[0].create_mutable_file_from_uri(uri)
             return newnode.download_to_data()
-            return d
         d.addCallback(_check_download_1)
 
         def _check_download_2(res):
-            #print "_check_download_2"
             self.failUnlessEqual(res, DATA)
             # same thing, but with a different client
-            #print "starting download 3"
             uri = self._mutable_node_1.get_uri()
             newnode = self.clients[1].create_mutable_file_from_uri(uri)
+            log.msg("starting retrieve2")
             d1 = newnode.download_to_data()
             d1.addCallback(lambda res: (res, newnode))
             return d1
         d.addCallback(_check_download_2)
 
         def _check_download_3((res, newnode)):
-            #print "_check_download_3"
             self.failUnlessEqual(res, DATA)
             # replace the data
-            #print "REPLACING"
+            log.msg("starting replace1")
             d1 = newnode.replace(NEWDATA)
             d1.addCallback(lambda res: newnode.download_to_data())
             return d1
         d.addCallback(_check_download_3)
 
         def _check_download_4(res):
-            print "_check_download_4"
             self.failUnlessEqual(res, NEWDATA)
             # now create an even newer node and replace the data on it. This
             # new node has never been used for download before.
             uri = self._mutable_node_1.get_uri()
             newnode1 = self.clients[2].create_mutable_file_from_uri(uri)
             newnode2 = self.clients[3].create_mutable_file_from_uri(uri)
+            log.msg("starting replace2")
             d1 = newnode1.replace(NEWERDATA)
             d1.addCallback(lambda res: newnode2.download_to_data())
             return d1
-        #d.addCallback(_check_download_4)
+        d.addCallback(_check_download_4)
 
         def _check_download_5(res):
-            print "_check_download_5"
+            log.msg("finished replace2")
             self.failUnlessEqual(res, NEWERDATA)
-        #d.addCallback(_check_download_5)
+        d.addCallback(_check_download_5)
 
         return d