From 30133a7cdf00075c676918c71f7b42ca6470303a Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Thu, 12 Apr 2007 19:41:48 -0700
Subject: [PATCH] hash trees: further cleanup, to make sure we're validating
 the right thing hashtree.py: improve the methods available for finding out
 which hash nodes  are needed. Change set_hashes() to require that every hash
 provided can  be validated up to the root. download.py: validate from the top
 down, including the URI-derived roothash  in the share hash tree, and
 stashing the thus-validated share hash for use  in the block hash tree.

---
 src/allmydata/download.py           |  63 ++++++++------
 src/allmydata/encode.py             |   8 +-
 src/allmydata/hashtree.py           | 130 +++++++++++++++++++---------
 src/allmydata/test/test_hashtree.py | 104 ++++++++++++++--------
 4 files changed, 197 insertions(+), 108 deletions(-)

diff --git a/src/allmydata/download.py b/src/allmydata/download.py
index a1b250c1..959d3fa7 100644
--- a/src/allmydata/download.py
+++ b/src/allmydata/download.py
@@ -48,48 +48,56 @@ class Output:
         return self.downloadable.finish()
 
 class ValidatedBucket:
-    def __init__(self, sharenum, bucket, share_hash_tree, num_blocks):
+    def __init__(self, sharenum, bucket,
+                 share_hash_tree, roothash,
+                 num_blocks):
         self.sharenum = sharenum
         self.bucket = bucket
+        self._share_hash = None # None means not validated yet
         self.share_hash_tree = share_hash_tree
+        self._roothash = roothash
         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
 
     def get_block(self, blocknum):
-        d1 = self.bucket.callRemote('get_block', blocknum)
-        # we might also need to grab some elements of our block hash tree, to
+        # the first time we use this bucket, we need to fetch enough elements
+        # of the share hash tree to validate it from our share hash up to the
+        # hashroot.
+        if not self._share_hash:
+            d1 = self.bucket.callRemote('get_share_hashes')
+        else:
+            d1 = defer.succeed(None)
+
+        # we might need to grab some elements of our block hash tree, to
         # validate the requested block up to the share hash
-        if self.block_hash_tree.needed_hashes(leaves=[blocknum]):
+        needed = self.block_hash_tree.needed_hashes(blocknum)
+        if needed:
+            # TODO: get fewer hashes, callRemote('get_block_hashes', needed)
             d2 = self.bucket.callRemote('get_block_hashes')
         else:
-            d2 = defer.succeed(None)
-        # we might need to grab some elements of the share hash tree to
-        # validate from our share hash up to the hashroot
-        if self.share_hash_tree.needed_hashes(leaves=[self.sharenum]):
-            d3 = self.bucket.callRemote('get_share_hashes')
-            need_to_validate_sharehash = True
-        else:
-            d3 = defer.succeed(None)
-            need_to_validate_sharehash = False
+            d2 = defer.succeed([])
+
+        d3 = self.bucket.callRemote('get_block', blocknum)
+
         d = defer.gatherResults([d1, d2, d3])
-        d.addCallback(self._got_data, blocknum, need_to_validate_sharehash)
+        d.addCallback(self._got_data, blocknum)
         return d
 
-    def _got_data(self, res, blocknum, need_to_validate_sharehash):
-        blockdata, blockhashes, sharehashes = res
+    def _got_data(self, res, blocknum):
+        sharehashes, blockhashes, blockdata = res
+
+        if not self._share_hash:
+            sh = dict(sharehashes)
+            sh[0] = self._roothash # always use our own root, from the URI
+            if self.share_hash_tree.get_leaf_index(self.sharenum) not in sh:
+                raise hashutil.NotEnoughHashesError
+            self.share_hash_tree.set_hashes(sh)
+            self._share_hash = self.share_hash_tree.get_leaf(self.sharenum)
+
         blockhash = hashutil.tagged_hash("encoded subshare", blockdata)
         # we always validate the blockhash
-        if blockhashes is None:
-            blockhashes = []
         bh = dict(enumerate(blockhashes))
