From be5a6147b41d62043abe35ea69d6fce1779f72e8 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 10 Mar 2008 17:46:52 -0700 Subject: [PATCH] test_mutable: test all hash-failure cases except a corrupted encrypted private key --- src/allmydata/mutable.py | 36 ++++--- src/allmydata/test/test_mutable.py | 162 +++++++++++++++++++++++++++-- 2 files changed, 180 insertions(+), 18 deletions(-) diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py index 5eb1b178..188075fc 100644 --- a/src/allmydata/mutable.py +++ b/src/allmydata/mutable.py @@ -49,11 +49,8 @@ SIGNED_PREFIX = ">BQ32s16s BBQQ" # this is covered by the signature HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets HEADER_LENGTH = struct.calcsize(HEADER) -def unpack_prefix_and_signature(data): - assert len(data) >= HEADER_LENGTH +def unpack_header(data): o = {} - prefix = data[:struct.calcsize(SIGNED_PREFIX)] - (version, seqnum, root_hash, @@ -65,6 +62,18 @@ def unpack_prefix_and_signature(data): o['share_data'], o['enc_privkey'], o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH]) + return (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) + +def unpack_prefix_and_signature(data): + assert len(data) >= HEADER_LENGTH + prefix = data[:struct.calcsize(SIGNED_PREFIX)] + + (version, + seqnum, + root_hash, + IV, + k, N, segsize, datalen, + o) = unpack_header(data) assert version == 0 if len(data) < o['share_hash_chain']: @@ -535,7 +544,7 @@ class Retrieve: self._pubkey = self._deserialize_pubkey(pubkey_s) self._node._populate_pubkey(self._pubkey) - verinfo = (seqnum, root_hash, IV, segsize, datalength) + verinfo = (seqnum, root_hash, IV, segsize, datalength) #, k, N) self._status.sharemap[peerid].add(verinfo) if verinfo not in self._valid_versions: @@ -694,12 +703,12 @@ class Retrieve: # arbitrary, really I want this to be something like # k - max(known_version_sharecounts) + some extra break - new_search_distance = max(max(peer_indicies), - self._status.get_search_distance()) - self._status.set_search_distance(new_search_distance) if new_query_peers: self.log("sending %d new queries (read %d bytes)" % (len(new_query_peers), self._read_size), level=log.UNUSUAL) + new_search_distance = max(max(peer_indicies), + self._status.get_search_distance()) + self._status.set_search_distance(new_search_distance) for (peerid, ss) in new_query_peers: self._do_query(ss, peerid, self._storage_index, self._read_size) # we'll retrigger when those queries come back @@ -802,7 +811,8 @@ class Retrieve: try: t2.set_hashes(hashes=share_hash_chain, leaves={shnum: share_hash_leaf}) - except (hashtree.BadHashError, hashtree.NotEnoughHashesError), e: + except (hashtree.BadHashError, hashtree.NotEnoughHashesError, + IndexError), e: msg = "corrupt hashes: %s" % (e,) raise CorruptShareError(peerid, shnum, msg) self.log(" data valid! len=%d" % len(share_data)) @@ -864,16 +874,18 @@ class Retrieve: self._node._populate_root_hash(root_hash) return plaintext - def _done(self, contents): + def _done(self, res): + # res is either the new contents, or a Failure self.log("DONE") self._running = False self._status.set_active(False) self._status.set_status("Done") self._status.set_progress(1.0) - self._status.set_size(len(contents)) + if isinstance(res, str): + self._status.set_size(len(res)) elapsed = time.time() - self._started self._status.timings["total"] = elapsed - eventually(self._done_deferred.callback, contents) + eventually(self._done_deferred.callback, res) def get_status(self): return self._status diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index e1c2cd10..83a78e2e 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -1,5 +1,5 @@ -import itertools, struct +import itertools, struct, re from cStringIO import StringIO from twisted.trial import unittest from twisted.internet import defer @@ -52,6 +52,10 @@ class FakeStorage: # tests to examine and manipulate the published shares. It also lets us # control the order in which read queries are answered, to exercise more # of the error-handling code in mutable.Retrieve . + # + # Note that we ignore the storage index: this FakeStorage instance can + # only be used for a single storage index. + def __init__(self): self._peers = {} @@ -177,6 +181,12 @@ class FakePubKey: def serialize(self): return "PUBKEY-%d" % self.count def verify(self, msg, signature): + if signature[:5] != "SIGN(": + return False + if signature[5:-1] != msg: + return False + if signature[-1] != ")": + return False return True class FakePrivKey: @@ -433,6 +443,14 @@ class FakeRetrieve(mutable.Retrieve): vector.append(shares[shnum][offset:offset+length]) return defer.succeed(response) + def _deserialize_pubkey(self, pubkey_s): + mo = re.search(r"^PUBKEY-(\d+)$", pubkey_s) + if not mo: + raise RuntimeError("mangled pubkey") + count = mo.group(1) + return FakePubKey(int(count)) + + class Roundtrip(unittest.TestCase): def setup_for_publish(self, num_peers): @@ -444,16 +462,15 @@ class Roundtrip(unittest.TestCase): fn.create("") p = FakePublish(fn) p._storage = s - return c, fn, p + r = FakeRetrieve(fn) + r._storage = s + return c, s, fn, p, r def test_basic(self): - c, fn, p = self.setup_for_publish(20) + c, s, fn, p, r = self.setup_for_publish(20) contents = "New contents go here" d = p.publish(contents) def _published(res): - # TODO: examine peers and check on their shares - r = FakeRetrieve(fn) - r._storage = p._storage return r.retrieve() d.addCallback(_published) def _retrieved(new_contents): @@ -461,3 +478,136 @@ class Roundtrip(unittest.TestCase): d.addCallback(_retrieved) return d + def flip_bit(self, original, byte_offset): + return (original[:byte_offset] + + chr(ord(original[byte_offset]) ^ 0x01) + + original[byte_offset+1:]) + + + def shouldFail(self, expected_failure, which, substring, + callable, *args, **kwargs): + assert substring is None or isinstance(substring, str) + d = defer.maybeDeferred(callable, *args, **kwargs) + def done(res): + if isinstance(res, failure.Failure): + res.trap(expected_failure) + if substring: + self.failUnless(substring in str(res), + "substring '%s' not in '%s'" + % (substring, str(res))) + else: + self.fail("%s was supposed to raise %s, not get '%s'" % + (which, expected_failure, res)) + d.addBoth(done) + return d + + def _corrupt_all(self, offset, substring, refetch_pubkey=False, + should_succeed=False): + c, s, fn, p, r = self.setup_for_publish(20) + contents = "New contents go here" + d = p.publish(contents) + def _published(res): + if refetch_pubkey: + # clear the pubkey, to force a fetch + r._pubkey = None + for peerid in s._peers: + shares = s._peers[peerid] + for shnum in shares: + data = shares[shnum] + (version, + seqnum, + root_hash, + IV, + k, N, segsize, datalen, + o) = mutable.unpack_header(data) + if isinstance(offset, tuple): + offset1, offset2 = offset + else: + offset1 = offset + offset2 = 0 + if offset1 == "pubkey": + real_offset = 107 + elif offset1 in o: + real_offset = o[offset1] + else: + real_offset = offset1 + real_offset = int(real_offset) + offset2 + assert isinstance(real_offset, int), offset + shares[shnum] = self.flip_bit(data, real_offset) + d.addCallback(_published) + if should_succeed: + d.addCallback(lambda res: r.retrieve()) + else: + d.addCallback(lambda res: + self.shouldFail(NotEnoughPeersError, + "_corrupt_all(offset=%s)" % (offset,), + substring, + r.retrieve)) + return d + + def test_corrupt_all_verbyte(self): + # when the version byte is not 0, we hit an assertion error in + # unpack_share(). + return self._corrupt_all(0, "AssertionError") + + def test_corrupt_all_seqnum(self): + # a corrupt sequence number will trigger a bad signature + return self._corrupt_all(1, "signature is invalid") + + def test_corrupt_all_R(self): + # a corrupt root hash will trigger a bad signature + return self._corrupt_all(9, "signature is invalid") + + def test_corrupt_all_IV(self): + # a corrupt salt/IV will trigger a bad signature + return self._corrupt_all(41, "signature is invalid") + + def test_corrupt_all_k(self): + # a corrupt 'k' will trigger a bad signature + return self._corrupt_all(57, "signature is invalid") + + def test_corrupt_all_N(self): + # a corrupt 'N' will trigger a bad signature + return self._corrupt_all(58, "signature is invalid") + + def test_corrupt_all_segsize(self): + # a corrupt segsize will trigger a bad signature + return self._corrupt_all(59, "signature is invalid") + + def test_corrupt_all_datalen(self): + # a corrupt data length will trigger a bad signature + return self._corrupt_all(67, "signature is invalid") + + def test_corrupt_all_pubkey(self): + # a corrupt pubkey won't match the URI's fingerprint + return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint", + refetch_pubkey=True) + + def test_corrupt_all_sig(self): + # a corrupt signature is a bad one + # the signature runs from about [543:799], depending upon the length + # of the pubkey + return self._corrupt_all("signature", "signature is invalid", + refetch_pubkey=True) + + def test_corrupt_all_share_hash_chain_number(self): + # a corrupt share hash chain entry will show up as a bad hash. If we + # mangle the first byte, that will look like a bad hash number, + # causing an IndexError + return self._corrupt_all("share_hash_chain", "corrupt hashes") + + def test_corrupt_all_share_hash_chain_hash(self): + # a corrupt share hash chain entry will show up as a bad hash. If we + # mangle a few bytes in, that will look like a bad hash. + return self._corrupt_all(("share_hash_chain",4), "corrupt hashes") + + def test_corrupt_all_block_hash_tree(self): + return self._corrupt_all("block_hash_tree", "block hash tree failure") + + def test_corrupt_all_block(self): + return self._corrupt_all("share_data", "block hash tree failure") + + def test_corrupt_all_encprivkey(self): + # a corrupted privkey won't even be noticed by the reader + return self._corrupt_all("enc_privkey", None, should_succeed=True) + -- 2.45.2