From be5a6147b41d62043abe35ea69d6fce1779f72e8 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Mon, 10 Mar 2008 17:46:52 -0700
Subject: [PATCH] test_mutable: test all hash-failure cases except a corrupted
 encrypted private key

---
 src/allmydata/mutable.py           |  36 ++++---
 src/allmydata/test/test_mutable.py | 162 +++++++++++++++++++++++++++--
 2 files changed, 180 insertions(+), 18 deletions(-)

diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py
index 5eb1b178..188075fc 100644
--- a/src/allmydata/mutable.py
+++ b/src/allmydata/mutable.py
@@ -49,11 +49,8 @@ SIGNED_PREFIX = ">BQ32s16s BBQQ" # this is covered by the signature
 HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets
 HEADER_LENGTH = struct.calcsize(HEADER)
 
-def unpack_prefix_and_signature(data):
-    assert len(data) >= HEADER_LENGTH
+def unpack_header(data):
     o = {}
-    prefix = data[:struct.calcsize(SIGNED_PREFIX)]
-
     (version,
      seqnum,
      root_hash,
@@ -65,6 +62,18 @@ def unpack_prefix_and_signature(data):
      o['share_data'],
      o['enc_privkey'],
      o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH])
+    return (version, seqnum, root_hash, IV, k, N, segsize, datalen, o)
+
+def unpack_prefix_and_signature(data):
+    assert len(data) >= HEADER_LENGTH
+    prefix = data[:struct.calcsize(SIGNED_PREFIX)]
+
+    (version,
+     seqnum,
+     root_hash,
+     IV,
+     k, N, segsize, datalen,
+     o) = unpack_header(data)
 
     assert version == 0
     if len(data) < o['share_hash_chain']:
@@ -535,7 +544,7 @@ class Retrieve:
             self._pubkey = self._deserialize_pubkey(pubkey_s)
             self._node._populate_pubkey(self._pubkey)
 
-        verinfo = (seqnum, root_hash, IV, segsize, datalength)
+        verinfo = (seqnum, root_hash, IV, segsize, datalength) #, k, N)
         self._status.sharemap[peerid].add(verinfo)
 
         if verinfo not in self._valid_versions:
@@ -694,12 +703,12 @@ class Retrieve:
                     # arbitrary, really I want this to be something like
                     # k - max(known_version_sharecounts) + some extra
                     break
-        new_search_distance = max(max(peer_indicies),
-                                  self._status.get_search_distance())
-        self._status.set_search_distance(new_search_distance)
         if new_query_peers:
             self.log("sending %d new queries (read %d bytes)" %
                      (len(new_query_peers), self._read_size), level=log.UNUSUAL)
+            new_search_distance = max(max(peer_indicies),
+                                      self._status.get_search_distance())
+            self._status.set_search_distance(new_search_distance)
             for (peerid, ss) in new_query_peers:
                 self._do_query(ss, peerid, self._storage_index, self._read_size)
             # we'll retrigger when those queries come back
@@ -802,7 +811,8 @@ class Retrieve:
         try:
             t2.set_hashes(hashes=share_hash_chain,
                           leaves={shnum: share_hash_leaf})
-        except (hashtree.BadHashError, hashtree.NotEnoughHashesError), e:
+        except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
+                IndexError), e:
             msg = "corrupt hashes: %s" % (e,)
             raise CorruptShareError(peerid, shnum, msg)
         self.log(" data valid! len=%d" % len(share_data))
@@ -864,16 +874,18 @@ class Retrieve:
         self._node._populate_root_hash(root_hash)
         return plaintext
 
-    def _done(self, contents):
+    def _done(self, res):
+        # res is either the new contents, or a Failure
         self.log("DONE")
         self._running = False
         self._status.set_active(False)
         self._status.set_status("Done")
         self._status.set_progress(1.0)
-        self._status.set_size(len(contents))
+        if isinstance(res, str):
+            self._status.set_size(len(res))
         elapsed = time.time() - self._started
         self._status.timings["total"] = elapsed