-        self.block_hash_tree.set_hashes(bh, {blocknum: blockhash},
-                                        must_validate=True)
-        if need_to_validate_sharehash:
-            # we only need to validate the sharehash once, the first time we
-            # fetch a block
-            sh = dict(sharehashes)
-            sharehash = self.block_hash_tree[0]
-            self.share_hash_tree.set_hashes(sh, {self.sharenum: sharehash},
-                                            must_validate=True)
+        bh[0] = self._share_hash # replace blockhash root with validated value
+        self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
         # If we made it here, the block is good. If the hash trees didn't
         # like what they saw, they would have raised a BadHashError, causing
         # our caller to see a Failure and thus ignore this block (as well as
@@ -237,6 +245,7 @@ class FileDownloader:
     def add_share_bucket(self, sharenum, bucket):
         vbucket = ValidatedBucket(sharenum, bucket,
                                   self._share_hashtree,
+                                  self._roothash,
                                   self._total_segments)
         self._share_buckets.setdefault(sharenum, set()).add(vbucket)
 
diff --git a/src/allmydata/encode.py b/src/allmydata/encode.py
index 743e63b8..f22e72cc 100644
--- a/src/allmydata/encode.py
+++ b/src/allmydata/encode.py
@@ -252,6 +252,10 @@ class Encoder(object):
         return sh.callRemote("put_block_hashes", all_hashes)
 
     def send_all_share_hash_trees(self):
+        # each bucket gets a set of share hash tree nodes that are needed to
+        # validate their share. This includes the share hash itself, but does
+        # not include the top-level hash root (which is stored securely in
+        # the URI instead).
         log.msg("%s sending all share hash trees" % self)
         dl = []
         for h in self.share_root_hashes:
@@ -264,9 +268,7 @@ class Encoder(object):
         for i in range(self.num_shares):
             # the HashTree is given a list of leaves: 0,1,2,3..n .
             # These become nodes A+0,A+1,A+2.. of the tree, where A=n-1
-            tree_width = roundup_pow2(self.num_shares)
-            base_index = i + tree_width - 1
-            needed_hash_indices = t.needed_for(base_index)
+            needed_hash_indices = t.needed_hashes(i, include_leaf=True)
             hashes = [(hi, t[hi]) for hi in needed_hash_indices]
             dl.append(self.send_one_share_hash_tree(i, hashes))
         return defer.DeferredList(dl)
diff --git a/src/allmydata/hashtree.py b/src/allmydata/hashtree.py
index 900c5032..63816f5e 100644
--- a/src/allmydata/hashtree.py
+++ b/src/allmydata/hashtree.py
@@ -156,6 +156,12 @@ class CompleteBinaryTreeMixin:
                                         idlib.b2a_or_none(self[i])))
         return "\n".join(lines) + "\n"
 
+    def get_leaf_index(self, leafnum):
+        return self.first_leaf_num + leafnum
+
+    def get_leaf(self, leafnum):
+        return self[self.first_leaf_num + leafnum]
+
 def empty_leaf_hash(i):
     return tagged_hash('Merkle tree empty leaf', "%d" % i)
 def pair_hash(a, b):
@@ -193,6 +199,7 @@ class HashTree(CompleteBinaryTreeMixin, list):
         # Augment the list.
         start = len(L)
         end   = roundup_pow2(len(L))
+        self.first_leaf_num = end - 1
         L     = L + [None] * (end - start)
         for i in range(start, end):
             L[i] = empty_leaf_hash(i)
@@ -206,6 +213,36 @@ class HashTree(CompleteBinaryTreeMixin, list):
         rows.reverse()
         self[:] = sum(rows, [])
 
+    def needed_hashes(self, leafnum, include_leaf=False):
+        """Which hashes will someone need to validate a given data block?
+
+        I am used to answer a question: supposing you have the data block
+        that is used to form leaf hash N, and you want to validate that it,
+        which hashes would you need?
+
+        I accept a leaf number and return a set of 'hash index' values, which
+        are integers from 0 to len(self). In the 'hash index' number space,
+        hash[0] is the root hash, while hash[len(self)-1] is the last leaf
+        hash.
+
+        This method can be used to find out which hashes you should request
+        from some untrusted source (usually the same source that provides the
+        data block), so you can minimize storage or transmission overhead. It
+        can also be used to determine which hashes you should send to a
+        remote data store so that it will be able to provide validatable data
+        in the future.
+
+        I will not include '0' (the root hash) in the result, since the root
+        is generally stored somewhere that is more trusted than the source of
+        the remaining hashes. I will include the leaf hash itself only if you
+        ask me to, by passing include_leaf=True.
+        """
+
+        needed = set(self.needed_for(self.first_leaf_num + leafnum))
+        if include_leaf:
+            needed.add(self.first_leaf_num + leafnum)
+        return needed
+
 
 class NotEnoughHashesError(Exception):
     pass
