From: Brian Warner <warner@allmydata.com>
Date: Thu, 12 Apr 2007 20:07:40 +0000 (-0700)
Subject: verify hash chains on incoming blocks
X-Git-Tag: tahoe_v0.1.0-0-UNSTABLE~111
X-Git-Url: https://git.rkrishnan.org/%5B/frontends/%22news.html/%22doc.html/cyclelanguage?a=commitdiff_plain;h=8f58b30db9430cbe3208930b421eb3407270d907;p=tahoe-lafs%2Ftahoe-lafs.git

verify hash chains on incoming blocks
Implement enough of chunk.IncompleteHashTree to be usable.
Rearrange download: all block/hash requests now go through
a ValidatedBucket instance, which is responsible for retrieving
and verifying hashes before providing validated data. Download
was changed to use ValidatedBuckets everywhere instead of
unwrapped RIBucketReader references.
---

diff --git a/src/allmydata/chunk.py b/src/allmydata/chunk.py
index a797c464..3e4e09ed 100644
--- a/src/allmydata/chunk.py
+++ b/src/allmydata/chunk.py
@@ -48,6 +48,7 @@ or implied.  It probably won't make your computer catch on fire,
 or eat  your children, but it might.  Use at your own risk.
 """
 
+from allmydata.util import idlib
 from allmydata.util.hashutil import tagged_hash, tagged_pair_hash
 
 __version__ = '1.0.0-allmydata'
@@ -134,6 +135,25 @@ class CompleteBinaryTreeMixin:
       here = self.parent(here)
     return needed
 
+  def depth_first(self, i=0):
+    yield i, 0
+    try:
+      for child,childdepth in self.depth_first(self.lchild(i)):
+        yield child, childdepth+1
+    except IndexError:
+      pass
+    try:
+      for child,childdepth in self.depth_first(self.rchild(i)):
+        yield child, childdepth+1
+    except IndexError:
+      pass
+
+  def dump(self):
+    lines = []
+    for i,depth in self.depth_first():
+      lines.append("%s%3d: %s" % ("  "*depth, i, idlib.b2a_or_none(self[i])))
+    return "\n".join(lines) + "\n"
+
 def empty_leaf_hash(i):
   return tagged_hash('Merkle tree empty leaf', "%d" % i)
 def pair_hash(a, b):
@@ -225,49 +245,136 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
     rows.reverse()
     self[:] = sum(rows, [])
 
-  def needed_hashes(self, leafnum):
-    hashnum = self.first_leaf_num + leafnum
-    maybe_needed = self.needed_for(hashnum)
-    maybe_needed += [0] # need the root too
+  def needed_hashes(self, hashes=[], leaves=[]):
+    hashnums = set(list(hashes))
+    for leafnum in leaves:
+      hashnums.add(self.first_leaf_num + leafnum)
+    maybe_needed = set()
+    for hashnum in hashnums:
+      maybe_needed.update(self.needed_for(hashnum))
+    maybe_needed.add(0) # need the root too
     return set([i for i in maybe_needed if self[i] is None])
 
-  def set_hash(self, i, newhash):
-    # note that we don't attempt to validate these
-    self[i] = newhash
-
-  def set_leaf(self, leafnum, leafhash):
-    hashnum = self.first_leaf_num + leafnum
-    needed = self.needed_hashes(leafnum)
-    if needed:
-      msg = "we need hashes " + ",".join([str(i) for i in needed])
-      raise NotEnoughHashesError(msg)
-    assert self[0] is not None # we can't validate without a root
-    added = set() # we'll remove these if the check fails
-    self[hashnum] = leafhash
-    added.add(hashnum)
 
-    # now propagate hash checks upwards until we reach the root
-    here = hashnum
-    while here != 0:
-      us = [here, self.sibling(here)]
-      us.sort()
-      leftnum, rightnum = us
-      lefthash = self[leftnum]
-      righthash = self[rightnum]
-      parent = self.parent(here)
-      parenthash = self[parent]
-
-      ourhash = pair_hash(lefthash, righthash)
-      if parenthash is not None:
-        if ourhash != parenthash:
-          for i in added:
-            self[i] = None
-          raise BadHashError("h([%d]+[%d]) != h[%d]" % (leftnum, rightnum,
-                                                        parent))
-      else:
-        self[parent] = ourhash
-        added.add(parent)
-      here = self.parent(here)
+  def set_hashes(self, hashes={}, leaves={}, must_validate=False):
+    """Add a bunch of hashes to the tree.
+
+    I will validate these to the best of my ability. If I already have a copy
+    of any of the new hashes, the new values must equal the existing ones, or
+    I will raise BadHashError. If adding a hash allows me to compute a parent
+    hash, those parent hashes must match or I will raise BadHashError. If I
+    raise BadHashError, I will forget about all the hashes that you tried to
+    add, leaving my state exactly the same as before I was called. If I
+    return successfully, I will remember all those hashes.
+
+    If every hash that was added was validated, I will return True. If some
+    could not be validated because I did not have enough parent hashes, I
+    will return False. As a result, if I am called with both a leaf hash and
+    the root hash was already set, I will return True if and only if the leaf
+    hash could be validated against the root.
+
+    If must_validate is True, I will raise NotEnoughHashesError instead of
+    returning False. If I raise NotEnoughHashesError, I will forget about all
+    the hashes that you tried to add. TODO: really?
+
+    'leaves' is a dictionary uses 'leaf index' values, which range from 0
+    (the left-most leaf) to num_leaves-1 (the right-most leaf), and form the
+    base of the tree. 'hashes' uses 'hash_index' values, which range from 0
+    (the root of the tree) to 2*num_leaves-2 (the right-most leaf). leaf[i]
+    is the same as hash[num_leaves-1+i].
+
+    The best way to use me is to obtain the root hash from some 'good'
+    channel, then call set_hash(0, root). Then use the 'bad' channel to
+    obtain data block 0 and the corresponding hash chain (a dict with the
+    same hashes that needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4,
+    8:h8} when len(L)=8). Hash the data block to create leaf0. Then call::
+
+     good = iht.set_hashes(hashes=hashchain, leaves={0: leaf0})
+
+    If 'good' is True, the data block was valid. If 'good' is False, the
+    hashchain did not have the right blocks and we don't know whether the
+    data block was good or bad. If set_hashes() raises an exception, either
+    the data was corrupted or one of the received hashes was corrupted.
+    """
+
+    assert isinstance(hashes, dict)
+    assert isinstance(leaves, dict)
+    new_hashes = hashes.copy()
+    for leafnum,leafhash in leaves.iteritems():
+      hashnum = self.first_leaf_num + leafnum
+      if hashnum in new_hashes:
+        assert new_hashes[hashnum] == leafhash
+      new_hashes[hashnum] = leafhash
+
+    added = set() # we'll remove these if the check fails
 
-    return None
+    try:
+      # first we provisionally add all hashes to the tree, comparing any
+      # duplicates
+      for i in new_hashes:
+        if self[i]:
+          if self[i] != new_hashes[i]:
+            raise BadHashError("new hash does not match existing hash at [%d]"
+                               % i)
+        else:
+          self[i] = new_hashes[i]
+          added.add(i)
+
+      # then we start from the bottom and compute new parent hashes upwards,
+      # comparing any that already exist. When this phase ends, all nodes
+      # that have a sibling will also have a parent.
+
+      hashes_to_check = list(new_hashes.keys())
+      # leaf-most first means reverse sorted order
+      while hashes_to_check:
+        hashes_to_check.sort()
+        i = hashes_to_check.pop(-1)
+        if i == 0:
+          # The root has no sibling. How lonely.
+          continue
+        if self[self.sibling(i)] is None:
+          # without a sibling, we can't compute a parent
+          continue
+        parentnum = self.parent(i)
+        # make sure we know right from left
+        leftnum, rightnum = sorted([i, self.sibling(i)])
+        new_parent_hash = pair_hash(self[leftnum], self[rightnum])
+        if self[parentnum]:
+          if self[parentnum] != new_parent_hash:
+            raise BadHashError("h([%d]+[%d]) != h[%d]" % (leftnum, rightnum,
+                                                          parentnum))
+        else:
+          self[parentnum] = new_parent_hash
+          added.add(parentnum)
+          hashes_to_check.insert(0, parentnum)
+
+      # then we walk downwards from the top (root), and anything that is
+      # reachable is validated. If any of the hashes that we've added are
+      # unreachable, then they are unvalidated.
+
+      reachable = set()
+      if self[0]:
+        reachable.add(0)
+      # TODO: this could be done more efficiently, by starting from each
+      # element of new_hashes and walking upwards instead, remembering a set
+      # of validated nodes so that the searches for later new_hashes goes
+      # faster. This approach is O(n), whereas O(ln(n)) should be feasible.
+      for i in range(1, len(self)):
+        if self[i] and self.parent(i) in reachable:
+          reachable.add(i)
+
+      # were we unable to validate any of the new hashes?
+      unvalidated = set(new_hashes.keys()) - reachable
+      if unvalidated:
+        if must_validate:
+          those = ",".join([str(i) for i in sorted(unvalidated)])
+          raise NotEnoughHashesError("unable to validate hashes %s" % those)
+
+    except (BadHashError, NotEnoughHashesError):
+      for i in added:
+        self[i] = None
+      raise
+
+    # if there were hashes that could not be validated, we return False
+    return not unvalidated
 
diff --git a/src/allmydata/download.py b/src/allmydata/download.py
index 458e7f2a..3265184d 100644
--- a/src/allmydata/download.py
+++ b/src/allmydata/download.py
@@ -5,7 +5,7 @@ from twisted.python import log
 from twisted.internet import defer
 from twisted.application import service
 
-from allmydata.util import idlib, mathutil
+from allmydata.util import idlib, mathutil, hashutil
 from allmydata.util.assertutil import _assert
 from allmydata import codec, chunk
 from allmydata.Crypto.Cipher import AES
@@ -47,15 +47,59 @@ class Output:
     def finish(self):
         return self.downloadable.finish()
 
+class ValidatedBucket:
+    def __init__(self, sharenum, bucket, share_hash_tree, num_blocks):
+        self.sharenum = sharenum
+        self.bucket = bucket
+        self.share_hash_tree = share_hash_tree
+        self.block_hash_tree = chunk.IncompleteHashTree(num_blocks)
+
+    def get_block(self, blocknum):
+        d1 = self.bucket.callRemote('get_block', blocknum)
+        # we might also need to grab some elements of our block hash tree, to
+        # validate the requested block up to the share hash
+        if self.block_hash_tree.needed_hashes(leaves=[blocknum]):
+            d2 = self.bucket.callRemote('get_block_hashes')
+        else:
+            d2 = defer.succeed(None)
+        # we might need to grab some elements of the share hash tree to
+        # validate from our share hash up to the hashroot
+        if self.share_hash_tree.needed_hashes(leaves=[self.sharenum]):
+            d3 = self.bucket.callRemote('get_share_hashes')
+        else:
+            d3 = defer.succeed(None)
+        d = defer.gatherResults([d1, d2, d3])
+        d.addCallback(self._got_data, blocknum)
+        return d
+
+    def _got_data(self, res, blocknum):
+        blockdata, blockhashes, sharehashes = res
+        blockhash = hashutil.tagged_hash("encoded subshare", blockdata)
+        if blockhashes:
+            bh = dict(enumerate(blockhashes))
+            self.block_hash_tree.set_hashes(bh, {blocknum: blockhash},
+                                            must_validate=True)
+        if sharehashes:
+            sh = dict(sharehashes)
+            sharehash = self.block_hash_tree[0]
+            self.share_hash_tree.set_hashes(sh, {self.sharenum: sharehash},
+                                            must_validate=True)
+        # If we made it here, the block is good. If the hash trees didn't
+        # like what they saw, they would have raised a BadHashError, causing
+        # our caller to see a Failure and thus ignore this block (as well as
+        # dropping this bucket).
+        return blockdata
+
+
 
 class BlockDownloader:
-    def __init__(self, bucket, blocknum, parent):
-        self.bucket = bucket
+    def __init__(self, vbucket, blocknum, parent):
+        self.vbucket = vbucket
         self.blocknum = blocknum
         self.parent = parent
         
     def start(self, segnum):
-        d = self.bucket.callRemote('get_block', segnum)
+        d = self.vbucket.get_block(segnum)
         d.addCallbacks(self._hold_block, self._got_block_error)
         return d
 
@@ -64,7 +108,7 @@ class BlockDownloader:
 
     def _got_block_error(self, f):
         log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
-        self.parent.bucket_failed(self.blocknum, self.bucket)
+        self.parent.bucket_failed(self.blocknum, self.vbucket)
 
 class SegmentDownloader:
     def __init__(self, parent, segmentnumber, needed_shares):
@@ -94,86 +138,34 @@ class SegmentDownloader:
         return d
 
     def _try(self):
-        while len(self.parent.active_buckets) < self.needed_blocks:
-            # need some more
-            otherblocknums = list(set(self.parent._share_buckets.keys()) - set(self.parent.active_buckets.keys()))
-            if not otherblocknums:
-                raise NotEnoughPeersError
-            blocknum = random.choice(otherblocknums)
-            bucket = random.choice(list(self.parent._share_buckets[blocknum]))
-            self.parent.active_buckets[blocknum] = bucket
-
+        # fill our set of active buckets, maybe raising NotEnoughPeersError
+        active_buckets = self.parent._activate_enough_buckets()
         # Now we have enough buckets, in self.parent.active_buckets.
 
-        # before we get any blocks of a given share, we need to be able to
-        # validate that block and that share. Check to see if we have enough
-        # hashes. If we don't, grab them before continuing.
-        d = self._grab_needed_hashes()
-        d.addCallback(self._download_some_blocks)
-        return d
-
-    def _grab_needed_hashes(self):
-        # each bucket is holding the hashes necessary to validate their
-        # share. So it suffices to ask everybody for all the hashes they know
-        # about. Eventually we'll have all that we need, so we can stop
-        # asking.
-
-        # for each active share, see what hashes we need
-        ht = self.parent.get_share_hashtree()
-        needed_hashes = set()
-        for shnum in self.parent.active_buckets:
-            needed_hashes.update(ht.needed_hashes(shnum))
-        if not needed_hashes:
-            return defer.succeed(None)
-
-        # for now, just ask everybody for everything
-        # TODO: send fewer queries
-        dl = []
-        for shnum, bucket in self.parent.active_buckets.iteritems():
-            d = bucket.callRemote("get_share_hashes")
-            d.addCallback(self._got_share_hashes, shnum, bucket)
-            dl.append(d)
-        d.addCallback(self._validate_root)
-        return defer.DeferredList(dl)
-
-    def _got_share_hashes(self, share_hashes, shnum, bucket):
-        ht = self.parent.get_share_hashtree()
-        for hashnum, sharehash in share_hashes:
-            # TODO: we're accumulating these hashes blindly, since we only
-            # validate the leaves. This makes it possible for someone to
-            # frame another server by giving us bad internal hashes. We pass
-            # 'shnum' and 'bucket' in so that if we detected problems with
-            # intermediate nodes, we could associate the error with the
-            # bucket and stop using them.
-            ht.set_hash(hashnum, sharehash)
-
-    def _validate_root(self, res):
-        # TODO: I dunno, check that the hash tree looks good so far and that
-        # it adds up to the root. The idea is to reject any bad buckets
-        # early.
-        pass
-
-    def _download_some_blocks(self, res):
         # in test cases, bd.start might mutate active_buckets right away, so
         # we need to put off calling start() until we've iterated all the way
-        # through it
+        # through it.
         downloaders = []
-        for blocknum, bucket in self.parent.active_buckets.iteritems():
-            bd = BlockDownloader(bucket, blocknum, self)
+        for blocknum, vbucket in active_buckets.iteritems():
+            bd = BlockDownloader(vbucket, blocknum, self)
             downloaders.append(bd)
         l = [bd.start(self.segmentnumber) for bd in downloaders]
-        return defer.DeferredList(l)
+        return defer.DeferredList(l, fireOnOneErrback=True)
 
     def hold_block(self, blocknum, data):
         self.blocks[blocknum] = data
 
-    def bucket_failed(self, shnum, bucket):
+    def bucket_failed(self, shnum, vbucket):
         del self.parent.active_buckets[shnum]
         s = self.parent._share_buckets[shnum]
-        s.remove(bucket)
+        # s is a set of ValidatedBucket instances
+        s.remove(vbucket)
+        # ... which might now be empty
         if not s:
+            # there are no more buckets which can provide this share, so
+            # remove the key. This may prompt us to use a different share.
             del self.parent._share_buckets[shnum]
-        
+
 class FileDownloader:
     debug = False
 
@@ -201,30 +193,21 @@ class FileDownloader:
         key = "\x00" * 16
         self._output = Output(downloadable, key)
 
-        # future:
-        # each time we start using a new shnum, we must acquire a share hash
-        # from one of the buckets that provides that shnum, then validate it
-        # against the rest of the share hash tree that they provide. Then,
-        # each time we get a block in that share, we must validate the block
-        # against the rest of the subshare hash tree that that bucket will
-        # provide.
-
         self._share_hashtree = chunk.IncompleteHashTree(total_shares)
-        #self._block_hashtrees = {} # k: shnum, v: hashtree
+        self._share_hashtree.set_hashes({0: roothash})
 
-    def get_share_hashtree(self):
-        return self._share_hashtree
+        self.active_buckets = {} # k: shnum, v: bucket
+        self._share_buckets = {} # k: shnum, v: set of buckets
 
     def start(self):
         log.msg("starting download [%s]" % (idlib.b2a(self._verifierid),))
         if self.debug:
             print "starting download"
-        # first step: who should we download from?
-        self.active_buckets = {} # k: shnum, v: bucket
-        self._share_buckets = {} # k: shnum, v: set of buckets
 
+        # first step: who should we download from?
         d = defer.maybeDeferred(self._get_all_shareholders)
         d.addCallback(self._got_all_shareholders)
+        # once we know that, we can download blocks from them
         d.addCallback(self._download_all_segments)
         d.addCallback(self._done)
         return d
@@ -243,19 +226,55 @@ class FileDownloader:
     def _got_response(self, buckets, connection):
         _assert(isinstance(buckets, dict), buckets) # soon foolscap will check this for us with its DictOf schema constraint
         for sharenum, bucket in buckets.iteritems():
-            self._share_buckets.setdefault(sharenum, set()).add(bucket)
-        
+            self.add_share_bucket(sharenum, bucket)
+
+    def add_share_bucket(self, sharenum, bucket):
+        vbucket = ValidatedBucket(sharenum, bucket,
+                                  self._share_hashtree,
+                                  self._total_segments)
+        self._share_buckets.setdefault(sharenum, set()).add(vbucket)
+
     def _got_error(self, f):
         self._client.log("Somebody failed. -- %s" % (f,))
 
     def _got_all_shareholders(self, res):
         if len(self._share_buckets) < self._num_needed_shares:
             raise NotEnoughPeersError
+        for s in self._share_buckets.values():
+            for vb in s:
+                assert isinstance(vb, ValidatedBucket), \
+                       "vb is %s but should be a ValidatedBucket" % (vb,)
+
+
+    def _activate_enough_buckets(self):
+        """either return a mapping from shnum to a ValidatedBucket that can
+        provide data for that share, or raise NotEnoughPeersError"""
+
+        while len(self.active_buckets) < self._num_needed_shares:
+            # need some more
+            handled_shnums = set(self.active_buckets.keys())
+            available_shnums = set(self._share_buckets.keys())
+            potential_shnums = list(available_shnums - handled_shnums)
+            if not potential_shnums:
+                raise NotEnoughPeersError
+            # choose a random share
+            shnum = random.choice(potential_shnums)
+            # and a random bucket that will provide it
+            validated_bucket = random.choice(list(self._share_buckets[shnum]))
+            self.active_buckets[shnum] = validated_bucket
+        return self.active_buckets
 
-        self.active_buckets = {}
-        self._output.open()
 
     def _download_all_segments(self, res):
+        # the promise: upon entry to this function, self._share_buckets
+        # contains enough buckets to complete the download, and some extra
+        # ones to tolerate some buckets dropping out or having errors.
+        # self._share_buckets is a dictionary that maps from shnum to a set
+        # of ValidatedBuckets, which themselves are wrappers around
+        # RIBucketReader references.
+        self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
+        self._output.open()
+
         d = defer.succeed(None)
         for segnum in range(self._total_segments-1):
             d.addCallback(self._download_segment, segnum)
diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py
index 6372af89..a58351c7 100644
--- a/src/allmydata/test/test_encode.py
+++ b/src/allmydata/test/test_encode.py
@@ -156,9 +156,9 @@ class Roundtrip(unittest.TestCase):
             client = None
             target = download.Data()
             fd = download.FileDownloader(client, URI, target)
-            fd._share_buckets = {}
             for shnum in range(NUM_SHARES):
-                fd._share_buckets[shnum] = set([all_shareholders[shnum]])
+                bucket = all_shareholders[shnum]
+                fd.add_share_bucket(shnum, bucket)
             fd._got_all_shareholders(None)
             d2 = fd._download_all_segments(None)
             d2.addCallback(fd._done)
diff --git a/src/allmydata/test/test_hashtree.py b/src/allmydata/test/test_hashtree.py
index 5cac2cc1..1a51c6c6 100644
--- a/src/allmydata/test/test_hashtree.py
+++ b/src/allmydata/test/test_hashtree.py
@@ -21,7 +21,24 @@ class Complete(unittest.TestCase):
         root = ht[0]
         self.failUnlessEqual(len(root), 32)
 
+    def testDump(self):
+        ht = make_tree(6)
+        expected = [(0,0),
+                    (1,1), (3,2), (7,3), (8,3), (4,2), (9,3), (10,3),
+                    (2,1), (5,2), (11,3), (12,3), (6,2), (13,3), (14,3),
+                    ]
+        self.failUnlessEqual(list(ht.depth_first()), expected)
+        d = "\n" + ht.dump()
+        #print d
+        self.failUnless("\n  0:" in d)
+        self.failUnless("\n    1:" in d)
+        self.failUnless("\n      3:" in d)
+        self.failUnless("\n        7:" in d)
+        self.failUnless("\n        8:" in d)
+        self.failUnless("\n      4:" in d)
+
 class Incomplete(unittest.TestCase):
+
     def testCheck(self):
         # first create a complete hash tree
         ht = make_tree(6)
@@ -30,32 +47,35 @@ class Incomplete(unittest.TestCase):
 
         # suppose we wanted to validate leaf[0]
         #  leaf[0] is the same as node[7]
-        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2, 0]))
-        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2, 0]))
-        iht.set_hash(0, ht[0])
-        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
-        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2]))
-        iht.set_hash(5, ht[5])
-        self.failUnlessEqual(iht.needed_hashes(0), set([8, 4, 2]))
-        self.failUnlessEqual(iht.needed_hashes(1), set([7, 4, 2]))
-
+        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2, 0]))
+        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2, 0]))
+        iht.set_hashes({0: ht[0]}) # set the root
+        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2]))
+        iht.set_hashes({5: ht[5]})
+        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set([8, 4, 2]))
+        self.failUnlessEqual(iht.needed_hashes(leaves=[1]), set([7, 4, 2]))
+
+        current_hashes = list(iht)
         try:
             # this should fail because there aren't enough hashes known
-            iht.set_leaf(0, tagged_hash("tag", "0"))
+            iht.set_hashes(leaves={0: tagged_hash("tag", "0")},
+                           must_validate=True)
         except chunk.NotEnoughHashesError:
             pass
         else:
             self.fail("didn't catch not enough hashes")
 
+        # and the set of hashes stored in the tree should still be the same
+        self.failUnlessEqual(list(iht), current_hashes)
+
         # provide the missing hashes
-        iht.set_hash(2, ht[2])
-        iht.set_hash(4, ht[4])
-        iht.set_hash(8, ht[8])
-        self.failUnlessEqual(iht.needed_hashes(0), set([]))
+        iht.set_hashes({2: ht[2], 4: ht[4], 8: ht[8]})
+        self.failUnlessEqual(iht.needed_hashes(leaves=[0]), set())
 
         try:
             # this should fail because the hash is just plain wrong
-            iht.set_leaf(0, tagged_hash("bad tag", "0"))
+            iht.set_hashes(leaves={0: tagged_hash("bad tag", "0")})
         except chunk.BadHashError:
             pass
         else:
@@ -63,34 +83,33 @@ class Incomplete(unittest.TestCase):
 
         try:
             # this should succeed
-            iht.set_leaf(0, tagged_hash("tag", "0"))
+            iht.set_hashes(leaves={0: tagged_hash("tag", "0")})
         except chunk.BadHashError, e:
             self.fail("bad hash: %s" % e)
 
         try:
             # this should succeed too
-            iht.set_leaf(1, tagged_hash("tag", "1"))
+            iht.set_hashes(leaves={1: tagged_hash("tag", "1")})
         except chunk.BadHashError:
             self.fail("bad hash")
 
         # giving it a bad internal hash should also cause problems
-        iht.set_hash(2, tagged_hash("bad tag", "x"))
+        iht.set_hashes({13: tagged_hash("bad tag", "x")})
         try:
-            iht.set_leaf(0, tagged_hash("tag", "0"))
+            iht.set_hashes({14: tagged_hash("tag", "14")})
         except chunk.BadHashError:
             pass
         else:
             self.fail("didn't catch bad hash")
         # undo our damage
-        iht.set_hash(2, ht[2])
+        iht[13] = None
 
-        self.failUnlessEqual(iht.needed_hashes(4), set([12, 6]))
+        self.failUnlessEqual(iht.needed_hashes(leaves=[4]), set([12, 6]))
 
-        iht.set_hash(6, ht[6])
-        iht.set_hash(12, ht[12])
+        iht.set_hashes({6: ht[6], 12: ht[12]})
         try:
             # this should succeed
-            iht.set_leaf(4, tagged_hash("tag", "4"))
+            iht.set_hashes(leaves={4: tagged_hash("tag", "4")})
         except chunk.BadHashError, e:
             self.fail("bad hash: %s" % e)