-        eventually(self._done_deferred.callback, contents)
+        eventually(self._done_deferred.callback, res)
 
     def get_status(self):
         return self._status
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index e1c2cd10..83a78e2e 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -1,5 +1,5 @@
 
-import itertools, struct
+import itertools, struct, re
 from cStringIO import StringIO
 from twisted.trial import unittest
 from twisted.internet import defer
@@ -52,6 +52,10 @@ class FakeStorage:
     # tests to examine and manipulate the published shares. It also lets us
     # control the order in which read queries are answered, to exercise more
     # of the error-handling code in mutable.Retrieve .
+    #
+    # Note that we ignore the storage index: this FakeStorage instance can
+    # only be used for a single storage index.
+
 
     def __init__(self):
         self._peers = {}
@@ -177,6 +181,12 @@ class FakePubKey:
     def serialize(self):
         return "PUBKEY-%d" % self.count
     def verify(self, msg, signature):
+        if signature[:5] != "SIGN(":
+            return False
+        if signature[5:-1] != msg:
+            return False
+        if signature[-1] != ")":
+            return False
         return True
 
 class FakePrivKey:
@@ -433,6 +443,14 @@ class FakeRetrieve(mutable.Retrieve):
                 vector.append(shares[shnum][offset:offset+length])
         return defer.succeed(response)
 
+    def _deserialize_pubkey(self, pubkey_s):
+        mo = re.search(r"^PUBKEY-(\d+)$", pubkey_s)
+        if not mo:
+            raise RuntimeError("mangled pubkey")
+        count = mo.group(1)
+        return FakePubKey(int(count))
+
+
 class Roundtrip(unittest.TestCase):
 
     def setup_for_publish(self, num_peers):
@@ -444,16 +462,15 @@ class Roundtrip(unittest.TestCase):
         fn.create("")
         p = FakePublish(fn)
         p._storage = s
-        return c, fn, p
+        r = FakeRetrieve(fn)
+        r._storage = s
+        return c, s, fn, p, r
 
     def test_basic(self):
-        c, fn, p = self.setup_for_publish(20)
+        c, s, fn, p, r = self.setup_for_publish(20)
         contents = "New contents go here"
         d = p.publish(contents)
         def _published(res):
-            # TODO: examine peers and check on their shares
-            r = FakeRetrieve(fn)
-            r._storage = p._storage
             return r.retrieve()
         d.addCallback(_published)
         def _retrieved(new_contents):
@@ -461,3 +478,136 @@ class Roundtrip(unittest.TestCase):
         d.addCallback(_retrieved)
         return d
 
