From a7aa6f868675a1c7f58fb152410c9609c33361ee Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Mon, 7 Jul 2008 17:36:00 -0700
Subject: [PATCH] implement a mutable checker+verifier. No repair yet. Part of
 #205.

---
 src/allmydata/mutable/checker.py   | 148 +++++++++++++++++++++++++++++
 src/allmydata/mutable/node.py      |   7 +-
 src/allmydata/test/test_mutable.py | 137 +++++++++++++++++++++++++-
 3 files changed, 285 insertions(+), 7 deletions(-)
 create mode 100644 src/allmydata/mutable/checker.py

diff --git a/src/allmydata/mutable/checker.py b/src/allmydata/mutable/checker.py
new file mode 100644
index 00000000..22d1db9b
--- /dev/null
+++ b/src/allmydata/mutable/checker.py
@@ -0,0 +1,148 @@
+
+import struct
+from twisted.internet import defer
+from twisted.python import failure
+from allmydata import hashtree
+from allmydata.util import hashutil
+
+from common import MODE_CHECK, CorruptShareError
+from servermap import ServerMap, ServermapUpdater
+from layout import unpack_share, SIGNED_PREFIX
+
+class MutableChecker:
+
+    def __init__(self, node):
+        self._node = node
+        self.healthy = True
+        self.problems = []
+        self._storage_index = self._node.get_storage_index()
+
+    def check(self, verify=False, repair=False):
+        servermap = ServerMap()
+        self.do_verify = verify
+        self.do_repair = repair
+        u = ServermapUpdater(self._node, servermap, MODE_CHECK)
+        d = u.update()
+        d.addCallback(self._got_mapupdate_results)
+        if verify:
+            d.addCallback(self._verify_all_shares)
+        d.addCallback(self._maybe_do_repair)
+        d.addCallback(self._return_results)
+        return d
+
+    def _got_mapupdate_results(self, servermap):
+        # the file is healthy if there is exactly one recoverable version, it
+        # has at least N distinct shares, and there are no unrecoverable
+        # versions: all existing shares will be for the same version.
+        self.best_version = None
+        if servermap.unrecoverable_versions():
+            self.healthy = False
+        num_recoverable = len(servermap.recoverable_versions())
+        if num_recoverable == 0:
+            self.healthy = False
+        else:
+            if num_recoverable > 1:
+                self.healthy = False
+            self.best_version = servermap.best_recoverable_version()
+            available_shares = servermap.shares_available()
+            (num_distinct_shares, k, N) = available_shares[self.best_version]
+            if num_distinct_shares < N:
+                self.healthy = False
+
+        return servermap
+
+    def _verify_all_shares(self, servermap):
+        # read every byte of each share
+        if not self.best_version:
+            return
+        versionmap = servermap.make_versionmap()
+        shares = versionmap[self.best_version]
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.best_version
+        offsets = dict(offsets_tuple)
+        readv = [ (0, offsets["EOF"]) ]
+        dl = []
+        for (shnum, peerid, timestamp) in shares:
+            ss = servermap.connections[peerid]
+            d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
+            d.addCallback(self._got_answer, peerid)
+            dl.append(d)
+        return defer.DeferredList(dl, fireOnOneErrback=True)
+
+    def _do_read(self, ss, peerid, storage_index, shnums, readv):
+        # isolate the callRemote to a separate method, so tests can subclass
+        # Publish and override it
+        d = ss.callRemote("slot_readv", storage_index, shnums, readv)
+        return d
+
+    def _got_answer(self, datavs, peerid):
+        for shnum,datav in datavs.items():
+            data = datav[0]
+            try:
+                self._got_results_one_share(shnum, peerid, data)
+            except CorruptShareError:
+                f = failure.Failure()
+                self.add_problem(shnum, peerid, f)
+
+    def check_prefix(self, peerid, shnum, data):
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.best_version
+        got_prefix = data[:struct.calcsize(SIGNED_PREFIX)]
+        if got_prefix != prefix:
+            raise CorruptShareError(peerid, shnum,
+                                    "prefix mismatch: share changed while we were reading it")
+
+    def _got_results_one_share(self, shnum, peerid, data):
+        self.check_prefix(peerid, shnum, data)
+
+        # the [seqnum:signature] pieces are validated by _compare_prefix,
+        # which checks their signature against the pubkey known to be
+        # associated with this file.
+
+        (seqnum, root_hash, IV, k, N, segsize, datalen, pubkey, signature,
+         share_hash_chain, block_hash_tree, share_data,
+         enc_privkey) = unpack_share(data)
+
+        # validate [share_hash_chain,block_hash_tree,share_data]
+
+        leaves = [hashutil.block_hash(share_data)]
+        t = hashtree.HashTree(leaves)
+        if list(t) != block_hash_tree:
+            raise CorruptShareError(peerid, shnum, "block hash tree failure")
+        share_hash_leaf = t[0]
+        t2 = hashtree.IncompleteHashTree(N)
+        # root_hash was checked by the signature
+        t2.set_hashes({0: root_hash})
+        try:
+            t2.set_hashes(hashes=share_hash_chain,
+                          leaves={shnum: share_hash_leaf})
+        except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
+                IndexError), e:
+            msg = "corrupt hashes: %s" % (e,)
+            raise CorruptShareError(peerid, shnum, msg)
+
+        # validate enc_privkey: only possible if we have a write-cap
+        if not self._node.is_readonly():
+            alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
+            alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
+            if alleged_writekey != self._node.get_writekey():
+                raise CorruptShareError(peerid, shnum, "invalid privkey")
+
+    def _maybe_do_repair(self, res):
+        if self.healthy:
+            return
+        if not self.do_repair:
+            return
+        pass
+
+    def _return_results(self, res):
+        r = {}
+        r['healthy'] = self.healthy
+        r['problems'] = self.problems
+        return r
+
+
+    def add_problem(self, shnum, peerid, what):
+        self.healthy = False
+        self.problems.append( (peerid, self._storage_index, shnum, what) )
+
diff --git a/src/allmydata/mutable/node.py b/src/allmydata/mutable/node.py
index b54bddb4..d7b7a6d0 100644
--- a/src/allmydata/mutable/node.py
+++ b/src/allmydata/mutable/node.py
@@ -19,6 +19,7 @@ from common import MODE_READ, MODE_WRITE, UnrecoverableFileError, \
      ResponseCache, UncoordinatedWriteError
 from servermap import ServerMap, ServermapUpdater
 from retrieve import Retrieve
