def get_leaf(self, leafnum):
return self[self.first_leaf_num + leafnum]
+ def depth_of(self, i):
+ """Return the depth or level of the given node. Level 0 contains node
+ Level 1 contains nodes 1 and 2. Level 2 contains nodes 3,4,5,6."""
+ depth = 0
+ while i != 0:
+ depth += 1
+ i = self.parent(i)
+ return depth
+
def empty_leaf_hash(i):
return tagged_hash('Merkle tree empty leaf', "%d" % i)
def pair_hash(a, b):
from 0 (the root of the tree) to 2*num_leaves-2 (the right-most
leaf). leaf[i] is the same as hash[num_leaves-1+i].
- The best way to use me is to obtain the root hash from some 'good'
- channel, and use the 'bad' channel to obtain data block 0 and the
+ The best way to use me is to start by obtaining the root hash from
+ some 'good' channel and populate me with it:
+
+ iht = IncompleteHashTree(numleaves)
+ roothash = trusted_channel.get_roothash()
+ iht.set_hashes(hashes={0: roothash})
+
+ Then use the 'bad' channel to obtain data block 0 and the
corresponding hash chain (a dict with the same hashes that
needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, 8:h8} when
len(L)=8). Hash the data block to create leaf0, then feed everything
into set_hashes() and see if it raises an exception or not::
- iht = IncompleteHashTree(numleaves)
- roothash = trusted_channel.get_roothash()
- otherhashes = untrusted_channel.get_hashes()
- # otherhashes.keys() should == iht.needed_hashes(leaves=[0])
- datablock0 = untrusted_channel.get_data(0)
- leaf0 = HASH(datablock0)
- # HASH() is probably hashutil.tagged_hash(tag, datablock0)
- hashes = otherhashes.copy()
- hashes[0] = roothash # from 'good' channel
- iht.set_hashes(hashes, leaves={0: leaf0})
+ otherhashes = untrusted_channel.get_hashes()
+ # otherhashes.keys() should == iht.needed_hashes(leaves=[0])
+ datablock0 = untrusted_channel.get_data(0)
+ leaf0 = HASH(datablock0)
+ # HASH() is probably hashutil.tagged_hash(tag, datablock0)
+ iht.set_hashes(otherhashes, leaves={0: leaf0})
If the set_hashes() call doesn't raise an exception, the data block
was valid. If it raises BadHashError, then either the data block was
- corrupted or one of the received hashes was corrupted.
+ corrupted or one of the received hashes was corrupted. If it raises
+ NotEnoughHashesError, then the otherhashes dictionary was incomplete.
"""
assert isinstance(hashes, dict)
% (leafnum, hashnum))
new_hashes[hashnum] = leafhash
- added = set() # we'll remove these if the check fails
+ remove_upon_failure = set() # we'll remove these if the check fails
+
+ # visualize this method in the following way:
+ # A: start with the empty or partially-populated tree as shown in
+ # the HashTree docstring
+ # B: add all of our input hashes to the tree, filling in some of the
+ # holes. Don't overwrite anything, but new values must equal the
+ # existing ones. Mark everything that was added with a red dot
+ # (meaning "not yet validated")
+ # C: start with the lowest/deepest level. Pick any red-dotted node,
+ # hash it with its sibling to compute the parent hash. Add the
+ # parent to the tree just like in step B (if the parent already
+ # exists, the values must be equal; if not, add our computed
+ # value with a red dot). If we have no sibling, throw
+ # NotEnoughHashesError, since we won't be able to validate this
+ # node. Remove the red dot. If there was a red dot on our
+ # sibling, remove it too.
+ # D: finish all red-dotted nodes in one level before moving up to
+ # the next.
+ # E: if we hit NotEnoughHashesError or BadHashError before getting
+ # to the root, discard every hash we've added.
try:
+ num_levels = self.depth_of(len(self)-1)
+ # hashes_to_check[level] is set(index). This holds the "red dots"
+ # described above
+ hashes_to_check = [set() for level in range(num_levels+1)]
+
# first we provisionally add all hashes to the tree, comparing
# any duplicates
- for i in new_hashes:
+ for i,h in new_hashes.iteritems():
+ level = self.depth_of(i)
+ hashes_to_check[level].add(i)
+
if self[i]:
- if self[i] != new_hashes[i]:
- msg = "new hash %s does not match existing hash %s at " % (base32.b2a(new_hashes[i]), base32.b2a(self[i]))
- msg += self._name_hash(i)
- raise BadHashError(msg)
- else:
- self[i] = new_hashes[i]
- added.add(i)
-
- # then we start from the bottom and compute new parent hashes
- # upwards, comparing any that already exist. When this phase
- # ends, all nodes that have a sibling will also have a parent.
-
- hashes_to_check = list(new_hashes.keys())
- # leaf-most first means reverse sorted order
- while hashes_to_check:
- hashes_to_check.sort()
- i = hashes_to_check.pop(-1)
- if i == 0:
- # The root has no sibling. How lonely.
- continue
- if self[self.sibling(i)] is None:
- # without a sibling, we can't compute a parent
- continue
- parentnum = self.parent(i)
- # make sure we know right from left
- leftnum, rightnum = sorted([i, self.sibling(i)])
- new_parent_hash = pair_hash(self[leftnum], self[rightnum])
- if self[parentnum]:
- if self[parentnum] != new_parent_hash:
- raise BadHashError("h([%d]+[%d]) != h[%d]" %
- (leftnum, rightnum, parentnum))
+ if self[i] != h:
+ raise BadHashError("new hash %s does not match "
+ "existing hash %s at %s"
+ % (base32.b2a(h),
+ base32.b2a(self[i]),
+ self._name_hash(i)))
else:
- self[parentnum] = new_parent_hash
- added.add(parentnum)
- hashes_to_check.insert(0, parentnum)
-
- # then we walk downwards from the top (root), and anything that
- # is reachable is validated. If any of the hashes that we've
- # added are unreachable, then they are unvalidated.
-
- reachable = set()
- if self[0]:
- reachable.add(0)
- # TODO: this could be done more efficiently, by starting from
- # each element of new_hashes and walking upwards instead,
- # remembering a set of validated nodes so that the searches for
- # later new_hashes goes faster. This approach is O(n), whereas
- # O(ln(n)) should be feasible.
- for i in range(1, len(self)):
- if self[i] and self.parent(i) in reachable:
- reachable.add(i)
-
- # were we unable to validate any of the new hashes?
- unvalidated = set(new_hashes.keys()) - reachable
- if unvalidated:
- those = ",".join([str(i) for i in sorted(unvalidated)])
- raise NotEnoughHashesError("unable to validate hashes %s"
- % those)
+ self[i] = h
+ remove_upon_failure.add(i)
+
+ for level in reversed(range(len(hashes_to_check))):
+ this_level = hashes_to_check[level]
+ while this_level:
+ i = this_level.pop()
+ if i == 0:
+ # The root has no sibling. How lonely. TODO: consider
+ # setting the root in our constructor, then throw
+ # NotEnoughHashesError here, because if we've
+ # generated the root from below, we don't have
+ # anything to validate it against.
+ continue
+ siblingnum = self.sibling(i)
+ if self[siblingnum] is None:
+ # without a sibling, we can't compute a parent, and
+ # we can't verify this node
+ raise NotEnoughHashesError("unable to validate [%d]"%i)
+ parentnum = self.parent(i)
+ # make sure we know right from left
+ leftnum, rightnum = sorted([i, siblingnum])
+ new_parent_hash = pair_hash(self[leftnum], self[rightnum])
+ if self[parentnum]:
+ if self[parentnum] != new_parent_hash:
+ raise BadHashError("h([%d]+[%d]) != h[%d]" %
+ (leftnum, rightnum, parentnum))
+ else:
+ self[parentnum] = new_parent_hash
+ remove_upon_failure.add(parentnum)
+ parent_level = self.depth_of(parentnum)
+ assert parent_level == level-1
+ hashes_to_check[parent_level].add(parentnum)
+
+ # our sibling is now as valid as this node
+ this_level.discard(siblingnum)
+ # we're done!
except (BadHashError, NotEnoughHashesError):
- for i in added:
+ for i in remove_upon_failure:
self[i] = None
raise
self.failUnlessEqual(ht.needed_hashes(5, False), set([11, 6, 1]))
self.failUnlessEqual(ht.needed_hashes(5, True), set([12, 11, 6, 1]))
+ def test_depth_of(self):
+ ht = hashtree.IncompleteHashTree(8)
+ self.failUnlessEqual(ht.depth_of(0), 0)
+ for i in [1,2]:
+ self.failUnlessEqual(ht.depth_of(i), 1, "i=%d"%i)
+ for i in [3,4,5,6]:
+ self.failUnlessEqual(ht.depth_of(i), 2, "i=%d"%i)
+ for i in [7,8,9,10,11,12,13,14]:
+ self.failUnlessEqual(ht.depth_of(i), 3, "i=%d"%i)
+ self.failUnlessRaises(IndexError, ht.depth_of, 15)
+
+ def test_large(self):
+ # IncompleteHashTree.set_hashes() used to take O(N**2). This test is
+ # meant to show that it now takes O(N) or maybe O(N*ln(N)). I wish
+ # there were a good way to assert this (like counting VM operations
+ # or something): the problem was inside list.sort(), so there's no
+ # good way to instrument set_hashes() to count what we care about. On
+ # my laptop, 10k leaves takes 1.1s in this fixed version, and 11.6s
+ # in the old broken version. An 80k-leaf test (corresponding to a
+ # 10GB file with a 128KiB segsize) 10s in the fixed version, and
+ # several hours in the broken version, but 10s on my laptop (plus the
+ # 20s of setup code) probably means 200s on our dapper buildslave,
+ # which is painfully long for a unit test.
+ self.do_test_speed(10000)
+
+ def do_test_speed(self, SIZE):
+ # on my laptop, SIZE=80k (corresponding to a 10GB file with a 128KiB
+ # segsize) takes:
+ # 7s to build the (complete) HashTree
+ # 13s to set up the dictionary
+ # 10s to run set_hashes()
+ ht = make_tree(SIZE)
+ iht = hashtree.IncompleteHashTree(SIZE)
+
+ needed = set()
+ for i in range(SIZE):
+ needed.update(ht.needed_hashes(i, True))
+ all = dict([ (i, ht[i]) for i in needed])
+ iht.set_hashes(hashes=all)
+
def test_check(self):
# first create a complete hash tree
ht = make_tree(6)