From: Brian Warner Date: Thu, 12 Apr 2007 20:07:40 +0000 (-0700) Subject: verify hash chains on incoming blocks X-Git-Tag: tahoe_v0.1.0-0-UNSTABLE~111 X-Git-Url: https://git.rkrishnan.org/simplejson/components/%22news.html//%22?a=commitdiff_plain;h=8f58b30db9430cbe3208930b421eb3407270d907;p=tahoe-lafs%2Ftahoe-lafs.git verify hash chains on incoming blocks Implement enough of chunk.IncompleteHashTree to be usable. Rearrange download: all block/hash requests now go through a ValidatedBucket instance, which is responsible for retrieving and verifying hashes before providing validated data. Download was changed to use ValidatedBuckets everywhere instead of unwrapped RIBucketReader references. --- diff --git a/src/allmydata/chunk.py b/src/allmydata/chunk.py index a797c464..3e4e09ed 100644 --- a/src/allmydata/chunk.py +++ b/src/allmydata/chunk.py @@ -48,6 +48,7 @@ or implied. It probably won't make your computer catch on fire, or eat your children, but it might. Use at your own risk. """ +from allmydata.util import idlib from allmydata.util.hashutil import tagged_hash, tagged_pair_hash __version__ = '1.0.0-allmydata' @@ -134,6 +135,25 @@ class CompleteBinaryTreeMixin: here = self.parent(here) return needed + def depth_first(self, i=0): + yield i, 0 + try: + for child,childdepth in self.depth_first(self.lchild(i)): + yield child, childdepth+1 + except IndexError: + pass + try: + for child,childdepth in self.depth_first(self.rchild(i)): + yield child, childdepth+1 + except IndexError: + pass + + def dump(self): + lines = [] + for i,depth in self.depth_first(): + lines.append("%s%3d: %s" % (" "*depth, i, idlib.b2a_or_none(self[i]))) + return "\n".join(lines) + "\n" + def empty_leaf_hash(i): return tagged_hash('Merkle tree empty leaf', "%d" % i) def pair_hash(a, b): @@ -225,49 +245,136 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list): rows.reverse() self[:] = sum(rows, []) - def needed_hashes(self, leafnum): - hashnum = self.first_leaf_num + leafnum - maybe_needed = self.needed_for(hashnum) - maybe_needed += [0] # need the root too + def needed_hashes(self, hashes=[], leaves=[]): + hashnums = set(list(hashes)) + for leafnum in leaves: + hashnums.add(self.first_leaf_num + leafnum) + maybe_needed = set() + for hashnum in hashnums: + maybe_needed.update(self.needed_for(hashnum)) + maybe_needed.add(0) # need the root too return set([i for i in maybe_needed if self[i] is None]) - def set_hash(self, i, newhash): - # note that we don't attempt to validate these - self[i] = newhash - - def set_leaf(self, leafnum, leafhash): - hashnum = self.first_leaf_num + leafnum - needed = self.needed_hashes(leafnum) - if needed: - msg = "we need hashes " + ",".join([str(i) for i in needed]) - raise NotEnoughHashesError(msg) - assert self[0] is not None # we can't validate without a root - added = set() # we'll remove these if the check fails - self[hashnum] = leafhash - added.add(hashnum) - # now propagate hash checks upwards until we reach the root - here = hashnum - while here != 0: - us = [here, self.sibling(here)] - us.sort() - leftnum, rightnum = us - lefthash = self[leftnum] - righthash = self[rightnum] - parent = self.parent(here) - parenthash = self[parent] - - ourhash = pair_hash(lefthash, righthash) - if parenthash is not None: - if ourhash != parenthash: - for i in added: - self[i] = None - raise BadHashError("h([%d]+[%d]) != h[%d]" % (leftnum, rightnum, - parent)) - else: - self[parent] = ourhash - added.add(parent) - here = self.parent(here) + def set_hashes(self, hashes={}, leaves={}, must_validate=False): + """Add a bunch of hashes to the tree. + + I will validate these to the best of my ability. If I already have a copy + of any of the new hashes, the new values must equal the existing ones, or + I will raise BadHashError. If adding a hash allows me to compute a parent + hash, those parent hashes must match or I will raise BadHashError. If I + raise BadHashError, I will forget about all the hashes that you tried to + add, leaving my state exactly the same as before I was called. If I + return successfully, I will remember all those hashes. + + If every hash that was added was validated, I will return True. If some + could not be validated because I did not have enough parent hashes, I + will return False. As a result, if I am called with both a leaf hash and + the root hash was already set, I will return True if and only if the leaf + hash could be validated against the root. + + If must_validate is True, I will raise NotEnoughHashesError instead of + returning False. If I raise NotEnoughHashesError, I will forget about all + the hashes that you tried to add. TODO: really? + + 'leaves' is a dictionary uses 'leaf index' values, which range from 0 + (the left-most leaf) to num_leaves-1 (the right-most leaf), and form the + base of the tree. 'hashes' uses 'hash_index' values, which range from 0 + (the root of the tree) to 2*num_leaves-2 (the right-most leaf). leaf[i] + is the same as hash[num_leaves-1+i]. + + The best way to use me is to obtain the root hash from some 'good' + channel, then call set_hash(0, root). Then use the 'bad' channel to + obtain data block 0 and the corresponding hash chain (a dict with the + same hashes that needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, + 8:h8} when len(L)=8). Hash the data block to create leaf0. Then call:: + + good = iht.set_hashes(hashes=hashchain, leaves={0: leaf0}) + + If 'good' is True, the data block was valid. If 'good' is False, the + hashchain did not have the right blocks and we don't know whether the + data block was good or bad. If set_hashes() raises an exception, either + the data was corrupted or one of the received hashes was corrupted. + """ + + assert isinstance(hashes, dict) + assert isinstance(leaves, dict) + new_hashes = hashes.copy() + for leafnum,leafhash in leaves.iteritems(): + hashnum = self.first_leaf_num + leafnum + if hashnum in new_hashes: + assert new_hashes[hashnum] == leafhash + new_hashes[hashnum] = leafhash + + added = set() # we'll remove these if the check fails - return None + try: + # first we provisionally add all hashes to the tree, comparing any + # duplicates + for i in new_hashes: + if self[i]: + if self[i] != new_hashes[i]: + raise BadHashError("new hash does not match existing hash at [%d]" + % i) + else: + self[i] = new_hashes[i] + added.add(i) + + # then we start from the bottom and compute new parent hashes upwards, + # comparing any that already exist. When this phase ends, all nodes + # that have a sibling will also have a parent. + + hashes_to_check = list(new_hashes.keys()) + # leaf-most first means reverse sorted order + while hashes_to_check: + hashes_to_check.sort() + i = hashes_to_check.pop(-1) + if i == 0: + # The root has no sibling. How lonely. + continue + if self[self.sibling(i)] is None: + # without a sibling, we can't compute a parent + continue + parentnum = self.parent(i) + # make sure we know right from left + leftnum, rightnum = sorted([i, self.sibling(i)]) + new_parent_hash = pair_hash(self[leftnum], self[rightnum]) + if self[parentnum]: + if self[parentnum] != new_parent_hash: + raise BadHashError("h([%d]+[%d]) != h[%d]" % (leftnum, rightnum, + parentnum)) + else: + self[parentnum] = new_parent_hash + added.add(parentnum) + hashes_to_check.insert(0, parentnum) + + # then we walk downwards from the top (root), and anything that is + # reachable is validated. If any of the hashes that we've added are + # unreachable, then they are unvalidated. + + reachable = set() + if self[0]: + reachable.add(0) + # TODO: this could be done more efficiently, by starting from each + # element of new_hashes and walking upwards instead, remembering a set + # of validated nodes so that the searches for later new_hashes goes + # faster. This approach is O(n), whereas O(ln(n)) should be feasible. + for i in range(1, len(self)): + if self[i] and self.parent(i) in reachable: + reachable.add(i) + + # were we unable to validate any of the new hashes? + unvalidated = set(new_hashes.keys()) - reachable + if unvalidated: + if must_validate: + those = ",".join([str(i) for i in sorted(unvalidated)]) + raise NotEnoughHashesError("unable to validate hashes %s" % those) + + except (BadHashError, NotEnoughHashesError): + for i in added: + self[i] = None + raise + + # if there were hashes that could not be validated, we return False + return not unvalidated diff --git a/src/allmydata/download.py b/src/allmydata/download.py index 458e7f2a..3265184d 100644 --- a/src/allmydata/download.py +++ b/src/allmydata/download.py @@ -5,7 +5,7 @@ from twisted.python import log from twisted.internet import defer from twisted.application import service -from allmydata.util import idlib, mathutil +from allmydata.util import idlib, mathutil, hashutil from allmydata.util.assertutil import _assert from allmydata import codec, chunk from allmydata.Crypto.Cipher import AES @@ -47,15 +47,59 @@ class Output: def finish(self): return self.downloadable.finish() +class ValidatedBucket: + def __init__(self, sharenum, bucket, share_hash_tree, num_blocks): + self.sharenum = sharenum + self.bucket = bucket + self.share_hash_tree = share_hash_tree + self.block_hash_tree = chunk.IncompleteHashTree(num_blocks) + + def get_block(self, blocknum): + d1 = self.bucket.callRemote('get_block', blocknum) + # we might also need to grab some elements of our block hash tree, to + # validate the requested block up to the share hash + if self.block_hash_tree.needed_hashes(leaves=[blocknum]): + d2 = self.bucket.callRemote('get_block_hashes') + else: + d2 = defer.succeed(None) + # we might need to grab some elements of the share hash tree to + # validate from our share hash up to the hashroot + if self.share_hash_tree.needed_hashes(leaves=[self.sharenum]): + d3 = self.bucket.callRemote('get_share_hashes') + else: + d3 = defer.succeed(None) + d = defer.gatherResults([d1, d2, d3]) + d.addCallback(self._got_data, blocknum) + return d + + def _got_data(self, res, blocknum): + blockdata, blockhashes, sharehashes = res + blockhash = hashutil.tagged_hash("encoded subshare", blockdata) + if blockhashes: + bh = dict(enumerate(blockhashes)) + self.block_hash_tree.set_hashes(bh, {blocknum: blockhash}, + must_validate=True) + if sharehashes: + sh = dict(sharehashes) + sharehash = self.block_hash_tree[0] + self.share_hash_tree.set_hashes(sh, {self.sharenum: sharehash}, + must_validate=True) + # If we made it here, the block is good. If the hash trees didn't + # like what they saw, they would have raised a BadHashError, causing + # our caller to see a Failure and thus ignore this block (as well as + # dropping this bucket). + return blockdata + + class BlockDownloader: - def __init__(self, bucket, blocknum, parent): - self.bucket = bucket + def __init__(self, vbucket, blocknum, parent): + self.vbucket = vbucket self.blocknum = blocknum self.parent = parent def start(self, segnum): - d = self.bucket.callRemote('get_block', segnum) + d = self.vbucket.get_block(segnum) d.addCallbacks(self._hold_block, self._got_block_error) return d @@ -64,7 +108,7 @@ class BlockDownloader: def _got_block_error(self, f): log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f)) - self.parent.bucket_failed(self.blocknum, self.bucket) + self.parent.bucket_failed(self.blocknum, self.vbucket) class SegmentDownloader: def __init__(self, parent, segmentnumber, needed_shares): @@ -94,86 +138,34 @@ class SegmentDownloader: return d def _try(self): - while len(self.parent.active_buckets) < self.needed_blocks: - # need some more - otherblocknums = list(set(self.parent._share_buckets.keys()) - set(self.parent.active_buckets.keys())) - if not otherblocknums: - raise NotEnoughPeersError - blocknum = random.choice(otherblocknums) - bucket = random.choice(list(self.parent._share_buckets[blocknum])) - self.parent.active_buckets[blocknum] = bucket - + # fill our set of active buckets, maybe raising NotEnoughPeersError + active_buckets = self.parent._activate_enough_buckets() # Now we have enough buckets, in self.parent.active_buckets. - # before we get any blocks of a given share, we need to be able to - # validate that block and that share. Check to see if we have enough - # hashes. If we don't, grab them before continuing. - d = self._grab_needed_hashes() - d.addCallback(self._download_some_blocks) - return d - - def _grab_needed_hashes(self): - # each bucket is holding the hashes necessary to validate their - # share. So it suffices to ask everybody for all the hashes they know - # about. Eventually we'll have all that we need, so we can stop - # asking. - - # for each active share, see what hashes we need - ht = self.parent.get_share_hashtree() - needed_hashes = set() - for shnum in self.parent.active_buckets: - needed_hashes.update(ht.needed_hashes(shnum)) - if not needed_hashes: - return defer.succeed(None) - - # for now, just ask everybody for everything - # TODO: send fewer queries - dl = [] - for shnum, bucket in self.parent.active_buckets.iteritems(): - d = bucket.callRemote("get_share_hashes") - d.addCallback(self._got_share_hashes, shnum, bucket) - dl.append(d) - d.addCallback(self._validate_root) - return defer.DeferredList(dl) - - def _got_share_hashes(self, share_hashes, shnum, bucket): - ht = self.parent.get_share_hashtree() - for hashnum, sharehash in share_hashes: - # TODO: we're accumulating these hashes blindly, since we only - # validate the leaves. This makes it possible for someone to - # frame another server by giving us bad internal hashes. We pass - # 'shnum' and 'bucket' in so that if we detected problems with - # intermediate nodes, we could associate the error with the - # bucket and stop using them. - ht.set_hash(hashnum, sharehash) - - def _validate_root(self, res): - # TODO: I dunno, check that the hash tree looks good so far and that - # it adds up to the root. The idea is to reject any bad buckets - # early. - pass - - def _download_some_blocks(self, res): # in test cases, bd.start might mutate active_buckets right away, so # we need to put off calling start() until we've iterated all the way - # through it + # through it. downloaders = [] - for blocknum, bucket in self.parent.active_buckets.iteritems(): - bd = BlockDownloader(bucket, blocknum, self) + for blocknum, vbucket in active_buckets.iteritems(): + bd = BlockDownloader(vbucket, blocknum, self) downloaders.append(bd) l = [bd.start(self.segmentnumber) for bd in downloaders] - return defer.DeferredList(l) + return defer.DeferredList(l, fireOnOneErrback=True) def hold_block(self, blocknum, data): self.blocks[blocknum] = data - def bucket_failed(self, shnum, bucket): + def bucket_failed(self, shnum, vbucket): del self.parent.active_buckets[shnum] s = self.parent._share_buckets[shnum] - s.remove(bucket) + # s is a set of ValidatedBucket instances + s.remove(vbucket) + # ... which might now be empty if not s: + # there are no more buckets which can provide this share, so + # remove the key. This may prompt us to use a different share. del self.parent._share_buckets[shnum] - + class FileDownloader: debug = False @@ -201,30 +193,21 @@ class FileDownloader: key = "\x00" * 16 self._output = Output(downloadable, key) - # future: - # each time we start using a new shnum, we must acquire a share hash - # from one of the buckets that provides that shnum, then validate it - # against the rest of the share hash tree that they provide. Then, - # each time we get a block in that share, we must validate the block - # against the rest of the subshare hash tree that that bucket will - # provide. - self._share_hashtree = chunk.IncompleteHashTree(total_shares) - #self._block_hashtrees = {} # k: shnum, v: hashtree + self._share_hashtree.set_hashes({0: roothash}) - def get_share_hashtree(self): - return self._share_hashtree + self.active_buckets = {} # k: shnum, v: bucket + self._share_buckets = {} # k: shnum, v: set of buckets def start(self): log.msg("starting download [%s]" % (idlib.b2a(self._verifierid),)) if self.debug: print "starting download" - # first step: who should we download from? - self.active_buckets = {} # k: shnum, v: bucket - self._share_buckets = {} # k: shnum, v: set of buckets + # first step: who should we download from? d = defer.maybeDeferred(self._get_all_shareholders) d.addCallback(self._got_all_shareholders) + # once we know that, we can download blocks from them d.addCallback(self._download_all_segments) d.addCallback(self._done) return d @@ -243,19 +226,55 @@ class FileDownloader: def _got_response(self, buckets, connection): _assert(isinstance(buckets, dict), buckets) # soon foolscap will check this for us with its DictOf schema constraint for sharenum, bucket in buckets.iteritems(): - self._share_buckets.setdefault(sharenum, set()).add(bucket) - + self.add_share_bucket(sharenum, bucket) + + def add_share_bucket(self, sharenum, bucket): + vbucket = ValidatedBucket(sharenum, bucket, + self._share_hashtree, + self._total_segments) + self._share_buckets.setdefault(sharenum, set()).add(vbucket) + def _got_error(self, f): self._client.log("Somebody failed. -- %s" % (f,)) def _got_all_shareholders(self, res): if len(self._share_buckets) < self._num_needed_shares: raise NotEnoughPeersError + for s in self._share_buckets.values(): + for vb in s: + assert isinstance(vb, ValidatedBucket), \ + "vb is %s but should be a ValidatedBucket" % (vb,) + + + def _activate_enough_buckets(self): + """either return a mapping from shnum to a ValidatedBucket that can + provide data for that share, or raise NotEnoughPeersError""" + + while len(self.active_buckets) < self._num_needed_shares: + # need some more + handled_shnums = set(self.active_buckets.keys()) + available_shnums = set(self._share_buckets.keys()) + potential_shnums = list(available_shnums - handled_shnums) + if not potential_shnums: + raise NotEnoughPeersError + # choose a random share + shnum = random.choice(potential_shnums) + # and a random bucket that will provide it + validated_bucket = random.choice(list(self._share_buckets[shnum])) + self.active_buckets[shnum] = validated_bucket + return self.active_buckets - self.active_buckets = {} - self._output.open() def _download_all_segments(self, res): + # the promise: upon entry to this function, self._share_buckets + # contains enough buckets to complete the download, and some extra + # ones to tolerate some buckets dropping out or having errors. + # self._share_buckets is a dictionary that maps from shnum to a set + # of ValidatedBuckets, which themselves are wrappers around + # RIBucketReader references. + self.active_buckets = {} # k: shnum, v: ValidatedBucket instance + self._output.open() + d = defer.succeed(None) for segnum in range(self._total_segments-1): d.addCallback(self._download_segment, segnum) diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index 6372af89..a58351c7 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -156,9 +156,9 @@ class Roundtrip(unittest.TestCase): client = None target = download.Data() fd = download.FileDownloader(client, URI, target) - fd._share_buckets = {} for shnum in range(NUM_SHARES): - fd._share_buckets[shnum] = set([all_shareholders[shnum]]) + bucket = all_shareholders[shnum] + fd.add_share_bucket(shnum, bucket) fd._got_all_shareholders(None) d2 = fd._download_all_segments(None) d2.addCallback(fd._done) diff --git a/src/allmydata/test/test_hashtree.py b/src/allmydata/test/test_hashtree.py index 5cac2cc1..1a51c6c6 100644 --- a/src/allmydata/test/test_hashtree.py +++ b/src/allmydata/test/test_hashtree.py @@ -21,7 +21,24 @@ class Complete(unittest.TestCase): root = ht[0] self.failUnlessEqual(len(root), 32) + def testDump(self): + ht = make_tree(6) + expected = [(0,0), + (1,1), (3,2), (7,3), (8,3), (4,2), (9,3), (10,3), + (2,1), (5,2), (11,3), (12,3), (6,2), (13,3), (14,3), + ] + self.failUnlessEqual(list(ht.depth_first()), expected) + d = "\n" + ht.dump() + #print d + self.failUnless("\n 0:" in d) + self.failUnless("\n 1:" in d) + self.failUnless("\n 3:" in d) + self.failUnless("\n 7:" in d) + self.failUnless("\n 8:" in d) + self.failUnless("\n 4:" in d) + class Incomplete(unittest.TestCase): + def testCheck(self): # first create a complete hash tree ht = make_tree(6) @@ -30,32 +47,35 @@ class Incomplete(unittest.TestCase): # suppose we wanted to validate leaf[0] # leaf[0] is the same as node[7] - self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2, 0])) - self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2, 0])) - iht.set_hash(0, ht[0]) - self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2])) - self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2])) - iht.set_hash(5, ht[5]) - self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2])) - self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2])) - + self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2, 0])) + self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2, 0])) + iht.set_hashes({0: ht[0]}) # set the root + self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2])) + self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2])) + iht.set_hashes({5: ht[5]}) + self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2])) + self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2])) + + current_hashes = list(iht) try: # this should fail because there aren't enough hashes known - iht.set_leaf(0, tagged_hash("tag", "0")) + iht.set_hashes(leaves={0: tagged_hash("tag", "0")}, + must_validate=True) except chunk.NotEnoughHashesError: pass else: self.fail("didn't catch not enough hashes") + # and the set of hashes stored in the tree should still be the same + self.failUnlessEqual(list(iht), current_hashes) + # provide the missing hashes - iht.set_hash(2, ht[2]) - iht.set_hash(4, ht[4]) - iht.set_hash(8, ht[8]) - self.failUnlessEqual(iht.needed_hashes(0), set([])) + iht.set_hashes({2: ht[2], 4: ht[4], 8: ht[8]}) + self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set()) try: # this should fail because the hash is just plain wrong - iht.set_leaf(0, tagged_hash("bad tag", "0")) + iht.set_hashes(leaves={0: tagged_hash("bad tag", "0")}) except chunk.BadHashError: pass else: @@ -63,34 +83,33 @@ class Incomplete(unittest.TestCase): try: # this should succeed - iht.set_leaf(0, tagged_hash("tag", "0")) + iht.set_hashes(leaves={0: tagged_hash("tag", "0")}) except chunk.BadHashError, e: self.fail("bad hash: %s" % e) try: # this should succeed too - iht.set_leaf(1, tagged_hash("tag", "1")) + iht.set_hashes(leaves={1: tagged_hash("tag", "1")}) except chunk.BadHashError: self.fail("bad hash") # giving it a bad internal hash should also cause problems - iht.set_hash(2, tagged_hash("bad tag", "x")) + iht.set_hashes({13: tagged_hash("bad tag", "x")}) try: - iht.set_leaf(0, tagged_hash("tag", "0")) + iht.set_hashes({14: tagged_hash("tag", "14")}) except chunk.BadHashError: pass else: self.fail("didn't catch bad hash") # undo our damage - iht.set_hash(2, ht[2]) + iht[13] = None - self.failUnlessEqual(iht.needed_hashes(4), set([12, 6])) + self.failUnlessEqual(iht.needed_hashes(leaves=[4]), set([12, 6])) - iht.set_hash(6, ht[6]) - iht.set_hash(12, ht[12]) + iht.set_hashes({6: ht[6], 12: ht[12]}) try: # this should succeed - iht.set_leaf(4, tagged_hash("tag", "4")) + iht.set_hashes(leaves={4: tagged_hash("tag", "4")}) except chunk.BadHashError, e: self.fail("bad hash: %s" % e)