@@ -250,18 +287,24 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
         rows.reverse()
         self[:] = sum(rows, [])
 
-    def needed_hashes(self, hashes=[], leaves=[]):
-        hashnums = set(list(hashes))
-        for leafnum in leaves:
-            hashnums.add(self.first_leaf_num + leafnum)
-        maybe_needed = set()
-        for hashnum in hashnums:
-            maybe_needed.update(self.needed_for(hashnum))
-        maybe_needed.add(0) # need the root too
-        return set([i for i in maybe_needed if self[i] is None])
 
+    def needed_hashes(self, leafnum, include_leaf=False):
+        """Which new hashes do I need to validate a given data block?
 
-    def set_hashes(self, hashes={}, leaves={}, must_validate=False):
+        I am much like HashTree.needed_hashes(), except that I don't include
+        hashes that I already know about. When needed_hashes() is called on
+        an empty IncompleteHashTree, it will return the same set as a
+        HashTree of the same size. But later, once hashes have been added
+        with set_hashes(), I will ask for fewer hashes, since some of the
+        necessary ones have already been set.
+        """
+
+        maybe_needed = set(self.needed_for(self.first_leaf_num + leafnum))
+        if include_leaf:
+            maybe_needed.add(self.first_leaf_num + leafnum)
+        return set([i for i in maybe_needed if self[i] is None])
+
+    def set_hashes(self, hashes={}, leaves={}):
         """Add a bunch of hashes to the tree.
 
         I will validate these to the best of my ability. If I already have a
@@ -273,15 +316,12 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
         before I was called. If I return successfully, I will remember all
         those hashes.
 
-        If every hash that was added was validated, I will return True. If
-        some could not be validated because I did not have enough parent
-        hashes, I will return False. As a result, if I am called with both a
-        leaf hash and the root hash was already set, I will return True if
-        and only if the leaf hash could be validated against the root.
-
-        If must_validate is True, I will raise NotEnoughHashesError instead
-        of returning False. If I raise NotEnoughHashesError, I will forget
-        about all the hashes that you tried to add. TODO: really?
+        I insist upon being able to validate all of the hashes that were
+        given to me. If I cannot do this because I'm missing some hashes, I
+        will raise NotEnoughHashesError (and forget about all the hashes that
+        you tried to add). Note that this means that the root hash must
+        either be included in 'hashes', or it must have been provided at some
+        point in the past.
 
         'leaves' is a dictionary uses 'leaf index' values, which range from 0
         (the left-most leaf) to num_leaves-1 (the right-most leaf), and form
@@ -290,28 +330,42 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
         leaf). leaf[i] is the same as hash[num_leaves-1+i].
 
         The best way to use me is to obtain the root hash from some 'good'
-        channel, then call set_hash(0, root). Then use the 'bad' channel to
-        obtain data block 0 and the corresponding hash chain (a dict with the
-        same hashes that needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4,
-        8:h8} when len(L)=8). Hash the data block to create leaf0. Then
-        call::
-
-          good = iht.set_hashes(hashes=hashchain, leaves={0: leaf0})
-
-        If 'good' is True, the data block was valid. If 'good' is False, the
-        hashchain did not have the right blocks and we don't know whether the
-        data block was good or bad. If set_hashes() raises an exception,
-        either the data was corrupted or one of the received hashes was
-        corrupted.
+        channel, and use the 'bad' channel to obtain data block 0 and the
+        corresponding hash chain (a dict with the same hashes that
+        needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, 8:h8} when
+        len(L)=8). Hash the data block to create leaf0, then feed everything
+        into set_hashes() and see if it raises an exception or not::
+
+          iht = IncompleteHashTree(numleaves)
+          roothash = trusted_channel.get_roothash()
+          otherhashes = untrusted_channel.get_hashes()
+          # otherhashes.keys() should == iht.needed_hashes(leaves=[0])
+          datablock0 = untrusted_channel.get_data(0)
+          leaf0 = HASH(datablock0)
+          # HASH() is probably hashutil.tagged_hash(tag, datablock0)
+          hashes = otherhashes.copy()
+          hashes[0] = roothash # from 'good' channel
+          iht.set_hashes(hashes, leaves={0: leaf0})
+
+        If the set_hashes() call doesn't raise an exception, the data block
+        was valid. If it raises BadHashError, then either the data block was
+        corrupted or one of the received hashes was corrupted.
         """
 
         assert isinstance(hashes, dict)
