]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
add tests for bad/inconsistent plaintext/crypttext merkle tree hashes
authorBrian Warner <warner@allmydata.com>
Fri, 8 Jun 2007 02:32:29 +0000 (19:32 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 8 Jun 2007 02:32:29 +0000 (19:32 -0700)
src/allmydata/download.py
src/allmydata/test/test_encode.py

index f4c6fbee32cf30e4346c9ddcd71b60a47a05ccfd..79bd2e7211f87472cf6dd070b7c8bd31ef6061ae 100644 (file)
@@ -277,6 +277,13 @@ class FileDownloader:
 
         self._thingA_data = None
 
+        self._fetch_failures = {"thingA": 0,
+                                "plaintext_hashroot": 0,
+                                "plaintext_hashtree": 0,
+                                "crypttext_hashroot": 0,
+                                "crypttext_hashtree": 0,
+                                }
+
     def start(self):
         log.msg("starting download [%s]" % idlib.b2a(self._storage_index))
 
@@ -345,6 +352,7 @@ class FileDownloader:
         def _validate(proposal, bucket):
             h = hashtree.thingA_hash(proposal)
             if h != self._thingA_hash:
+                self._fetch_failures["thingA"] += 1
                 msg = ("The copy of thingA we received from %s was bad" %
                        bucket)
                 raise BadThingAHashValue(msg)
@@ -357,14 +365,17 @@ class FileDownloader:
     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
                                 validatorfunc):
         if not sources:
-            raise NotEnoughPeersError("ran out of peers while fetching %s" %
-                                      name)
+            raise NotEnoughPeersError("started with zero peers while fetching "
+                                      "%s" % name)
         bucket = sources[0]
         sources = sources[1:]
         d = bucket.callRemote(methname, *args)
         d.addCallback(validatorfunc, bucket)
         def _bad(f):
             log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD
+            if not sources:
+                raise NotEnoughPeersError("ran out of peers, last error was %s"
+                                          % (f,))
             # try again with a different one
             return self._obtain_validated_thing(None, sources, name,
                                                 methname, args, validatorfunc)
@@ -402,10 +413,20 @@ class FileDownloader:
     def _get_plaintext_hashtrees(self):
         def _validate_plaintext_hashtree(proposal, bucket):
             if proposal[0] != self._thingA_data['plaintext_root_hash']:
+                self._fetch_failures["plaintext_hashroot"] += 1
                 msg = ("The copy of the plaintext_root_hash we received from"
                        " %s was bad" % bucket)
                 raise BadPlaintextHashValue(msg)