+from checker import MutableChecker
 
 
 class BackoffAgent:
@@ -235,9 +236,9 @@ class MutableFileNode:
 
     #################################
 
-    def check(self):
-        verifier = self.get_verifier()
-        return self._client.getServiceNamed("checker").check(verifier)
+    def check(self, verify=False, repair=False):
+        checker = MutableChecker(self)
+        return checker.check(verify, repair)
 
     # allow the use of IDownloadTarget
     def download(self, target):
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index c431a8be..25215c96 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -20,7 +20,7 @@ from allmydata.mutable.node import MutableFileNode, BackoffAgent
 from allmydata.mutable.common import DictOfSets, ResponseCache, \
      MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
      NeedMoreDataError, UnrecoverableFileError, UncoordinatedWriteError, \
-     NotEnoughServersError
+     NotEnoughServersError, CorruptShareError
 from allmydata.mutable.retrieve import Retrieve
 from allmydata.mutable.publish import Publish
 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
@@ -223,7 +223,7 @@ def flip_bit(original, byte_offset):
             chr(ord(original[byte_offset]) ^ 0x01) +
             original[byte_offset+1:])
 
-def corrupt(res, s, offset, shnums_to_corrupt=None):
+def corrupt(res, s, offset, shnums_to_corrupt=None, offset_offset=0):
     # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a
     # list of shnums to corrupt.
     for peerid in s._peers:
@@ -250,7 +250,7 @@ def corrupt(res, s, offset, shnums_to_corrupt=None):
                 real_offset = o[offset1]
             else:
                 real_offset = offset1
-            real_offset = int(real_offset) + offset2
+            real_offset = int(real_offset) + offset2 + offset_offset
             assert isinstance(real_offset, int), offset
             shares[shnum] = flip_bit(data, real_offset)
     return res
@@ -1119,6 +1119,131 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
         return d
 
 