+        for h in hashes.values():
+            assert isinstance(h, str)
         assert isinstance(leaves, dict)
+        for h in leaves.values():
+            assert isinstance(h, str)
         new_hashes = hashes.copy()
         for leafnum,leafhash in leaves.iteritems():
             hashnum = self.first_leaf_num + leafnum
             if hashnum in new_hashes:
-                assert new_hashes[hashnum] == leafhash
+                if new_hashes[hashnum] != leafhash:
+                    raise BadHashError("got conflicting hashes in my "
+                                       "arguments: leaves[%d] != hashes[%d]"
+                                       % (leafnum, hashnum))
             new_hashes[hashnum] = leafhash
 
         added = set() # we'll remove these if the check fails
@@ -374,15 +428,11 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
             # were we unable to validate any of the new hashes?
             unvalidated = set(new_hashes.keys()) - reachable
             if unvalidated:
-                if must_validate:
-                    those = ",".join([str(i) for i in sorted(unvalidated)])
-                    raise NotEnoughHashesError("unable to validate hashes %s" % those)
+                those = ",".join([str(i) for i in sorted(unvalidated)])
+                raise NotEnoughHashesError("unable to validate hashes %s"
+                                           % those)
 
         except (BadHashError, NotEnoughHashesError):
             for i in added:
                 self[i] = None
             raise
-
-        # if there were hashes that could not be validated, we return False
-        return not unvalidated
-
diff --git a/src/allmydata/test/test_hashtree.py b/src/allmydata/test/test_hashtree.py
index 2349ca2c..b2907a61 100644
--- a/src/allmydata/test/test_hashtree.py
+++ b/src/allmydata/test/test_hashtree.py
@@ -13,15 +13,27 @@ def make_tree(numleaves):
     return ht
 
 class Complete(unittest.TestCase):
-    def testCreate(self):
-        # try out various sizes
+    def test_create(self):
+        # try out various sizes, since we pad to a power of two
         ht = make_tree(6)
-        ht = make_tree(8)
         ht = make_tree(9)
+        ht = make_tree(8)
         root = ht[0]
         self.failUnlessEqual(len(root), 32)
+        self.failUnlessEqual(ht.get_leaf(0), tagged_hash("tag", "0"))
+        self.failUnlessRaises(IndexError, ht.get_leaf, 8)
+        self.failUnlessEqual(ht.get_leaf_index(0), 7)
 
-    def testDump(self):
+    def test_needed_hashes(self):
+        ht = make_tree(8)
+        self.failUnlessEqual(ht.needed_hashes(0), set([8, 4, 2]))
+        self.failUnlessEqual(ht.needed_hashes(0, True), set([7, 8, 4, 2]))
+        self.failUnlessEqual(ht.needed_hashes(1), set([7, 4, 2]))
+        self.failUnlessEqual(ht.needed_hashes(7), set([13, 5, 1]))
+        self.failUnlessEqual(ht.needed_hashes(7, False), set([13, 5, 1]))
+        self.failUnlessEqual(ht.needed_hashes(7, True), set([14, 13, 5, 1]))
+
+    def test_dump(self):
         ht = make_tree(6)
         expected = [(0,0),
                     (1,1), (3,2), (7,3), (8,3), (4,2), (9,3), (10,3),
@@ -39,7 +51,16 @@ class Complete(unittest.TestCase):
 
 class Incomplete(unittest.TestCase):
 
-    def testCheck(self):
+    def test_create(self):
+        ht = hashtree.IncompleteHashTree(6)
+        ht = hashtree.IncompleteHashTree(9)
+        ht = hashtree.IncompleteHashTree(8)
+        self.failUnlessEqual(ht[0], None)
+        self.failUnlessEqual(ht.get_leaf(0), None)
+        self.failUnlessRaises(IndexError, ht.get_leaf, 8)
+        self.failUnlessEqual(ht.get_leaf_index(0), 7)
+
+    def test_check(self):
         # first create a complete hash tree
         ht = make_tree(6)
         # then create a corresponding incomplete tree
@@ -47,20 +68,23 @@ class Incomplete(unittest.TestCase):
 
         # suppose we wanted to validate leaf[0]
         #  leaf[0] is the same as node[7]
-        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2, 0]))
-        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2, 0]))
-        iht.set_hashes({0: ht[0]}) # set the root
-        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2]))
-        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2]))
-        iht.set_hashes({5: ht[5]})
-        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2]))
-        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(0, True), set([7, 8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2]))
+        iht[0] = ht[0] # set the root
+        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2]))
+        iht[5] = ht[5]
+        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2]))
+
+        # reset
+        iht = hashtree.IncompleteHashTree(6)
 
         current_hashes = list(iht)
         try:
             # this should fail because there aren't enough hashes known