-            self._plaintext_hashes = proposal
+            pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
+            pt_hashes = dict(list(enumerate(proposal)))
+            try:
+                pt_hashtree.set_hashes(pt_hashes)
+            except hashtree.BadHashError:
+                # the hashes they gave us were not self-consistent, even
+                # though the root matched what we saw in the thingA block
+                self._fetch_failures["plaintext_hashtree"] += 1
+                raise
+            self._plaintext_hashtree = pt_hashtree
         d = self._obtain_validated_thing(None,
                                          self._thingA_sources,
                                          "plaintext_hashes",
@@ -416,10 +437,19 @@ class FileDownloader:
     def _get_crypttext_hashtrees(self, res):
         def _validate_crypttext_hashtree(proposal, bucket):
             if proposal[0] != self._thingA_data['crypttext_root_hash']:
+                self._fetch_failures["crypttext_hashroot"] += 1
                 msg = ("The copy of the crypttext_root_hash we received from"
                        " %s was bad" % bucket)
                 raise BadCrypttextHashValue(msg)
-            self._crypttext_hashes = proposal
+            ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
+            ct_hashes = dict(list(enumerate(proposal)))
+            try:
+                ct_hashtree.set_hashes(ct_hashes)
+            except hashtree.BadHashError:
+                self._fetch_failures["crypttext_hashtree"] += 1
+                raise
+            ct_hashtree.set_hashes(ct_hashes)
+            self._crypttext_hashtree = ct_hashtree
         d = self._obtain_validated_thing(None,
                                          self._thingA_sources,
                                          "crypttext_hashes",
@@ -428,13 +458,8 @@ class FileDownloader:
         return d
 
     def _setup_hashtrees(self, res):
-        plaintext_hashtree = hashtree.IncompleteHashTree(self._total_segments)
-        plaintext_hashes = dict(list(enumerate(self._plaintext_hashes)))
-        plaintext_hashtree.set_hashes(plaintext_hashes)
-        crypttext_hashtree = hashtree.IncompleteHashTree(self._total_segments)
-        crypttext_hashes = dict(list(enumerate(self._crypttext_hashes)))
-        crypttext_hashtree.set_hashes(crypttext_hashes)
-        self._output.setup_hashtrees(plaintext_hashtree, crypttext_hashtree)
+        self._output.setup_hashtrees(self._plaintext_hashtree,
+                                     self._crypttext_hashtree)
 
 
     def _create_validated_buckets(self, ignored=None):
index d108b9bfa64f72b75ef8f62133a6644979db4ecf..b39751aaade0068ff0e6ff71017908e1d3554a8d 100644 (file)
@@ -96,7 +96,7 @@ class FakeBucketWriter:
         assert not self.closed
         self.closed = True
 
-    def flip_bit(self, good):
+    def flip_bit(self, good): # flips the last bit
         return good[:-1] + chr(ord(good[-1]) ^ 0x01)
 
     def get_block(self, blocknum):
@@ -106,17 +106,20 @@ class FakeBucketWriter:
         return self.blocks[blocknum]
 
     def get_plaintext_hashes(self):
-        if self.mode == "bad plaintexthash":
-            hashes = self.plaintext_hashes[:]
+        hashes = self.plaintext_hashes[:]
+        if self.mode == "bad plaintext hashroot":
+            hashes[0] = self.flip_bit(hashes[0])
+        if self.mode == "bad plaintext hash":
             hashes[1] = self.flip_bit(hashes[1])
-            return hashes
-        return self.plaintext_hashes
+        return hashes
+
     def get_crypttext_hashes(self):
-        if self.mode == "bad crypttexthash":
-            hashes = self.crypttext_hashes[:]
+        hashes = self.crypttext_hashes[:]
+        if self.mode == "bad crypttext hashroot":
+            hashes[0] = self.flip_bit(hashes[0])
+        if self.mode == "bad crypttext hash":
             hashes[1] = self.flip_bit(hashes[1])
-            return hashes
-        return self.crypttext_hashes
+        return hashes
 
     def get_block_hashes(self):
         if self.mode == "bad blockhash":
@@ -136,6 +139,11 @@ class FakeBucketWriter:
             return []
         return self.share_hashes
 
+    def get_thingA(self):
+        if self.mode == "bad thingA":
+            return self.flip_bit(self.thingA)
+        return self.thingA
+
 
 def make_data(length):
     data = "happy happy joy joy" * 100
@@ -250,6 +258,7 @@ class Roundtrip(unittest.TestCase):
                          datalen=76,
                          max_segment_size=25,
                          bucket_modes={},
+                         recover_mode="recover",
                          ):
         if AVAILABLE_SHARES is None:
             AVAILABLE_SHARES = k_and_happy_and_n[2]
@@ -257,10 +266,16 @@ class Roundtrip(unittest.TestCase):
         d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
                       max_segment_size, bucket_modes, data)
         # that fires with (thingA_hash, e, shareholders)
-        d.addCallback(self.recover, AVAILABLE_SHARES)
+        if recover_mode == "recover":
+            d.addCallback(self.recover, AVAILABLE_SHARES)
+        elif recover_mode == "thingA":
+            d.addCallback(self.recover_with_thingA, AVAILABLE_SHARES)
+        else:
+            raise RuntimeError, "unknown recover_mode '%s'" % recover_mode
         # that fires with newdata
-        def _downloaded(newdata):
+        def _downloaded((newdata, fd)):
             self.failUnless(newdata == data)
+            return fd
         d.addCallback(_downloaded)
         return d
 
@@ -305,8 +320,17 @@ class Roundtrip(unittest.TestCase):
         client = None
         target = download.Data()
         fd = download.FileDownloader(client, URI, target)
-        fd.check_verifierid = False
-        fd.check_fileid = False
+
+        # we manually cycle the FileDownloader through a number of steps that
+        # would normally be sequenced by a Deferred chain in
+        # FileDownloader.start(), to give us more control over the process.
+        # In particular, by bypassing _get_all_shareholders, we skip
+        # permuted-peerlist selection.
+        for shnum, bucket in shareholders.items():
+            if shnum < AVAILABLE_SHARES and bucket.closed:
+                fd.add_share_bucket(shnum, bucket)
+        fd._got_all_shareholders(None)
+
         # grab a copy of thingA from one of the shareholders
         thingA = shareholders[0].thingA
         thingA_data = bencode.bdecode(thingA)
