From 893eea849ba7cf788295679875c319ff509aaba8 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Sat, 7 Jan 2012 18:12:51 -0800
Subject: [PATCH] mutable/retrieve.py: clean up control flow to avoid dropping
 errors

* replace DeferredList with gatherResults, simplify result handling
* use BadShareError to signal recoverable problems in either fetch or
  validate, catch after _validate_block
* _validate_block is thus not responsible for noticing fetch problems
* rename _validation_or_decoding_failed() to _handle_bad_share()
* _get_needed_hashes() returns two Deferreds, instead of a hard-to-unpack
  DeferredList
---
 src/allmydata/mutable/retrieve.py | 73 ++++++++++++++-----------------
 1 file changed, 34 insertions(+), 39 deletions(-)

diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py
index 055a2d39..3b556593 100644
--- a/src/allmydata/mutable/retrieve.py
+++ b/src/allmydata/mutable/retrieve.py
@@ -5,7 +5,8 @@ from zope.interface import implements
 from twisted.internet import defer
 from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
-from foolscap.api import eventually, fireEventually
+from foolscap.api import eventually, fireEventually, DeadReferenceError, \
+     RemoteException
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
 from allmydata.util import hashutil, log, mathutil, deferredutil
@@ -15,7 +16,8 @@ from allmydata.storage.server import si_b2a
 from pycryptopp.cipher.aes import AES
 from pycryptopp.publickey import rsa
 
-from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
+from allmydata.mutable.common import CorruptShareError, BadShareError, \
+     UncoordinatedWriteError
 from allmydata.mutable.layout import MDMFSlotReadProxy
 
 class RetrieveStatus:
@@ -591,7 +593,7 @@ class Retrieve:
         self._bad_shares.add((server, shnum, f))
         self._status.add_problem(server, f)
         self._last_failure = f
-        if f.check(CorruptShareError):
+        if f.check(BadShareError):
             self.notify_server_corruption(server, shnum, str(f.value))
 
 
@@ -631,16 +633,15 @@ class Retrieve:
         ds = []
         for reader in self._active_readers:
             started = time.time()
-            d = reader.get_block_and_salt(segnum)
-            d2 = self._get_needed_hashes(reader, segnum)
-            dl = defer.DeferredList([d, d2], consumeErrors=True)
-            dl.addCallback(self._validate_block, segnum, reader, reader.server, started)
-            dl.addErrback(self._validation_or_decoding_failed, [reader])
-            ds.append(dl)
-        # _validation_or_decoding_failed is supposed to eat any recoverable
-        # errors (like corrupt shares), returning a None when that happens.
-        # If it raises an exception itself, or if it can't handle the error,
-        # the download should fail. So we can use gatherResults() here.
+            d1 = reader.get_block_and_salt(segnum)
+            d2,d3 = self._get_needed_hashes(reader, segnum)
+            d = deferredutil.gatherResults([d1,d2,d3])
+            d.addCallback(self._validate_block, segnum, reader, reader.server, started)
+            # _handle_bad_share takes care of recoverable errors (by dropping
+            # that share and returning None). Any other errors (i.e. code
+            # bugs) are passed through and cause the retrieve to fail.
+            d.addErrback(self._handle_bad_share, [reader])
+            ds.append(d)
         dl = deferredutil.gatherResults(ds)
         if self._verify:
             dl.addCallback(lambda ignored: "")
@@ -672,8 +673,6 @@ class Retrieve:
         self.log("everything looks ok, building segment %d" % segnum)
         d = self._decode_blocks(results, segnum)
         d.addCallback(self._decrypt_segment)
-        d.addErrback(self._validation_or_decoding_failed,
-                     self._active_readers)
         # check to see whether we've been paused before writing
         # anything.
         d.addCallback(self._check_for_paused)
@@ -724,13 +723,25 @@ class Retrieve:
         self._current_segment += 1
 
 
-    def _validation_or_decoding_failed(self, f, readers):
+    def _handle_bad_share(self, f, readers):
         """
         I am called when a block or a salt fails to correctly validate, or when
         the decryption or decoding operation fails for some reason.  I react to
         this failure by notifying the remote server of corruption, and then
         removing the remote server from further activity.
         """
+        # these are the errors we can tolerate: by giving up on this share
+        # and finding others to replace it. Any other errors (i.e. coding
+        # bugs) are re-raised, causing the download to fail.
+        f.trap(DeadReferenceError, RemoteException, BadShareError)
+
+        # DeadReferenceError happens when we try to fetch data from a server
+        # that has gone away. RemoteException happens if the server had an
+        # internal error. BadShareError encompasses: (UnknownVersionError,
+        # LayoutInvalid, struct.error) which happen when we get obviously
+        # wrong data, and CorruptShareError which happens later, when we
+        # perform integrity checks on the data.
+
         assert isinstance(readers, list)
         bad_shnums = [reader.shnum for reader in readers]
 
@@ -739,7 +750,7 @@ class Retrieve:
                  (bad_shnums, readers, self._current_segment, str(f)))
         for reader in readers:
             self._mark_bad_share(reader.server, reader.shnum, reader, f)
-        return
+        return None
 
 
     def _validate_block(self, results, segnum, reader, server, started):
@@ -753,30 +764,15 @@ class Retrieve:
         elapsed = time.time() - started
         self._status.add_fetch_timing(server, elapsed)
         self._set_current_status("validating blocks")
-        # Did we fail to fetch either of the things that we were
-        # supposed to? Fail if so.
-        if not results[0][0] and results[1][0]:
-            # handled by the errback handler.
-
-            # These all get batched into one query, so the resulting
-            # failure should be the same for all of them, so we can just
-            # use the first one.
-            assert isinstance(results[0][1], failure.Failure)
-
-            f = results[0][1]
-            raise CorruptShareError(server,
-                                    reader.shnum,
-                                    "Connection error: %s" % str(f))
 
-        block_and_salt, block_and_sharehashes = results
-        block, salt = block_and_salt[1]
-        blockhashes, sharehashes = block_and_sharehashes[1]
+        block_and_salt, blockhashes, sharehashes = results
+        block, salt = block_and_salt
 
-        blockhashes = dict(enumerate(blockhashes[1]))
+        blockhashes = dict(enumerate(blockhashes))
         self.log("the reader gave me the following blockhashes: %s" % \
                  blockhashes.keys())
         self.log("the reader gave me the following sharehashes: %s" % \
-                 sharehashes[1].keys())
+                 sharehashes.keys())
         bht = self._block_hash_trees[reader.shnum]
 
         if bht.needed_hashes(segnum, include_leaf=True):
@@ -815,7 +811,7 @@ class Retrieve:
                                               include_leaf=True) or \
                                               self._verify:
             try:
-                self.share_hash_tree.set_hashes(hashes=sharehashes[1],
+                self.share_hash_tree.set_hashes(hashes=sharehashes,
                                             leaves={reader.shnum: bht[0]})
             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                     IndexError), e:
@@ -855,8 +851,7 @@ class Retrieve:
         else:
             d2 = defer.succeed({}) # the logic in the next method
                                    # expects a dict
-        dl = defer.DeferredList([d1, d2], consumeErrors=True)
-        return dl
+        return d1,d2
 
 
     def _decode_blocks(self, results, segnum):
-- 
2.45.2