From 466014f66fbccff7668a270eafa1d372c5e11457 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Tue, 31 Mar 2009 13:21:27 -0700
Subject: [PATCH] hashtree: fix O(N**2) behavior, to improve fatal alacrity
 problems in a 10GB file (#670). Also improve docstring.

---
 src/allmydata/hashtree.py           | 172 ++++++++++++++++------------
 src/allmydata/test/test_hashtree.py |  40 +++++++
 2 files changed, 139 insertions(+), 73 deletions(-)

diff --git a/src/allmydata/hashtree.py b/src/allmydata/hashtree.py
index 20d2afc7..20b97d35 100644
--- a/src/allmydata/hashtree.py
+++ b/src/allmydata/hashtree.py
@@ -162,6 +162,15 @@ class CompleteBinaryTreeMixin:
     def get_leaf(self, leafnum):
         return self[self.first_leaf_num + leafnum]
 
+    def depth_of(self, i):
+        """Return the depth or level of the given node. Level 0 contains node
+        Level 1 contains nodes 1 and 2. Level 2 contains nodes 3,4,5,6."""
+        depth = 0
+        while i != 0:
+            depth += 1
+            i = self.parent(i)
+        return depth
+
 def empty_leaf_hash(i):
     return tagged_hash('Merkle tree empty leaf', "%d" % i)
 def pair_hash(a, b):
@@ -337,27 +346,30 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
         from 0 (the root of the tree) to 2*num_leaves-2 (the right-most
         leaf). leaf[i] is the same as hash[num_leaves-1+i].
 
-        The best way to use me is to obtain the root hash from some 'good'
-        channel, and use the 'bad' channel to obtain data block 0 and the
+        The best way to use me is to start by obtaining the root hash from
+        some 'good' channel and populate me with it:
+
+         iht = IncompleteHashTree(numleaves)
+         roothash = trusted_channel.get_roothash()
+         iht.set_hashes(hashes={0: roothash})
+
+        Then use the 'bad' channel to obtain data block 0 and the
         corresponding hash chain (a dict with the same hashes that
         needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, 8:h8} when
         len(L)=8). Hash the data block to create leaf0, then feed everything
         into set_hashes() and see if it raises an exception or not::
 
-          iht = IncompleteHashTree(numleaves)
-          roothash = trusted_channel.get_roothash()
-          otherhashes = untrusted_channel.get_hashes()
-          # otherhashes.keys() should == iht.needed_hashes(leaves=[0])
-          datablock0 = untrusted_channel.get_data(0)
-          leaf0 = HASH(datablock0)
-          # HASH() is probably hashutil.tagged_hash(tag, datablock0)
-          hashes = otherhashes.copy()
-          hashes[0] = roothash # from 'good' channel
-          iht.set_hashes(hashes, leaves={0: leaf0})
+         otherhashes = untrusted_channel.get_hashes()
+         # otherhashes.keys() should == iht.needed_hashes(leaves=[0])
+         datablock0 = untrusted_channel.get_data(0)
+         leaf0 = HASH(datablock0)
+         # HASH() is probably hashutil.tagged_hash(tag, datablock0)
+         iht.set_hashes(otherhashes, leaves={0: leaf0})
 
         If the set_hashes() call doesn't raise an exception, the data block
         was valid. If it raises BadHashError, then either the data block was