@@ -321,13 +345,64 @@ class Roundtrip(unittest.TestCase):
                   'total_shares': e.num_shares,
                   }
         fd._got_thingA(thingA_data)
+        # we skip _get_hashtrees here, and the lack of hashtree attributes
+        # will cause the download.Output object to skip the
+        # plaintext/crypttext merkle tree checks. We instruct the downloader
+        # to skip the full-file checks as well.
+        fd.check_verifierid = False
+        fd.check_fileid = False
+
+        fd._create_validated_buckets(None)
+        d = fd._download_all_segments(None)
+        d.addCallback(fd._done)
+        def _done(newdata):
+            return (newdata, fd)
+        d.addCallback(_done)
+        return d
+
+    def recover_with_thingA(self, (thingA_hash, e, shareholders),
+                            AVAILABLE_SHARES):
+        URI = pack_uri(storage_index="S" * 20,
+                       key=e.key,
+                       thingA_hash=thingA_hash,
+                       needed_shares=e.required_shares,
+                       total_shares=e.num_shares,
+                       size=e.file_size)
+        client = None
+        target = download.Data()
+        fd = download.FileDownloader(client, URI, target)
+
+        # we manually cycle the FileDownloader through a number of steps that
+        # would normally be sequenced by a Deferred chain in
+        # FileDownloader.start(), to give us more control over the process.
+        # In particular, by bypassing _get_all_shareholders, we skip
+        # permuted-peerlist selection.
         for shnum, bucket in shareholders.items():
             if shnum < AVAILABLE_SHARES and bucket.closed:
                 fd.add_share_bucket(shnum, bucket)
         fd._got_all_shareholders(None)
-        fd._create_validated_buckets(None)
-        d = fd._download_all_segments(None)
+
+        # ask shareholders for thingA as usual, validating the responses.
+        # Arrange for shareholders[0] to be the first, so we can selectively
+        # corrupt the data it returns.
+        fd._thingA_sources = shareholders.values()
+        fd._thingA_sources.remove(shareholders[0])
+        fd._thingA_sources.insert(0, shareholders[0])
+        # the thingA block contains plaintext/crypttext hash trees, but does
+        # not have a fileid or verifierid, so we have to disable those checks
+        fd.check_verifierid = False
+        fd.check_fileid = False
+
+        d = fd._obtain_thingA(None)
+        d.addCallback(fd._got_thingA)
+
+        d.addCallback(fd._get_hashtrees)
+        d.addCallback(fd._create_validated_buckets)
+        d.addCallback(fd._download_all_segments)
         d.addCallback(fd._done)
+        def _done(newdata):
+            return (newdata, fd)
+        d.addCallback(_done)
         return d
 
     def test_not_enough_shares(self):
@@ -419,6 +494,77 @@ class Roundtrip(unittest.TestCase):
                           for i in range(6, 10)])
         return self.send_and_recover((4,8,10), bucket_modes=modemap)
 
+    def assertFetchFailureIn(self, fd, where):
+        expected = {"thingA": 0,
+                    "plaintext_hashroot": 0,
+                    "plaintext_hashtree": 0,
+                    "crypttext_hashroot": 0,
+                    "crypttext_hashtree": 0,
+                    }
+        if where is not None:
+            expected[where] += 1
+        self.failUnlessEqual(fd._fetch_failures, expected)
+
+    def test_good_thingA(self):
+        # exercise recover_mode="thingA", just to make sure the test works
+        modemap = dict([(i, "good") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, None)
+        return d
+
+    def test_bad_thingA(self):
+        # the first server has a bad thingA block, so we will fail over to a
+        # different server.
+        modemap = dict([(i, "bad thingA") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, "thingA")
+        return d
+
+    def test_bad_plaintext_hashroot(self):
+        # the first server has a bad plaintext hashroot, so we will fail over
+        # to a different server.
+        modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
+        return d
+
+    def test_bad_crypttext_hashroot(self):
+        # the first server has a bad crypttext hashroot, so we will fail
+        # over to a different server.
+        modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
+        return d
+
+    def test_bad_plaintext_hashes(self):
+        # the first server has a bad plaintext hash block, so we will fail
+        # over to a different server.
+        modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
+        return d
+
+    def test_bad_crypttext_hashes(self):
+        # the first server has a bad crypttext hash block, so we will fail
+        # over to a different server.
+        modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
+                       [(i, "good") for i in range(1, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode="thingA")
+        d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
+        return d
+
+
     def test_bad_sharehashes_failure(self):
         # the first 7 servers have bad block hashes, so the sharehash tree
         # will not validate, and the download will fail