+    def flip_bit(self, original, byte_offset):
+        return (original[:byte_offset] +
+                chr(ord(original[byte_offset]) ^ 0x01) +
+                original[byte_offset+1:])
+
+
+    def shouldFail(self, expected_failure, which, substring,
+                    callable, *args, **kwargs):
+        assert substring is None or isinstance(substring, str)
+        d = defer.maybeDeferred(callable, *args, **kwargs)
+        def done(res):
+            if isinstance(res, failure.Failure):
+                res.trap(expected_failure)
+                if substring:
+                    self.failUnless(substring in str(res),
+                                    "substring '%s' not in '%s'"
+                                    % (substring, str(res)))
+            else:
+                self.fail("%s was supposed to raise %s, not get '%s'" %
+                          (which, expected_failure, res))
+        d.addBoth(done)
+        return d
+
+    def _corrupt_all(self, offset, substring, refetch_pubkey=False,
+                     should_succeed=False):
+        c, s, fn, p, r = self.setup_for_publish(20)
+        contents = "New contents go here"
+        d = p.publish(contents)
+        def _published(res):
+            if refetch_pubkey:
+                # clear the pubkey, to force a fetch
+                r._pubkey = None
+            for peerid in s._peers:
+                shares = s._peers[peerid]
+                for shnum in shares:
+                    data = shares[shnum]
+                    (version,
+                     seqnum,
+                     root_hash,
+                     IV,
+                     k, N, segsize, datalen,
+                     o) = mutable.unpack_header(data)
+                    if isinstance(offset, tuple):
+                        offset1, offset2 = offset
+                    else:
+                        offset1 = offset
+                        offset2 = 0
+                    if offset1 == "pubkey":
+                        real_offset = 107
+                    elif offset1 in o:
+                        real_offset = o[offset1]
+                    else:
+                        real_offset = offset1
+                    real_offset = int(real_offset) + offset2
+                    assert isinstance(real_offset, int), offset
+                    shares[shnum] = self.flip_bit(data, real_offset)
+        d.addCallback(_published)
+        if should_succeed:
+            d.addCallback(lambda res: r.retrieve())
+        else:
+            d.addCallback(lambda res:
+                          self.shouldFail(NotEnoughPeersError,
+                                          "_corrupt_all(offset=%s)" % (offset,),
+                                          substring,
+                                          r.retrieve))
+        return d
+
+    def test_corrupt_all_verbyte(self):
+        # when the version byte is not 0, we hit an assertion error in
+        # unpack_share().
+        return self._corrupt_all(0, "AssertionError")
+
+    def test_corrupt_all_seqnum(self):
+        # a corrupt sequence number will trigger a bad signature
+        return self._corrupt_all(1, "signature is invalid")
+
+    def test_corrupt_all_R(self):
+        # a corrupt root hash will trigger a bad signature
+        return self._corrupt_all(9, "signature is invalid")
+
+    def test_corrupt_all_IV(self):
+        # a corrupt salt/IV will trigger a bad signature
+        return self._corrupt_all(41, "signature is invalid")
+
+    def test_corrupt_all_k(self):
+        # a corrupt 'k' will trigger a bad signature
+        return self._corrupt_all(57, "signature is invalid")
+
+    def test_corrupt_all_N(self):
+        # a corrupt 'N' will trigger a bad signature
+        return self._corrupt_all(58, "signature is invalid")
+
+    def test_corrupt_all_segsize(self):
+        # a corrupt segsize will trigger a bad signature
+        return self._corrupt_all(59, "signature is invalid")
+
+    def test_corrupt_all_datalen(self):
+        # a corrupt data length will trigger a bad signature
+        return self._corrupt_all(67, "signature is invalid")
+
+    def test_corrupt_all_pubkey(self):
+        # a corrupt pubkey won't match the URI's fingerprint
+        return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint",
+                                 refetch_pubkey=True)
+
+    def test_corrupt_all_sig(self):
+        # a corrupt signature is a bad one
+        # the signature runs from about [543:799], depending upon the length
+        # of the pubkey
+        return self._corrupt_all("signature", "signature is invalid",
+                                 refetch_pubkey=True)
+
+    def test_corrupt_all_share_hash_chain_number(self):
+        # a corrupt share hash chain entry will show up as a bad hash. If we
+        # mangle the first byte, that will look like a bad hash number,
+        # causing an IndexError
+        return self._corrupt_all("share_hash_chain", "corrupt hashes")
+
+    def test_corrupt_all_share_hash_chain_hash(self):
+        # a corrupt share hash chain entry will show up as a bad hash. If we
+        # mangle a few bytes in, that will look like a bad hash.
+        return self._corrupt_all(("share_hash_chain",4), "corrupt hashes")
+
+    def test_corrupt_all_block_hash_tree(self):
+        return self._corrupt_all("block_hash_tree", "block hash tree failure")
+
+    def test_corrupt_all_block(self):
+        return self._corrupt_all("share_data", "block hash tree failure")
+
+    def test_corrupt_all_encprivkey(self):
+        # a corrupted privkey won't even be noticed by the reader
+        return self._corrupt_all("enc_privkey", None, should_succeed=True)
+
-- 
2.45.2