From d6f2dbbac7edbd09889606e84d8c49f7f040921e Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Tue, 13 Nov 2007 23:08:15 -0700
Subject: [PATCH] mutable: handle bad hashes, improve test coverage, rearrange
 slightly to facilitate these

---
 src/allmydata/mutable.py           | 144 ++++++++++++++++++++---------
 src/allmydata/scripts/debug.py     |   3 +-
 src/allmydata/test/test_mutable.py |  11 +--
 src/allmydata/test/test_system.py  |  98 +++++++++++++++-----
 4 files changed, 183 insertions(+), 73 deletions(-)

diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py
index fadcea5a..ba6b68e4 100644
--- a/src/allmydata/mutable.py
+++ b/src/allmydata/mutable.py
@@ -31,10 +31,11 @@ class UncoordinatedWriteError(Exception):
 
 class CorruptShareError(Exception):
     def __init__(self, peerid, shnum, reason):
+        self.args = (peerid, shnum, reason)
         self.peerid = peerid
         self.shnum = shnum
         self.reason = reason
-    def __repr__(self):
+    def __str__(self):
         short_peerid = idlib.nodeid_b2a(self.peerid)[:8]
         return "<CorruptShareError peerid=%s shnum[%d]: %s" % (short_peerid,
                                                                self.shnum,
@@ -104,6 +105,7 @@ def unpack_share(data):
         chunk = share_hash_chain_s[i:i+hsize]
         (hid, h) = struct.unpack(share_hash_format, chunk)
         share_hash_chain.append( (hid, h) )
+    share_hash_chain = dict(share_hash_chain)
     block_hash_tree_s = data[o['block_hash_tree']:o['share_data']]
     assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s)
     block_hash_tree = []
@@ -167,6 +169,32 @@ def pack_offsets(verification_key_length, signature_length,
                        offsets['enc_privkey'],
                        offsets['EOF'])
 
+def pack_share(prefix, verification_key, signature,
+               share_hash_chain, block_hash_tree,
+               share_data, encprivkey):
+    share_hash_chain_s = "".join([struct.pack(">H32s", i, share_hash_chain[i])
+                                  for i in sorted(share_hash_chain.keys())])
+    for h in block_hash_tree:
+        assert len(h) == 32
+    block_hash_tree_s = "".join(block_hash_tree)
+
+    offsets = pack_offsets(len(verification_key),
+                           len(signature),
+                           len(share_hash_chain_s),
+                           len(block_hash_tree_s),
+                           len(share_data),
+                           len(encprivkey))
+    final_share = "".join([prefix,
+                           offsets,
+                           verification_key,
+                           signature,
+                           share_hash_chain_s,
+                           block_hash_tree_s,
+                           share_data,
+                           encprivkey])
+    return final_share
+
+
 class Retrieve:
     def __init__(self, filenode):
         self._node = filenode
@@ -224,6 +252,8 @@ class Retrieve:
         # 7: if we discover corrupt shares during the reconstruction process,
         #    remove that share from the sharemap.  and start step#6 again.
 
+        self.log("starting retrieval")
+
         initial_query_count = 5
         self._read_size = 2000
 
@@ -245,11 +275,12 @@ class Retrieve:
         # continuing through the last byte of sharedata.
         self._valid_versions = {}
 
-        # self._valid_shares is a set (peerid,data) tuples. Each time we
-        # examine the hash chains inside a share and validate them against a
-        # signed root_hash, we add the share to self._valid_shares . We use
-        # this to avoid re-checking the hashes over and over again.
-        self._valid_shares = set()
+        # self._valid_shares is a dict mapping (peerid,data) tuples to
+        # validated sharedata strings. Each time we examine the hash chains
+        # inside a share and validate them against a signed root_hash, we add
+        # the share to self._valid_shares . We use this to avoid re-checking
+        # the hashes over and over again.
+        self._valid_shares = {}
 
         self._done_deferred = defer.Deferred()
 
@@ -332,6 +363,8 @@ class Retrieve:
 
         for shnum,datav in datavs.items():
             data = datav[0]
+            self.log("_got_results: got shnum #%d from peerid %s"
+                     % (shnum, idlib.shortnodeid_b2a(peerid)))
             (seqnum, root_hash, IV, k, N, segsize, datalength,
              pubkey_s, signature, prefix) = unpack_prefix_and_signature(data)
 
