mutable/servermap: improve test coverage
authorBrian Warner <warner@allmydata.com>
Tue, 22 Apr 2008 23:47:52 +0000 (16:47 -0700)
committerBrian Warner <warner@allmydata.com>
Tue, 22 Apr 2008 23:47:52 +0000 (16:47 -0700)
src/allmydata/mutable/servermap.py
src/allmydata/test/test_mutable.py

index ea9de5338d1e884142ee25ae3b328c556e9a9b4e..0bf8e9d74f97499191c2597bb10d32ead5b702f8 100644 (file)
@@ -484,9 +484,9 @@ class ServermapUpdater:
 
     def _got_results(self, datavs, peerid, readsize, stuff, started):
         lp = self.log(format="got result from [%(peerid)s], %(numshares)d shares",
-                     peerid=idlib.shortnodeid_b2a(peerid),
-                     numshares=len(datavs),
-                     level=log.NOISY)
+                      peerid=idlib.shortnodeid_b2a(peerid),
+                      numshares=len(datavs),
+                      level=log.NOISY)
         now = time.time()
         elapsed = now - started
         self._queries_outstanding.discard(peerid)
@@ -508,7 +508,7 @@ class ServermapUpdater:
         for shnum,datav in datavs.items():
             data = datav[0]
             try:
-                verinfo = self._got_results_one_share(shnum, data, peerid)
+                verinfo = self._got_results_one_share(shnum, data, peerid, lp)
                 last_verinfo = verinfo
                 last_shnum = shnum
                 self._node._cache.add(verinfo, shnum, 0, data, now)
@@ -527,6 +527,8 @@ class ServermapUpdater:
         if self._need_privkey and last_verinfo:
             # send them a request for the privkey. We send one request per
             # server.
+            lp2 = self.log("sending privkey request",
+                           parent=lp, level=log.NOISY)
             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
              offsets_tuple) = last_verinfo
             o = dict(offsets_tuple)
@@ -538,8 +540,8 @@ class ServermapUpdater:
             d = self._do_read(ss, peerid, self._storage_index,
                               [last_shnum], readv)
             d.addCallback(self._got_privkey_results, peerid, last_shnum,
-                          privkey_started)
-            d.addErrback(self._privkey_query_failed, peerid, last_shnum)
+                          privkey_started, lp2)
+            d.addErrback(self._privkey_query_failed, peerid, last_shnum, lp2)
             d.addErrback(log.err)
             d.addCallback(self._check_for_done)
             d.addErrback(self._fatal_error)
@@ -547,10 +549,11 @@ class ServermapUpdater:
         # all done!
         self.log("_got_results done", parent=lp)
 
-    def _got_results_one_share(self, shnum, data, peerid):
-        lp = self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
-                      shnum=shnum,
-                      peerid=idlib.shortnodeid_b2a(peerid))
+    def _got_results_one_share(self, shnum, data, peerid, lp):
+        self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s",
+                 shnum=shnum,
+                 peerid=idlib.shortnodeid_b2a(peerid),
+                 parent=lp)
 
         # this might raise NeedMoreDataError, if the pubkey and signature
         # live at some weird offset. That shouldn't happen, so I'm going to
@@ -567,7 +570,7 @@ class ServermapUpdater:
             self._node._populate_pubkey(self._deserialize_pubkey(pubkey_s))
 
         if self._need_privkey:
-            self._try_to_extract_privkey(data, peerid, shnum)
+            self._try_to_extract_privkey(data, peerid, shnum, lp)
 
         (ig_version, ig_seqnum, ig_root_hash, ig_IV, ig_k, ig_N,
          ig_segsize, ig_datalen, offsets) = unpack_header(data)
@@ -610,7 +613,7 @@ class ServermapUpdater:
         verifier = rsa.create_verifying_key_from_string(pubkey_s)
         return verifier
 
-    def _try_to_extract_privkey(self, data, peerid, shnum):
+    def _try_to_extract_privkey(self, data, peerid, shnum, lp):
         try:
             r = unpack_share(data)
         except NeedMoreDataError, e:
@@ -620,7 +623,8 @@ class ServermapUpdater:
             self.log("shnum %d on peerid %s: share was too short (%dB) "
                      "to get the encprivkey; [%d:%d] ought to hold it" %
                      (shnum, idlib.shortnodeid_b2a(peerid), len(data),
-                      offset, offset+length))
+                      offset, offset+length),
+                     parent=lp)
             # NOTE: if uncoordinated writes are taking place, someone might
             # change the share (and most probably move the encprivkey) before
             # we get a chance to do one of these reads and fetch it. This
