]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
test_encode.py: even more testing of merkle trees, getting fairly comprehensive now
authorBrian Warner <warner@allmydata.com>
Fri, 8 Jun 2007 04:24:39 +0000 (21:24 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 8 Jun 2007 04:24:39 +0000 (21:24 -0700)
src/allmydata/encode.py
src/allmydata/test/test_encode.py

index 83126fe50d745a240eb241c980288de25d9daa64..7879955e24eebfc6b295c36d42658786b092a816 100644 (file)
@@ -122,6 +122,8 @@ class Encoder(object):
 
         data['size'] = self.file_size
         data['segment_size'] = self.segment_size
+        data['num_segments'] = mathutil.div_ceil(self.file_size,
+                                                 self.segment_size)
         data['needed_shares'] = self.required_shares
         data['total_shares'] = self.num_shares
 
index b39751aaade0068ff0e6ff71017908e1d3554a8d..fd5bfa4d77edcc82e7a39c5e00afbccd42d637b1 100644 (file)
@@ -1,14 +1,18 @@
-#! /usr/bin/env python
 
 from twisted.trial import unittest
 from twisted.internet import defer
 from twisted.python.failure import Failure
 from foolscap import eventual
-from allmydata import encode, download
-from allmydata.util import bencode
+from allmydata import encode, download, hashtree
+from allmydata.util import hashutil
 from allmydata.uri import pack_uri
+from allmydata.Crypto.Cipher import AES
+import sha
 from cStringIO import StringIO
 
+def netstring(s):
+    return "%d:%s," % (len(s), s)
+
 class FakePeer:
     def __init__(self, mode="good"):
         self.ss = FakeStorageServer(mode)
@@ -44,6 +48,9 @@ class FakeStorageServer:
 class LostPeerError(Exception):
     pass
 
+def flip_bit(good): # flips the last bit
+    return good[:-1] + chr(ord(good[-1]) ^ 0x01)
+
 class FakeBucketWriter:
     # these are used for both reading and writing
     def __init__(self, mode="good"):
@@ -96,41 +103,38 @@ class FakeBucketWriter:
         assert not self.closed
         self.closed = True
 
-    def flip_bit(self, good): # flips the last bit
-        return good[:-1] + chr(ord(good[-1]) ^ 0x01)
-
     def get_block(self, blocknum):
         assert isinstance(blocknum, (int, long))
         if self.mode == "bad block":
-            return self.flip_bit(self.blocks[blocknum])
+            return flip_bit(self.blocks[blocknum])
         return self.blocks[blocknum]
 
     def get_plaintext_hashes(self):
         hashes = self.plaintext_hashes[:]
         if self.mode == "bad plaintext hashroot":
-            hashes[0] = self.flip_bit(hashes[0])
+            hashes[0] = flip_bit(hashes[0])
         if self.mode == "bad plaintext hash":
-            hashes[1] = self.flip_bit(hashes[1])
+            hashes[1] = flip_bit(hashes[1])
         return hashes
 
     def get_crypttext_hashes(self):
         hashes = self.crypttext_hashes[:]
         if self.mode == "bad crypttext hashroot":
-            hashes[0] = self.flip_bit(hashes[0])
+            hashes[0] = flip_bit(hashes[0])
         if self.mode == "bad crypttext hash":
-            hashes[1] = self.flip_bit(hashes[1])
+            hashes[1] = flip_bit(hashes[1])
         return hashes
 
     def get_block_hashes(self):
         if self.mode == "bad blockhash":
             hashes = self.block_hashes[:]
-            hashes[1] = self.flip_bit(hashes[1])
+            hashes[1] = flip_bit(hashes[1])
             return hashes
         return self.block_hashes
     def get_share_hashes(self):
         if self.mode == "bad sharehash":
             hashes = self.share_hashes[:]
-            hashes[1] = (hashes[1][0], self.flip_bit(hashes[1][1]))
+            hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
             return hashes
         if self.mode == "missing sharehash":
             # one sneaky attack would be to pretend we don't know our own
@@ -141,7 +145,7 @@ class FakeBucketWriter:
 
     def get_thingA(self):
         if self.mode == "bad thingA":
-            return self.flip_bit(self.thingA)
+            return flip_bit(self.thingA)
         return self.thingA
 
 
@@ -266,12 +270,7 @@ class Roundtrip(unittest.TestCase):
         d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
                       max_segment_size, bucket_modes, data)
         # that fires with (thingA_hash, e, shareholders)