+class CheckerMixin:
+    def check_good(self, r, where):
+        self.failUnless(r['healthy'], where)
+        self.failIf(r['problems'], where)
+        return r
+
+    def check_bad(self, r, where):
+        self.failIf(r['healthy'], where)
+        return r
+
+    def check_expected_failure(self, r, expected_exception, substring, where):
+        for (peerid, storage_index, shnum, f) in r['problems']:
+            if f.check(expected_exception):
+                self.failUnless(substring in str(f),
+                                "%s: substring '%s' not in '%s'" %
+                                (where, substring, str(f)))
+                return
+        self.fail("%s: didn't see expected exception %s in problems %s" %
+                  (where, expected_exception, r['problems']))
+
+
+class Checker(unittest.TestCase, CheckerMixin):
+    def setUp(self):
+        # publish a file and create shares, which can then be manipulated
+        # later.
+        self.CONTENTS = "New contents go here" * 1000
+        num_peers = 20
+        self._client = FakeClient(num_peers)
+        self._storage = self._client._storage
+        d = self._client.create_mutable_file(self.CONTENTS)
+        def _created(node):
+            self._fn = node
+        d.addCallback(_created)
+        return d
+
+
+    def test_check_good(self):
+        d = self._fn.check()
+        d.addCallback(self.check_good, "test_check_good")
+        return d
+
+    def test_check_no_shares(self):
+        for shares in self._storage._peers.values():
+            shares.clear()
+        d = self._fn.check()
+        d.addCallback(self.check_bad, "test_check_no_shares")
+        return d
+
+    def test_check_not_enough_shares(self):
+        for shares in self._storage._peers.values():
+            for shnum in shares.keys():
+                if shnum > 0:
+                    del shares[shnum]
+        d = self._fn.check()
+        d.addCallback(self.check_bad, "test_check_not_enough_shares")
+        return d
+
+    def test_check_all_bad_sig(self):
+        corrupt(None, self._storage, 1) # bad sig
+        d = self._fn.check()
+        d.addCallback(self.check_bad, "test_check_all_bad_sig")
+        return d
+
+    def test_check_all_bad_blocks(self):
+        corrupt(None, self._storage, "share_data", [9]) # bad blocks
+        # the Checker won't notice this.. it doesn't look at actual data
+        d = self._fn.check()
+        d.addCallback(self.check_good, "test_check_all_bad_blocks")
+        return d
+
+    def test_verify_good(self):
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_good, "test_verify_good")
+        return d
+
+    def test_verify_all_bad_sig(self):
+        corrupt(None, self._storage, 1) # bad sig
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_bad, "test_verify_all_bad_sig")
+        return d
+
+    def test_verify_one_bad_sig(self):
+        corrupt(None, self._storage, 1, [9]) # bad sig
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_bad, "test_verify_one_bad_sig")
+        return d
+
+    def test_verify_one_bad_block(self):
+        corrupt(None, self._storage, "share_data", [9]) # bad blocks
+        # the Verifier *will* notice this, since it examines every byte
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_bad, "test_verify_one_bad_block")
+        d.addCallback(self.check_expected_failure,
+                      CorruptShareError, "block hash tree failure",
+                      "test_verify_one_bad_block")
+        return d
+
+    def test_verify_one_bad_sharehash(self):
+        corrupt(None, self._storage, "share_hash_chain", [9], 5)
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_bad, "test_verify_one_bad_sharehash")
+        d.addCallback(self.check_expected_failure,
+                      CorruptShareError, "corrupt hashes",
+                      "test_verify_one_bad_sharehash")
+        return d
+
+    def test_verify_one_bad_encprivkey(self):
+        corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
+        d = self._fn.check(verify=True)
+        d.addCallback(self.check_bad, "test_verify_one_bad_encprivkey")
+        d.addCallback(self.check_expected_failure,
+                      CorruptShareError, "invalid privkey",
+                      "test_verify_one_bad_encprivkey")
+        return d
+
+    def test_verify_one_bad_encprivkey_uncheckable(self):
+        corrupt(None, self._storage, "enc_privkey", [9]) # bad privkey
+        readonly_fn = self._fn.get_readonly()
+        # a read-only node has no way to validate the privkey
+        d = readonly_fn.check(verify=True)
+        d.addCallback(self.check_good,
+                      "test_verify_one_bad_encprivkey_uncheckable")
+        return d
+
+
 class MultipleEncodings(unittest.TestCase):
     def setUp(self):
         self.CONTENTS = "New contents go here"
@@ -1261,7 +1386,7 @@ class MultipleEncodings(unittest.TestCase):
         d.addCallback(_retrieved)
         return d
 
-class MultipleVersions(unittest.TestCase):
+class MultipleVersions(unittest.TestCase, CheckerMixin):
     def setUp(self):
         self.CONTENTS = ["Contents 0",
                          "Contents 1",
@@ -1324,6 +1449,10 @@ class MultipleVersions(unittest.TestCase):
         self._set_versions(dict([(i,2) for i in (0,2,4,6,8)]))
         d = self._fn.download_best_version()
         d.addCallback(lambda res: self.failUnlessEqual(res, self.CONTENTS[4]))
+        # and the checker should report problems
+        d.addCallback(lambda res: self._fn.check())
+        d.addCallback(self.check_bad, "test_multiple_versions")
+
         # but if everything is at version 2, that's what we should download
         d.addCallback(lambda res:
                       self._set_versions(dict([(i,2) for i in range(10)])))
-- 
2.45.2