@@ -636,20 +640,22 @@ class ServermapUpdater:
          pubkey, signature, share_hash_chain, block_hash_tree,
          share_data, enc_privkey) = r
 
-        return self._try_to_validate_privkey(enc_privkey, peerid, shnum)
+        return self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
 
-    def _try_to_validate_privkey(self, enc_privkey, peerid, shnum):
+    def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
 
         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
         if alleged_writekey != self._node.get_writekey():
             self.log("invalid privkey from %s shnum %d" %
-                     (idlib.nodeid_b2a(peerid)[:8], shnum), level=log.WEIRD)
+                     (idlib.nodeid_b2a(peerid)[:8], shnum),
+                     parent=lp, level=log.WEIRD)
             return
 
         # it's good
         self.log("got valid privkey from shnum %d on peerid %s" %
-                 (shnum, idlib.shortnodeid_b2a(peerid)))
+                 (shnum, idlib.shortnodeid_b2a(peerid)),
+                 parent=lp)
         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
         self._node._populate_encprivkey(enc_privkey)
         self._node._populate_privkey(privkey)
@@ -669,7 +675,7 @@ class ServermapUpdater:
         self._queries_completed += 1
         self._last_failure = f
 
-    def _got_privkey_results(self, datavs, peerid, shnum, started):
+    def _got_privkey_results(self, datavs, peerid, shnum, started, lp):
         now = time.time()
         elapsed = now - started
         self._status.add_per_server_time(peerid, "privkey", started, elapsed)
@@ -681,12 +687,12 @@ class ServermapUpdater:
             return
         datav = datavs[shnum]
         enc_privkey = datav[0]
-        self._try_to_validate_privkey(enc_privkey, peerid, shnum)
+        self._try_to_validate_privkey(enc_privkey, peerid, shnum, lp)
 
-    def _privkey_query_failed(self, f, peerid, shnum):
+    def _privkey_query_failed(self, f, peerid, shnum, lp):
         self._queries_outstanding.discard(peerid)
         self.log("error during privkey query: %s %s" % (f, f.value),
-                 level=log.WEIRD)
+                 parent=lp, level=log.WEIRD)
         if not self._running:
             return
         self._queries_outstanding.discard(peerid)
@@ -702,12 +708,14 @@ class ServermapUpdater:
         lp = self.log(format=("_check_for_done, mode is '%(mode)s', "
                               "%(outstanding)d queries outstanding, "
                               "%(extra)d extra peers available, "
-                              "%(must)d 'must query' peers left"
+                              "%(must)d 'must query' peers left, "
+                              "need_privkey=%(need_privkey)s"
                               ),
                       mode=self.mode,
                       outstanding=len(self._queries_outstanding),
                       extra=len(self.extra_peers),
                       must=len(self._must_query),
+                      need_privkey=self._need_privkey,
                       level=log.NOISY,
                       )
 
index 7832000b211e8d218e047922c5fa704168870887..dc3369b2f20a138a0664cab8aec4db31b26e394f 100644 (file)
@@ -3,6 +3,7 @@ import os, struct
 from cStringIO import StringIO
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
+from twisted.python import failure
 from allmydata import uri, download, storage
 from allmydata.util import base32, testutil, idlib
 from allmydata.util.idlib import shortnodeid_b2a
@@ -54,18 +55,29 @@ class FakeStorage:
         # order).
         self._sequence = None
         self._pending = {}
+        self._pending_timer = None
+        self._special_answers = {}
 
     def read(self, peerid, storage_index):
         shares = self._peers.get(peerid, {})
+        if self._special_answers.get(peerid, []):
+            mode = self._special_answers[peerid].pop(0)
+            if mode == "fail":
+                shares = failure.Failure(IntentionalError())
+            elif mode == "none":
+                shares = {}
+            elif mode == "normal":
+                pass
         if self._sequence is None:
             return defer.succeed(shares)
         d = defer.Deferred()
         if not self._pending:
-            reactor.callLater(1.0, self._fire_readers)
+            self._pending_timer = reactor.callLater(1.0, self._fire_readers)
         self._pending[peerid] = (d, shares)
         return d
 
     def _fire_readers(self):