-        if recover_mode == "recover":
-            d.addCallback(self.recover, AVAILABLE_SHARES)
-        elif recover_mode == "thingA":
-            d.addCallback(self.recover_with_thingA, AVAILABLE_SHARES)
-        else:
-            raise RuntimeError, "unknown recover_mode '%s'" % recover_mode
+        d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
         # that fires with newdata
         def _downloaded((newdata, fd)):
             self.failUnless(newdata == data)
@@ -301,8 +300,15 @@ class Roundtrip(unittest.TestCase):
             peer = FakeBucketWriter(mode)
             shareholders[shnum] = peer
         e.set_shareholders(shareholders)
-        e.set_thingA_data({'verifierid': "V" * 20,
-                           'fileid': "F" * 20,
+        fileid_hasher = sha.new(netstring("allmydata_fileid_v1"))
+        fileid_hasher.update(data)
+        cryptor = AES.new(key=nonkey, mode=AES.MODE_CTR,
+                          counterstart="\x00"*16)
+        verifierid_hasher = sha.new(netstring("allmydata_verifierid_v1"))
+        verifierid_hasher.update(cryptor.encrypt(data))
+
+        e.set_thingA_data({'verifierid': verifierid_hasher.digest(),
+                           'fileid': fileid_hasher.digest(),
                            })
         d = e.start()
         def _sent(thingA_hash):
@@ -310,60 +316,14 @@ class Roundtrip(unittest.TestCase):
         d.addCallback(_sent)
         return d
 
-    def recover(self, (thingA_hash, e, shareholders), AVAILABLE_SHARES):
-        URI = pack_uri(storage_index="S" * 20,
-                       key=e.key,
-                       thingA_hash=thingA_hash,
-                       needed_shares=e.required_shares,
-                       total_shares=e.num_shares,
-                       size=e.file_size)
-        client = None
-        target = download.Data()
-        fd = download.FileDownloader(client, URI, target)
+    def recover(self, (thingA_hash, e, shareholders), AVAILABLE_SHARES,
+                recover_mode):
+        key = e.key
+        if "corrupt_key" in recover_mode:
+            key = flip_bit(key)
 
-        # we manually cycle the FileDownloader through a number of steps that
-        # would normally be sequenced by a Deferred chain in
-        # FileDownloader.start(), to give us more control over the process.
-        # In particular, by bypassing _get_all_shareholders, we skip
-        # permuted-peerlist selection.
-        for shnum, bucket in shareholders.items():
-            if shnum < AVAILABLE_SHARES and bucket.closed:
-                fd.add_share_bucket(shnum, bucket)
-        fd._got_all_shareholders(None)
-
-        # grab a copy of thingA from one of the shareholders
-        thingA = shareholders[0].thingA
-        thingA_data = bencode.bdecode(thingA)
-        NOTthingA = {'codec_name': e._codec.get_encoder_type(),
-                  'codec_params': e._codec.get_serialized_params(),
-                  'tail_codec_params': e._tail_codec.get_serialized_params(),
-                  'verifierid': "V" * 20,
-                  'fileid': "F" * 20,
-                     #'share_root_hash': roothash,
-                  'segment_size': e.segment_size,
-                  'needed_shares': e.required_shares,
-                  'total_shares': e.num_shares,
-                  }
-        fd._got_thingA(thingA_data)
-        # we skip _get_hashtrees here, and the lack of hashtree attributes
-        # will cause the download.Output object to skip the
-        # plaintext/crypttext merkle tree checks. We instruct the downloader
-        # to skip the full-file checks as well.
-        fd.check_verifierid = False
-        fd.check_fileid = False
-
-        fd._create_validated_buckets(None)
-        d = fd._download_all_segments(None)
-        d.addCallback(fd._done)
-        def _done(newdata):
-            return (newdata, fd)
-        d.addCallback(_done)
-        return d
-
-    def recover_with_thingA(self, (thingA_hash, e, shareholders),
-                            AVAILABLE_SHARES):
         URI = pack_uri(storage_index="S" * 20,
-                       key=e.key,
+                       key=key,
                        thingA_hash=thingA_hash,
                        needed_shares=e.required_shares,
                        total_shares=e.num_shares,
@@ -382,21 +342,39 @@ class Roundtrip(unittest.TestCase):
                 fd.add_share_bucket(shnum, bucket)
         fd._got_all_shareholders(None)
 
-        # ask shareholders for thingA as usual, validating the responses.
-        # Arrange for shareholders[0] to be the first, so we can selectively
-        # corrupt the data it returns.
+        # Make it possible to obtain thingA from the shareholders. Arrange
+        # for shareholders[0] to be the first, so we can selectively corrupt
+        # the data it returns.
         fd._thingA_sources = shareholders.values()
         fd._thingA_sources.remove(shareholders[0])
         fd._thingA_sources.insert(0, shareholders[0])
-        # the thingA block contains plaintext/crypttext hash trees, but does
-        # not have a fileid or verifierid, so we have to disable those checks
-        fd.check_verifierid = False
-        fd.check_fileid = False
 
-        d = fd._obtain_thingA(None)
+        d = defer.succeed(None)
+
+        # have the FileDownloader retrieve a copy of thingA itself
+        d.addCallback(fd._obtain_thingA)
+
+        if "corrupt_crypttext_hashes" in recover_mode:
+            # replace everybody's crypttext hash trees with a different one
+            # (computed over a different file), then modify our thingA to
+            # reflect the new crypttext hash tree root
+            def _corrupt_crypttext_hashes(thingA):
+                assert isinstance(thingA, dict)
+                assert 'crypttext_root_hash' in thingA
+                badhash = hashutil.tagged_hash("bogus", "data")
+                bad_crypttext_hashes = [badhash] * thingA['num_segments']
+                badtree = hashtree.HashTree(bad_crypttext_hashes)
+                for bucket in shareholders.values():
+                    bucket.crypttext_hashes = list(badtree)
+                thingA['crypttext_root_hash'] = badtree[0]
+                return thingA
+            d.addCallback(_corrupt_crypttext_hashes)
+
         d.addCallback(fd._got_thingA)
 
+        # also have the FileDownloader ask for hash trees
         d.addCallback(fd._get_hashtrees)
+
         d.addCallback(fd._create_validated_buckets)
         d.addCallback(fd._download_all_segments)
         d.addCallback(fd._done)
@@ -505,12 +483,11 @@ class Roundtrip(unittest.TestCase):
             expected[where] += 1
         self.failUnlessEqual(fd._fetch_failures, expected)
 
-    def test_good_thingA(self):
-        # exercise recover_mode="thingA", just to make sure the test works
-        modemap = dict([(i, "good") for i in range(1)] +
-                       [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+    def test_good(self):
+        # just to make sure the test harness works when we aren't
+        # intentionally causing failures
+        modemap = dict([(i, "good") for i in range(0, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, None)
         return d
 
@@ -519,8 +496,7 @@ class Roundtrip(unittest.TestCase):
         # different server.
         modemap = dict([(i, "bad thingA") for i in range(1)] +
                        [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, "thingA")
         return d
 
@@ -529,8 +505,7 @@ class Roundtrip(unittest.TestCase):
         # to a different server.
         modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
                        [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
         return d
 
@@ -539,8 +514,7 @@ class Roundtrip(unittest.TestCase):
         # over to a different server.
         modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
                        [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
         return d
 
@@ -549,8 +523,7 @@ class Roundtrip(unittest.TestCase):
         # over to a different server.
         modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
                        [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
         return d
 
@@ -559,11 +532,39 @@ class Roundtrip(unittest.TestCase):
         # over to a different server.
         modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
                        [(i, "good") for i in range(1, 10)])
-        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
-                                  recover_mode="thingA")
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap)
         d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
         return d
 
+    def test_bad_crypttext_hashes_failure(self):
+        # to test that the crypttext merkle tree is really being applied, we
+        # sneak into the download process and corrupt two things: we replace
+        # everybody's crypttext hashtree with a bad version (computed over
+        # bogus data), and we modify the supposedly-validated thingA block to
+        # match the new crypttext hashtree root. The download process should
+        # notice that the crypttext coming out of FEC doesn't match the tree,
+        # and fail.
+
+        modemap = dict([(i, "good") for i in range(0, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode=("corrupt_crypttext_hashes"))
+        def _done(res):
+            self.failUnless(isinstance(res, Failure))
+            self.failUnless(res.check(hashtree.BadHashError), res)
+        d.addBoth(_done)
+        return d
+
+
+    def test_bad_plaintext(self):
+        # faking a decryption failure is easier: just corrupt the key
+        modemap = dict([(i, "good") for i in range(0, 10)])
+        d = self.send_and_recover((4,8,10), bucket_modes=modemap,
+                                  recover_mode=("corrupt_key"))
+        def _done(res):
+            self.failUnless(isinstance(res, Failure))
+            self.failUnless(res.check(hashtree.BadHashError))
+        d.addBoth(_done)
+        return d
 
     def test_bad_sharehashes_failure(self):
         # the first 7 servers have bad block hashes, so the sharehash tree