@@ -339,7 +372,7 @@ class Retrieve:
                 fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
                 if fingerprint != self._node._fingerprint:
                     # bad share
-                    raise CorruptShareError(peerid,
+                    raise CorruptShareError(peerid, shnum,
                                             "pubkey doesn't match fingerprint")
                 self._pubkey = self._deserialize_pubkey(pubkey_s)
                 self._node._populate_pubkey(self._pubkey)
@@ -349,11 +382,11 @@ class Retrieve:
                 # it's a new pair. Verify the signature.
                 valid = self._pubkey.verify(prefix, signature)
                 if not valid:
-                    raise CorruptShareError(peerid,
+                    raise CorruptShareError(peerid, shnum,
                                             "signature is invalid")
                 # ok, it's a valid verinfo. Add it to the list of validated
                 # versions.
-                self.log("found valid version %d-%s from %s-sh%d: %d-%d/%d/%d"
+                self.log(" found valid version %d-%s from %s-sh%d: %d-%d/%d/%d"
                          % (seqnum, idlib.b2a(root_hash)[:4],
                             idlib.shortnodeid_b2a(peerid), shnum,
                             k, N, segsize, datalength))
@@ -372,6 +405,7 @@ class Retrieve:
             # there's enough data present. If not, raise NeedMoreDataError,
             # which will trigger a re-fetch.
             _ignored = unpack_share(data)
+            self.log(" found enough data to add share contents")
             self._valid_versions[verinfo][1].add(shnum, (peerid, data))
 
 
@@ -391,9 +425,11 @@ class Retrieve:
         self._bad_peerids.add(peerid)
         short_sid = idlib.b2a(self._storage_index)[:6]
         if f.check(CorruptShareError):
-            self.log("WEIRD: bad share for %s: %s" % (short_sid, f))
+            self.log("WEIRD: bad share for %s: %s %s" % (short_sid, f,
+                                                         f.value))
         else:
-            self.log("WEIRD: other error for %s: %s" % (short_sid, f))
+            self.log("WEIRD: other error for %s: %s %s" % (short_sid, f,
+                                                           f.value))
 
     def _check_for_done(self, res):
         if not self._running:
@@ -422,13 +458,17 @@ class Retrieve:
                 def _problem(f):
                     self._last_failure = f
                     if f.check(CorruptShareError):
-                        # log(WEIRD)
+                        self.log("WEIRD: saw corrupt share, rescheduling")
                         # _attempt_decode is responsible for removing the bad
                         # share, so we can just try again
-                        eventually(self._check_for_done)
+                        eventually(self._check_for_done, None)
                         return
                     return f
                 d.addCallbacks(self._done, _problem)
+                # TODO: create an errback-routing mechanism to make sure that
+                # weird coding errors will cause the retrieval to fail rather
+                # than hanging forever. Any otherwise-unhandled exceptions
+                # should follow this path.
                 return
 
         # we don't have enough shares yet. Should we send out more queries?
@@ -478,13 +518,26 @@ class Retrieve:
         # sharemap is a dict which maps shnum to [(peerid,data)..] sets.
         (seqnum, root_hash, IV, segsize, datalength) = verinfo
 
+        assert len(sharemap) >= self._required_shares, len(sharemap)
+
+        shares_s = []
+        for shnum in sorted(sharemap.keys()):
+            for shareinfo in sharemap[shnum]:
+                shares_s.append("#%d" % shnum)
+        shares_s = ",".join(shares_s)
+        self.log("_attempt_decode: version %d-%s, shares: %s" %
+                 (seqnum, idlib.b2a(root_hash)[:4], shares_s))
+
         # first, validate each share that we haven't validated yet. We use
         # self._valid_shares to remember which ones we've already checked.
 
         shares = {}
         for shnum, shareinfos in sharemap.items():
+            assert len(shareinfos) > 0
             for shareinfo in shareinfos:
+                # have we already validated the hashes on this share?
                 if shareinfo not in self._valid_shares:
+                    # nope: must check the hashes and extract the actual data
                     (peerid,data) = shareinfo
                     try:
                         # The (seqnum+root_hash+IV) tuple for this share was
@@ -500,24 +553,36 @@ class Retrieve:
                         # validate the prefix on all shares) from using
                         # anything else in the share.
                         validator = self._validate_share_and_extract_data
-                        sharedata = validator(root_hash, shnum, data)
+                        sharedata = validator(peerid, root_hash, shnum, data)
                         assert isinstance(sharedata, str)
                     except CorruptShareError, e:
                         self.log("WEIRD: share was corrupt: %s" % e)
                         sharemap[shnum].discard(shareinfo)
+                        if not sharemap[shnum]:
+                            # remove the key so the test in _check_for_done
+                            # can accurately decide that we don't have enough
+                            # shares to try again right now.
+                            del sharemap[shnum]
                         # If there are enough remaining shares,
                         # _check_for_done() will try again
                         raise
-                    self._valid_shares.add(shareinfo)
-                    shares[shnum] = sharedata
-        # at this point, all shares in the sharemap are valid, and they're
-        # all for the same seqnum+root_hash version, so it's now down to
-        # doing FEC and decrypt.
+                    # share is valid: remember it so we won't need to check
+                    # (or extract) it again
+                    self._valid_shares[shareinfo] = sharedata
+
+                # the share is now in _valid_shares, so just copy over the
+                # sharedata
+                shares[shnum] = self._valid_shares[shareinfo]
+
+        # now that the big loop is done, all shares in the sharemap are
+        # valid, and they're all for the same seqnum+root_hash version, so
+        # it's now down to doing FEC and decrypt.
+        assert len(shares) >= self._required_shares, len(shares)
         d = defer.maybeDeferred(self._decode, shares, segsize, datalength)
         d.addCallback(self._decrypt, IV, seqnum, root_hash)
         return d
 
-    def _validate_share_and_extract_data(self, root_hash, shnum, data):
+    def _validate_share_and_extract_data(self, peerid, root_hash, shnum, data):
         # 'data' is the whole SMDF share
         self.log("_validate_share_and_extract_data[%d]" % shnum)
         assert data[0] == "\x00"
@@ -531,7 +596,7 @@ class Retrieve:
         leaves = [hashutil.block_hash(share_data)]
         t = hashtree.HashTree(leaves)
         if list(t) != block_hash_tree:
-            raise CorruptShareError("block hash tree failure")
+            raise CorruptShareError(peerid, shnum, "block hash tree failure")
         share_hash_leaf = t[0]
         # t2 = hashtree.IncompleteHashTree()
         # TODO: use shnum, share_hash_leaf, share_hash_chain to compare against
@@ -553,6 +618,7 @@ class Retrieve:
             shareids.append(shareid)
             shares.append(share)
 
+        assert len(shareids) >= self._required_shares, len(shareids)
         # zfec really doesn't want extra shares
         shareids = shareids[:self._required_shares]
         shares = shares[:self._required_shares]
@@ -650,7 +716,7 @@ class Publish:
         # 4a: may need to run recovery algorithm
         # 5: when enough responses are back, we're done
 
-        self.log("got enough peers")
+        self.log("starting publish, data is %r" % (newdata,))
 
         self._storage_index = self._node.get_storage_index()
         self._writekey = self._node.get_writekey()
@@ -747,9 +813,12 @@ class Publish:
 
     def _got_query_results(self, datavs, peerid, permutedid,
                            reachable_peers, current_share_peers):
+        self.log("_got_query_results")
+
         assert isinstance(datavs, dict)
         reachable_peers[peerid] = permutedid
         for shnum, datav in datavs.items():
+            self.log(" peer has shnum %d" % shnum)
             assert len(datav) == 1
             data = datav[0]
             # We want (seqnum, root_hash, IV) from all servers to know what
@@ -999,29 +1068,14 @@ class Publish:
 
         final_shares = {}
         for shnum in range(total_shares):
-            shc = share_hash_chain[shnum]
-            share_hash_chain_s = "".join([struct.pack(">H32s", i, shc[i])
-                                          for i in sorted(shc.keys())])
-            bht = block_hash_trees[shnum]
-            for h in bht:
-                assert len(h) == 32
-            block_hash_tree_s = "".join(bht)
-            share_data = all_shares[shnum]
-            offsets = pack_offsets(len(verification_key),
-                                   len(signature),
-                                   len(share_hash_chain_s),
-                                   len(block_hash_tree_s),
-                                   len(share_data),
-                                   len(encprivkey))
-
-            final_shares[shnum] = "".join([prefix,
-                                           offsets,
-                                           verification_key,
-                                           signature,
-                                           share_hash_chain_s,
-                                           block_hash_tree_s,
-                                           share_data,
-                                           encprivkey])
+            final_share = pack_share(prefix,
+                                     verification_key,
+                                     signature,
+                                     share_hash_chain[shnum],
+                                     block_hash_trees[shnum],
+                                     all_shares[shnum],
+                                     encprivkey)
+            final_shares[shnum] = final_share
         return (seqnum, root_hash, final_shares, target_info)
 
 
diff --git a/src/allmydata/scripts/debug.py b/src/allmydata/scripts/debug.py
index 9c34d91a..aa302799 100644
--- a/src/allmydata/scripts/debug.py
+++ b/src/allmydata/scripts/debug.py
@@ -179,7 +179,8 @@ def dump_SDMF_share(offset, length, config, out, err):
     print >>out, "  total_shares: %d" % N
     print >>out, "  segsize: %d" % segsize
     print >>out, "  datalen: %d" % datalen
-    share_hash_ids = ",".join([str(hid) for (hid,hash) in share_hash_chain])
+    share_hash_ids = ",".join(sorted([str(hid)
+                                      for hid in share_hash_chain.keys()]))
     print >>out, "  share_hash_chain: %s" % share_hash_ids
     print >>out, "  block_hash_tree: %d nodes" % len(block_hash_tree)
 
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index dccdaa0d..ca4e7ae6 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -264,13 +264,12 @@ class Publish(unittest.TestCase):
                                            k, N, segsize, datalen)
                 self.failUnlessEqual(signature,
                                      FakePrivKey(0).sign(sig_material))
-                self.failUnless(isinstance(share_hash_chain, list))
+                self.failUnless(isinstance(share_hash_chain, dict))
                 self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
-                for i in share_hash_chain:
-                    self.failUnless(isinstance(i, tuple))
-                    self.failUnless(isinstance(i[0], int))
-                    self.failUnless(isinstance(i[1], str))
-                    self.failUnlessEqual(len(i[1]), 32)
+                for shnum,share_hash in share_hash_chain.items():
+                    self.failUnless(isinstance(shnum, int))
+                    self.failUnless(isinstance(share_hash, str))
+                    self.failUnlessEqual(len(share_hash), 32)
                 self.failUnless(isinstance(block_hash_tree, list))
                 self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
                 self.failUnlessEqual(IV, "IV"*8)
diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py
index 7f56c411..d404f9ba 100644
--- a/src/allmydata/test/test_system.py
+++ b/src/allmydata/test/test_system.py
@@ -6,7 +6,7 @@ from twisted.trial import unittest
 from twisted.internet import defer, reactor
 from twisted.internet import threads # CLI tests use deferToThread
 from twisted.application import service
-from allmydata import client, uri, download, upload
+from allmydata import client, uri, download, upload, storage, mutable
 from allmydata.introducer import IntroducerNode
 from allmydata.util import deferredutil, fileutil, idlib, mathutil, testutil
 from allmydata.scripts import runner
@@ -237,6 +237,72 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
         return d
     test_upload_and_download.timeout = 4800
 
+    def _find_shares(self, basedir):
+        shares = []
+        for (dirpath, dirnames, filenames) in os.walk(basedir):
+            if "storage" not in dirpath:
+                continue
+            if not filenames:
+                continue
+            pieces = dirpath.split(os.sep)
+            if pieces[-3] == "storage" and pieces[-2] == "shares":
+                # we're sitting in .../storage/shares/$SINDEX , and there
+                # are sharefiles here
+                assert pieces[-4].startswith("client")
+                client_num = int(pieces[-4][-1])
+                storage_index_s = pieces[-1]
+                storage_index = idlib.a2b(storage_index_s)
+                for sharename in filenames:
+                    shnum = int(sharename)
+                    filename = os.path.join(dirpath, sharename)
+                    data = (client_num, storage_index, filename, shnum)
+                    shares.append(data)
+        if not shares:
+            self.fail("unable to find any share files in %s" % basedir)
+        return shares
+
+    def _corrupt_mutable_share(self, filename, which):
+        msf = storage.MutableShareFile(filename)
+        datav = msf.readv([ (0, 1000000) ])
+        final_share = datav[0]
+        assert len(final_share) < 1000000 # ought to be truncated
+        pieces = mutable.unpack_share(final_share)
+        (seqnum, root_hash, IV, k, N, segsize, datalen,
+         verification_key, signature, share_hash_chain, block_hash_tree,
+         share_data, enc_privkey) = pieces
+
+        if which == "seqnum":
+            seqnum = seqnum + 15
+        elif which == "R":
+            root_hash = self.flip_bit(root_hash)
+        elif which == "IV":
+            IV = self.flip_bit(IV)
+        elif which == "segsize":
+            segsize = segsize + 15
+        elif which == "pubkey":
+            verification_key = self.flip_bit(verification_key)
+        elif which == "signature":
+            signature = self.flip_bit(signature)
+        elif which == "share_hash_chain":
+            nodenum = share_hash_chain.keys()[0]
+            share_hash_chain[nodenum] = self.flip_bit(share_hash_chain[nodenum])
+        elif which == "block_hash_tree":
+            block_hash_tree[-1] = self.flip_bit(block_hash_tree[-1])
+        elif which == "share_data":
+            share_data = self.flip_bit(share_data)
+        elif which == "encprivkey":
+            enc_privkey = self.flip_bit(enc_privkey)
+
+        prefix = mutable.pack_prefix(seqnum, root_hash, IV, k, N,
+                                     segsize, datalen)
+        final_share = mutable.pack_share(prefix,
+                                         verification_key,
+                                         signature,
+                                         share_hash_chain,
+                                         block_hash_tree,
+                                         share_data,
+                                         enc_privkey)
+        msf.writev( [(0, final_share)], None)
 
     def test_mutable(self):
         self.basedir = "system/SystemTest/test_mutable"
@@ -260,22 +326,8 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
         def _test_debug(res):
             # find a share. It is important to run this while there is only
             # one slot in the grid.
-            for (dirpath, dirnames, filenames) in os.walk(self.basedir):
-                if "storage" not in dirpath:
-                    continue
-                if not filenames:
-                    continue
-                pieces = dirpath.split(os.sep)
-                if pieces[-3] == "storage" and pieces[-2] == "shares":
-                    # we're sitting in .../storage/shares/$SINDEX , and there
-                    # are sharefiles here
-                    assert pieces[-4].startswith("client")
-                    client_num = int(pieces[-4][-1])
-                    filename = os.path.join(dirpath, filenames[0])
-                    break
-            else:
-                self.fail("unable to find any share files in %s"
-                          % self.basedir)
+            shares = self._find_shares(self.basedir)
+            (client_num, storage_index, filename, shnum) = shares[0]
             log.msg("test_system.SystemTest.test_mutable._test_debug using %s"
                     % filename)
             log.msg(" for clients[%d]" % client_num)
@@ -367,6 +419,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
             uri = self._mutable_node_1.get_uri()
             newnode1 = self.clients[2].create_node_from_uri(uri)
             newnode2 = self.clients[3].create_node_from_uri(uri)
+            self._newnode3 = self.clients[3].create_node_from_uri(uri)
             log.msg("starting replace2")
             d1 = newnode1.replace(NEWERDATA, wait_for_numpeers=self.numclients)
             d1.addCallback(lambda res: newnode2.download_to_data())
@@ -376,13 +429,13 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
         def _check_download_5(res):
             log.msg("finished replace2")
             self.failUnlessEqual(res, NEWERDATA)
-            # Make sure we can create empty files -- this can screw up the
-            # segsize math.
-            d1 = self.clients[2].create_mutable_file("", wait_for_numpeers=self.numclients)
+            # make sure we can create empty files, this usually screws up the
+            # segsize math
+            d1 = self.clients[2].create_mutable_file("")
             d1.addCallback(lambda newnode: newnode.download_to_data())
             d1.addCallback(lambda res: self.failUnlessEqual("", res))
             return d1
-        d.addCallback(_check_download_5)
+        d.addCallback(_check_empty_file)
 
         d.addCallback(lambda res: self.clients[0].create_empty_dirnode(wait_for_numpeers=self.numclients))
         def _created_dirnode(dnode):
@@ -394,6 +447,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
             d1.addCallback(lambda res: dnode.set_node("see recursive", dnode, wait_for_numpeers=self.numclients))
             d1.addCallback(lambda res: dnode.has_child("see recursive"))
             d1.addCallback(lambda answer: self.failUnlessEqual(answer, True))
+            d1.addCallback(lambda res: dnode.build_manifest())
+            d1.addCallback(lambda manifest:
+                           self.failUnlessEqual(len(manifest), 1))
             return d1
         d.addCallback(_created_dirnode)
 
-- 
2.45.2