+        self._pending_timer = None
         pending = self._pending
         self._pending = {}
         extra = []
@@ -654,7 +666,7 @@ class Servermap(unittest.TestCase):
         d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10))
 
         # create a new file, which is large enough to knock the privkey out
-        # of the early part of the fil
+        # of the early part of the file
         LARGE = "These are Larger contents" * 200 # about 5KB
         d.addCallback(lambda res: self._client.create_mutable_file(LARGE))
         def _created(large_fn):
@@ -1342,6 +1354,7 @@ class LocalWrapper:
     def __init__(self, original):
         self.original = original
         self.broken = False
+        self.post_call_notifier = None
     def callRemote(self, methname, *args, **kwargs):
         def _call():
             if self.broken:
@@ -1350,6 +1363,8 @@ class LocalWrapper:
             return meth(*args, **kwargs)
         d = fireEventually()
         d.addCallback(lambda res: _call())
+        if self.post_call_notifier:
+            d.addCallback(self.post_call_notifier, methname)
         return d
 
 class LessFakeClient(FakeClient):
@@ -1469,3 +1484,72 @@ class Problems(unittest.TestCase, testutil.ShouldFailMixin):
         d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
         return d
 
+    def test_privkey_query_error(self):
+        # when a servermap is updated with MODE_WRITE, it tries to get the
+        # privkey. Something might go wrong during this query attempt.
+        self.client = FakeClient(20)
+        # we need some contents that are large enough to push the privkey out
+        # of the early part of the file
+        LARGE = "These are Larger contents" * 200 # about 5KB
+        d = self.client.create_mutable_file(LARGE)
+        def _created(n):
+            self.uri = n.get_uri()
+            self.n2 = self.client.create_node_from_uri(self.uri)
+            # we start by doing a map update to figure out which is the first
+            # server.
+            return n.get_servermap(MODE_WRITE)
+        d.addCallback(_created)
+        d.addCallback(lambda res: fireEventually(res))
+        def _got_smap1(smap):
+            peer0 = list(smap.make_sharemap()[0])[0]
+            # we tell the server to respond to this peer first, so that it
+            # will be asked for the privkey first
+            self.client._storage._sequence = [peer0]
+            # now we make the peer fail their second query
+            self.client._storage._special_answers[peer0] = ["normal", "fail"]
+        d.addCallback(_got_smap1)
+        # now we update a servermap from a new node (which doesn't have the
+        # privkey yet, forcing it to use a separate privkey query). Each
+        # query response will trigger a privkey query, and since we're using
+        # _sequence to make the peer0 response come back first, we'll send it
+        # a privkey query first, and _sequence will again ensure that the
+        # peer0 query will also come back before the others, and then
+        # _special_answers will make sure that the query raises an exception.
+        # The whole point of these hijinks is to exercise the code in
+        # _privkey_query_failed. Note that the map-update will succeed, since
+        # we'll just get a copy from one of the other shares.
+        d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
+        # Using FakeStorage._sequence means there will be read requests still
+        # floating around.. wait for them to retire
+        def _cancel_timer(res):
+            if self.client._storage._pending_timer:
+                self.client._storage._pending_timer.cancel()
+            return res
+        d.addBoth(_cancel_timer)
+        return d
+
+    def test_privkey_query_missing(self):
+        # like test_privkey_query_error, but the shares are deleted by the
+        # second query, instead of raising an exception.
+        self.client = FakeClient(20)
+        LARGE = "These are Larger contents" * 200 # about 5KB
+        d = self.client.create_mutable_file(LARGE)
+        def _created(n):
+            self.uri = n.get_uri()
+            self.n2 = self.client.create_node_from_uri(self.uri)
+            return n.get_servermap(MODE_WRITE)
+        d.addCallback(_created)
+        d.addCallback(lambda res: fireEventually(res))
+        def _got_smap1(smap):
+            peer0 = list(smap.make_sharemap()[0])[0]
+            self.client._storage._sequence = [peer0]
+            self.client._storage._special_answers[peer0] = ["normal", "none"]
+        d.addCallback(_got_smap1)
+        d.addCallback(lambda res: self.n2.get_servermap(MODE_WRITE))
+        def _cancel_timer(res):
+            if self.client._storage._pending_timer:
+                self.client._storage._pending_timer.cancel()
+            return res
+        d.addBoth(_cancel_timer)
+        return d
+