-        corrupted or one of the received hashes was corrupted.
+        corrupted or one of the received hashes was corrupted. If it raises
+        NotEnoughHashesError, then the otherhashes dictionary was incomplete.
         """
 
         assert isinstance(hashes, dict)
@@ -376,73 +388,87 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
                                        % (leafnum, hashnum))
             new_hashes[hashnum] = leafhash
 
-        added = set() # we'll remove these if the check fails
+        remove_upon_failure = set() # we'll remove these if the check fails
+
+        # visualize this method in the following way:
+        #  A: start with the empty or partially-populated tree as shown in
+        #     the HashTree docstring
+        #  B: add all of our input hashes to the tree, filling in some of the
+        #     holes. Don't overwrite anything, but new values must equal the
+        #     existing ones. Mark everything that was added with a red dot
+        #     (meaning "not yet validated")
+        #  C: start with the lowest/deepest level. Pick any red-dotted node,
+        #     hash it with its sibling to compute the parent hash. Add the
+        #     parent to the tree just like in step B (if the parent already
+        #     exists, the values must be equal; if not, add our computed
+        #     value with a red dot). If we have no sibling, throw
+        #     NotEnoughHashesError, since we won't be able to validate this
+        #     node. Remove the red dot. If there was a red dot on our
+        #     sibling, remove it too.
+        #  D: finish all red-dotted nodes in one level before moving up to
+        #     the next.
+        #  E: if we hit NotEnoughHashesError or BadHashError before getting
+        #     to the root, discard every hash we've added.
 
         try:
+            num_levels = self.depth_of(len(self)-1)
+            # hashes_to_check[level] is set(index). This holds the "red dots"
+            # described above
+            hashes_to_check = [set() for level in range(num_levels+1)]
+
             # first we provisionally add all hashes to the tree, comparing
             # any duplicates
-            for i in new_hashes:
+            for i,h in new_hashes.iteritems():
+                level = self.depth_of(i)
+                hashes_to_check[level].add(i)
+
                 if self[i]:
-                    if self[i] != new_hashes[i]:
-                        msg = "new hash %s does not match existing hash %s at " % (base32.b2a(new_hashes[i]), base32.b2a(self[i]))
-                        msg += self._name_hash(i)
-                        raise BadHashError(msg)
-                else:
-                    self[i] = new_hashes[i]
-                    added.add(i)
-
-            # then we start from the bottom and compute new parent hashes
-            # upwards, comparing any that already exist. When this phase
-            # ends, all nodes that have a sibling will also have a parent.
-
-            hashes_to_check = list(new_hashes.keys())
-            # leaf-most first means reverse sorted order
-            while hashes_to_check:
-                hashes_to_check.sort()
-                i = hashes_to_check.pop(-1)
-                if i == 0:
-                    # The root has no sibling. How lonely.
-                    continue
-                if self[self.sibling(i)] is None:
-                    # without a sibling, we can't compute a parent
-                    continue
-                parentnum = self.parent(i)
-                # make sure we know right from left
-                leftnum, rightnum = sorted([i, self.sibling(i)])
-                new_parent_hash = pair_hash(self[leftnum], self[rightnum])
-                if self[parentnum]:
-                    if self[parentnum] != new_parent_hash:
-                        raise BadHashError("h([%d]+[%d]) != h[%d]" %
-                                           (leftnum, rightnum, parentnum))
+                    if self[i] != h:
+                        raise BadHashError("new hash %s does not match "
+                                           "existing hash %s at %s"
+                                           % (base32.b2a(h),
+                                              base32.b2a(self[i]),
+                                              self._name_hash(i)))
                 else:
-                    self[parentnum] = new_parent_hash
-                    added.add(parentnum)
-                    hashes_to_check.insert(0, parentnum)
-
-            # then we walk downwards from the top (root), and anything that
-            # is reachable is validated. If any of the hashes that we've
-            # added are unreachable, then they are unvalidated.
-
-            reachable = set()
-            if self[0]:
-                reachable.add(0)
-            # TODO: this could be done more efficiently, by starting from
-            # each element of new_hashes and walking upwards instead,
-            # remembering a set of validated nodes so that the searches for
-            # later new_hashes goes faster. This approach is O(n), whereas
-            # O(ln(n)) should be feasible.
-            for i in range(1, len(self)):
-                if self[i] and self.parent(i) in reachable:
-                    reachable.add(i)
-
-            # were we unable to validate any of the new hashes?
-            unvalidated = set(new_hashes.keys()) - reachable
-            if unvalidated:
-                those = ",".join([str(i) for i in sorted(unvalidated)])
-                raise NotEnoughHashesError("unable to validate hashes %s"
-                                           % those)
+                    self[i] = h
+                    remove_upon_failure.add(i)
+
+            for level in reversed(range(len(hashes_to_check))):
+                this_level = hashes_to_check[level]
+                while this_level:
+                    i = this_level.pop()
+                    if i == 0:
+                        # The root has no sibling. How lonely. TODO: consider
+                        # setting the root in our constructor, then throw
+                        # NotEnoughHashesError here, because if we've
+                        # generated the root from below, we don't have
+                        # anything to validate it against.
+                        continue
+                    siblingnum = self.sibling(i)
+                    if self[siblingnum] is None:
+                        # without a sibling, we can't compute a parent, and
+                        # we can't verify this node
+                        raise NotEnoughHashesError("unable to validate [%d]"%i)
+                    parentnum = self.parent(i)
+                    # make sure we know right from left
+                    leftnum, rightnum = sorted([i, siblingnum])
+                    new_parent_hash = pair_hash(self[leftnum], self[rightnum])
+                    if self[parentnum]:
+                        if self[parentnum] != new_parent_hash:
+                            raise BadHashError("h([%d]+[%d]) != h[%d]" %
+                                               (leftnum, rightnum, parentnum))
+                    else:
+                        self[parentnum] = new_parent_hash
+                        remove_upon_failure.add(parentnum)
+                        parent_level = self.depth_of(parentnum)
+                        assert parent_level == level-1
+                        hashes_to_check[parent_level].add(parentnum)
+
+                    # our sibling is now as valid as this node
+                    this_level.discard(siblingnum)
+            # we're done!
 
         except (BadHashError, NotEnoughHashesError):
-            for i in added:
+            for i in remove_upon_failure:
                 self[i] = None
             raise
diff --git a/src/allmydata/test/test_hashtree.py b/src/allmydata/test/test_hashtree.py
index 5388c6e9..75b4c933 100644
--- a/src/allmydata/test/test_hashtree.py
+++ b/src/allmydata/test/test_hashtree.py
@@ -80,6 +80,46 @@ class Incomplete(unittest.TestCase):
         self.failUnlessEqual(ht.needed_hashes(5, False), set([11, 6, 1]))
         self.failUnlessEqual(ht.needed_hashes(5, True), set([12, 11, 6, 1]))
 
+    def test_depth_of(self):
+        ht = hashtree.IncompleteHashTree(8)
+        self.failUnlessEqual(ht.depth_of(0), 0)
+        for i in [1,2]:
+            self.failUnlessEqual(ht.depth_of(i), 1, "i=%d"%i)
+        for i in [3,4,5,6]:
+            self.failUnlessEqual(ht.depth_of(i), 2, "i=%d"%i)
+        for i in [7,8,9,10,11,12,13,14]:
+            self.failUnlessEqual(ht.depth_of(i), 3, "i=%d"%i)
+        self.failUnlessRaises(IndexError, ht.depth_of, 15)
+
+    def test_large(self):
+        # IncompleteHashTree.set_hashes() used to take O(N**2). This test is
+        # meant to show that it now takes O(N) or maybe O(N*ln(N)). I wish
+        # there were a good way to assert this (like counting VM operations
+        # or something): the problem was inside list.sort(), so there's no
+        # good way to instrument set_hashes() to count what we care about. On
+        # my laptop, 10k leaves takes 1.1s in this fixed version, and 11.6s
+        # in the old broken version. An 80k-leaf test (corresponding to a
+        # 10GB file with a 128KiB segsize) 10s in the fixed version, and
+        # several hours in the broken version, but 10s on my laptop (plus the
+        # 20s of setup code) probably means 200s on our dapper buildslave,
+        # which is painfully long for a unit test.
+        self.do_test_speed(10000)
+
+    def do_test_speed(self, SIZE):
+        # on my laptop, SIZE=80k (corresponding to a 10GB file with a 128KiB
+        # segsize) takes:
+        #  7s to build the (complete) HashTree
+        #  13s to set up the dictionary
+        #  10s to run set_hashes()
+        ht = make_tree(SIZE)
+        iht = hashtree.IncompleteHashTree(SIZE)
+
+        needed = set()
+        for i in range(SIZE):
+            needed.update(ht.needed_hashes(i, True))
+        all = dict([ (i, ht[i]) for i in needed])
+        iht.set_hashes(hashes=all)
+
     def test_check(self):
         # first create a complete hash tree
         ht = make_tree(6)
-- 
2.45.2