-            iht.set_hashes(leaves={0: tagged_hash("tag", "0")},
-                           must_validate=True)
+            iht.set_hashes(leaves={0: tagged_hash("tag", "0")})
         except hashtree.NotEnoughHashesError:
             pass
         else:
@@ -68,48 +92,52 @@ class Incomplete(unittest.TestCase):
 
         # and the set of hashes stored in the tree should still be the same
         self.failUnlessEqual(list(iht), current_hashes)
+        # and we should still need the same
+        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
+
+        chain = {0: ht[0], 2: ht[2], 4: ht[4], 8: ht[8]}
+        try:
+            # this should fail because the leaf hash is just plain wrong
+            iht.set_hashes(chain, leaves={0: tagged_hash("bad tag", "0")})
+        except hashtree.BadHashError:
+            pass
+        else:
+            self.fail("didn't catch bad hash")
 
-        # provide the missing hashes
-        iht.set_hashes({2: ht[2], 4: ht[4], 8: ht[8]})
-        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set())
+        bad_chain = chain.copy()
+        bad_chain[2] = ht[2] + "BOGUS"
 
+        # this should fail because the internal hash is wrong
         try:
-            # this should fail because the hash is just plain wrong
-            iht.set_hashes(leaves={0: tagged_hash("bad tag", "0")})
+            iht.set_hashes(bad_chain, leaves={0: tagged_hash("tag", "0")})
         except hashtree.BadHashError:
             pass
         else:
             self.fail("didn't catch bad hash")
 
+        # this should succeed
         try:
-            # this should succeed
-            iht.set_hashes(leaves={0: tagged_hash("tag", "0")})
+            iht.set_hashes(chain, leaves={0: tagged_hash("tag", "0")})
         except hashtree.BadHashError, e:
             self.fail("bad hash: %s" % e)
 
+        self.failUnlessEqual(ht.get_leaf(0), tagged_hash("tag", "0"))
+        self.failUnlessRaises(IndexError, ht.get_leaf, 8)
+
+        # this should succeed too
         try:
-            # this should succeed too
             iht.set_hashes(leaves={1: tagged_hash("tag", "1")})
         except hashtree.BadHashError:
             self.fail("bad hash")
 
-        # giving it a bad internal hash should also cause problems
-        iht.set_hashes({13: tagged_hash("bad tag", "x")})
-        try:
-            iht.set_hashes({14: tagged_hash("tag", "14")})
-        except hashtree.BadHashError:
-            pass
-        else:
-            self.fail("didn't catch bad hash")
-        # undo our damage
-        iht[13] = None
-
-        self.failUnlessEqual(iht.needed_hashes(leaves=[4]), set([12, 6]))
+        # now that leaves 0 and 1 are known, some of the internal nodes are
+        # known
+        self.failUnlessEqual(iht.needed_hashes(4), set([12, 6]))
+        chain = {6: ht[6], 12: ht[12]}
 
-        iht.set_hashes({6: ht[6], 12: ht[12]})
+        # this should succeed
         try:
-            # this should succeed
-            iht.set_hashes(leaves={4: tagged_hash("tag", "4")})
+            iht.set_hashes(chain, leaves={4: tagged_hash("tag", "4")})
         except hashtree.BadHashError, e:
             self.fail("bad hash: %s" % e)
 
-- 
2.45.2