]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/mutable/retrieve.py
Remove some bare asserts in retrieve.py (there are still quite a few left). refs...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
index 029818dcd39172b3862a78f9bd63ee58643923bc..6c2c5c9bf8c48d81c267d1a5ca79aa2b70ce1011 100644 (file)
@@ -5,17 +5,21 @@ from zope.interface import implements
 from twisted.internet import defer
 from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
-from foolscap.api import eventually, fireEventually
+from foolscap.api import eventually, fireEventually, DeadReferenceError, \
+     RemoteException
+
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
-from allmydata.util import hashutil, log, mathutil
+from allmydata.util.assertutil import _assert
+from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
 from allmydata.storage.server import si_b2a
 from pycryptopp.cipher.aes import AES
 from pycryptopp.publickey import rsa
 
-from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
+from allmydata.mutable.common import CorruptShareError, BadShareError, \
+     UncoordinatedWriteError
 from allmydata.mutable.layout import MDMFSlotReadProxy
 
 class RetrieveStatus:
@@ -60,10 +64,9 @@ class RetrieveStatus:
         return self._problems
 
     def add_fetch_timing(self, server, elapsed):
-        serverid = server.get_serverid()
-        if serverid not in self.timings["fetch_per_server"]:
-            self.timings["fetch_per_server"][serverid] = []
-        self.timings["fetch_per_server"][serverid].append(elapsed)
+        if server not in self.timings["fetch_per_server"]:
+            self.timings["fetch_per_server"][server] = []
+        self.timings["fetch_per_server"][server].append(elapsed)
     def accumulate_decode_time(self, elapsed):
         self.timings["decode"] += elapsed
     def accumulate_decrypt_time(self, elapsed):
@@ -129,7 +132,7 @@ class Retrieve:
 
         # verify means that we are using the downloader logic to verify all
         # of our shares. This tells the downloader a few things.
-        # 
+        #
         # 1. We need to download all of the shares.
         # 2. We don't need to decode or decrypt the shares, since our
         #    caller doesn't care about the plaintext, only the
@@ -240,8 +243,8 @@ class Retrieve:
         self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
-        self._setup_download()
         self._setup_encoding_parameters()
+        self._setup_download()
         self.log("starting download")
         self._started_fetching = time.time()
         # The download process beyond this is a state machine.
@@ -285,24 +288,28 @@ class Retrieve:
         self.remaining_sharemap = DictOfSets()
         for (shnum, server, timestamp) in shares:
             self.remaining_sharemap.add(shnum, server)
-            # If the servermap update fetched anything, it fetched at least 1
-            # KiB, so we ask for that much.
-            # TODO: Change the cache methods to allow us to fetch all of the
-            # data that they have, then change this method to do that.
-            any_cache = self._node._read_from_cache(self.verinfo, shnum,
-                                                    0, 1000)
-            reader = MDMFSlotReadProxy(server.get_rref(),
-                                       self._storage_index,
-                                       shnum,
-                                       any_cache)
+            # Reuse the SlotReader from the servermap.
+            key = (self.verinfo, server.get_serverid(),
+                   self._storage_index, shnum)
+            if key in self.servermap.proxies:
+                reader = self.servermap.proxies[key]
+            else:
+                reader = MDMFSlotReadProxy(server.get_rref(),
+                                           self._storage_index, shnum, None)
             reader.server = server
             self.readers[shnum] = reader
-        assert len(self.remaining_sharemap) >= k
+
+        if len(self.remaining_sharemap) < k:
+            self._raise_notenoughshareserror()
 
         self.shares = {} # maps shnum to validated blocks
         self._active_readers = [] # list of active readers for this dl.
         self._block_hash_trees = {} # shnum => hashtree
 
+        for i in xrange(self._total_shares):
+            # So we don't have to do this later.
+            self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
+
         # We need one share hash tree for the entire file; its leaves
         # are the roots of the block hash trees for the shares that
         # comprise it, and its root is in the verinfo.
@@ -318,15 +325,14 @@ class Retrieve:
         segment with. I return the plaintext associated with that
         segment.
         """
-        # shnum => block hash tree. Unused, but setup_encoding_parameters will
-        # want to set this.
+        # We don't need the block hash trees in this case.
         self._block_hash_trees = None
         self._setup_encoding_parameters()
 
-        # _decode_blocks() expects the output of a DeferredList that contains
-        # the outputs of _validate_block() (each of which is a dict mapping
-        # shnum to (block,salt) bytestrings).
-        d = self._decode_blocks([(True, blocks_and_salts)], segnum)
+        # _decode_blocks() expects the output of a gatherResults that
+        # contains the outputs of _validate_block() (each of which is a dict
+        # mapping shnum to (block,salt) bytestrings).
+        d = self._decode_blocks([blocks_and_salts], segnum)
         d.addCallback(self._decrypt_segment)
         return d
 
@@ -385,15 +391,10 @@ class Retrieve:
                  (k, n, self._num_segments, self._segment_size,
                   self._tail_segment_size))
 
-        if self._block_hash_trees is not None:
-            for i in xrange(self._total_shares):
-                # So we don't have to do this later.
-                self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
-
         # Our last task is to tell the downloader where to start and
         # where to stop. We use three parameters for that:
         #   - self._start_segment: the segment that we need to start
-        #     downloading from. 
+        #     downloading from.
         #   - self._current_segment: the next segment that we need to
         #     download.
         #   - self._last_segment: The last segment that we were asked to
@@ -406,10 +407,12 @@ class Retrieve:
         if self._offset:
             self.log("got offset: %d" % self._offset)
             # our start segment is the first segment containing the
-            # offset we were given. 
+            # offset we were given.
             start = self._offset // self._segment_size
 
-            assert start < self._num_segments
+            _assert(start < self._num_segments,
+                    start=start, num_segments=self._num_segments,
+                    offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
             self.log("got start segment: %d" % self._start_segment)
         else:
@@ -430,7 +433,10 @@ class Retrieve:
                 # but the one before it.
                 end = (end_data - 1) // self._segment_size
 
-                assert end < self._num_segments
+                _assert(end < self._num_segments,
+                        end=end, num_segments=self._num_segments,
+                        end_data=end_data, offset=self._offset, read_length=self._read_length,
+                        segment_size=self._segment_size)
                 self._last_segment = end
             else:
                 self._last_segment = self._start_segment
@@ -503,7 +509,6 @@ class Retrieve:
                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
                 # XXX: don't just drop the Deferred. We need error-reporting
                 # but not flow-control here.
-        assert len(self._active_readers) >= self._required_shares
 
     def _try_to_validate_prefix(self, prefix, reader):
         """
@@ -529,40 +534,6 @@ class Retrieve:
                                           "indicate an uncoordinated write")
         # Otherwise, we're okay -- no issues.
 
-
-    def _remove_reader(self, reader):
-        """
-        At various points, we will wish to remove a server from
-        consideration and/or use. These include, but are not necessarily
-        limited to:
-
-            - A connection error.
-            - A mismatched prefix (that is, a prefix that does not match
-              our conception of the version information string).
-            - A failing block hash, salt hash, or share hash, which can
-              indicate disk failure/bit flips, or network trouble.
-
-        This method will do that. I will make sure that the
-        (shnum,reader) combination represented by my reader argument is
-        not used for anything else during this download. I will not
-        advise the reader of any corruption, something that my callers
-        may wish to do on their own.
-        """
-        # TODO: When you're done writing this, see if this is ever
-        # actually used for something that _mark_bad_share isn't. I have
-        # a feeling that they will be used for very similar things, and
-        # that having them both here is just going to be an epic amount
-        # of code duplication.
-        #
-        # (well, okay, not epic, but meaningful)
-        self.log("removing reader %s" % reader)
-        # Remove the reader from _active_readers
-        self._active_readers.remove(reader)
-        # TODO: self.readers.remove(reader)?
-        for shnum in list(self.remaining_sharemap.keys()):
-            self.remaining_sharemap.discard(shnum, reader.server)
-
-
     def _mark_bad_share(self, server, shnum, reader, f):
         """
         I mark the given (server, shnum) as a bad share, which means that it
@@ -587,23 +558,30 @@ class Retrieve:
                  (shnum, server.get_name()))
         prefix = self.verinfo[-2]
         self.servermap.mark_bad_share(server, shnum, prefix)
-        self._remove_reader(reader)
         self._bad_shares.add((server, shnum, f))
         self._status.add_problem(server, f)
         self._last_failure = f
-        self.notify_server_corruption(server, shnum, str(f.value))
 
+        # Remove the reader from _active_readers
+        self._active_readers.remove(reader)
+        for shnum in list(self.remaining_sharemap.keys()):
+            self.remaining_sharemap.discard(shnum, reader.server)
+
+        if f.check(BadShareError):
+            self.notify_server_corruption(server, shnum, str(f.value))
 
     def _download_current_segment(self):
         """
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
-        assert len(self._active_readers) >= self._required_shares
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
             return self._done()
+        elif self._verify and len(self._active_readers) == 0:
+            self.log("no more good shares, no need to keep verifying")
+            return self._done()
         self.log("on segment %d of %d" %
                  (self._current_segment + 1, self._num_segments))
         d = self._process_segment(self._current_segment)
@@ -622,7 +600,6 @@ class Retrieve:
 
         # TODO: The old code uses a marker. Should this code do that
         # too? What did the Marker do?
-        assert len(self._active_readers) >= self._required_shares
 
         # We need to ask each of our active readers for its block and
         # salt. We will then validate those. If validation is
@@ -630,13 +607,16 @@ class Retrieve:
         ds = []
         for reader in self._active_readers:
             started = time.time()
-            d = reader.get_block_and_salt(segnum)
-            d2 = self._get_needed_hashes(reader, segnum)
-            dl = defer.DeferredList([d, d2], consumeErrors=True)
-            dl.addCallback(self._validate_block, segnum, reader, reader.server, started)
-            dl.addErrback(self._validation_or_decoding_failed, [reader])
-            ds.append(dl)
-        dl = defer.DeferredList(ds)
+            d1 = reader.get_block_and_salt(segnum)
+            d2,d3 = self._get_needed_hashes(reader, segnum)
+            d = deferredutil.gatherResults([d1,d2,d3])
+            d.addCallback(self._validate_block, segnum, reader, reader.server, started)
+            # _handle_bad_share takes care of recoverable errors (by dropping
+            # that share and returning None). Any other errors (i.e. code
+            # bugs) are passed through and cause the retrieve to fail.
+            d.addErrback(self._handle_bad_share, [reader])
+            ds.append(d)
+        dl = deferredutil.gatherResults(ds)
         if self._verify:
             dl.addCallback(lambda ignored: "")
             dl.addCallback(self._set_segment)
@@ -645,35 +625,34 @@ class Retrieve:
         return dl
 
 
-    def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum):
+    def _maybe_decode_and_decrypt_segment(self, results, segnum):
         """
-        I take the results of fetching and validating the blocks from a
-        callback chain in another method. If the results are such that
-        they tell me that validation and fetching succeeded without
-        incident, I will proceed with decoding and decryption.
-        Otherwise, I will do nothing.
+        I take the results of fetching and validating the blocks from
+        _process_segment. If validation and fetching succeeded without
+        incident, I will proceed with decoding and decryption. Otherwise, I
+        will do nothing.
         """
         self.log("trying to decode and decrypt segment %d" % segnum)
-        failures = False
-        for block_and_salt in blocks_and_salts:
-            if not block_and_salt[0] or block_and_salt[1] == None:
-                self.log("some validation operations failed; not proceeding")
-                failures = True
-                break
-        if not failures:
-            self.log("everything looks ok, building segment %d" % segnum)
-            d = self._decode_blocks(blocks_and_salts, segnum)
-            d.addCallback(self._decrypt_segment)
-            d.addErrback(self._validation_or_decoding_failed,
-                         self._active_readers)
-            # check to see whether we've been paused before writing
-            # anything.
-            d.addCallback(self._check_for_paused)
-            d.addCallback(self._check_for_stopped)
-            d.addCallback(self._set_segment)
-            return d
-        else:
+
+        # 'results' is the output of a gatherResults set up in
+        # _process_segment(). Each component Deferred will either contain the
+        # non-Failure output of _validate_block() for a single block (i.e.
+        # {segnum:(block,salt)}), or None if _validate_block threw an
+        # exception and _validation_or_decoding_failed handled it (by
+        # dropping that server).
+
+        if None in results:
+            self.log("some validation operations failed; not proceeding")
             return defer.succeed(None)
+        self.log("everything looks ok, building segment %d" % segnum)
+        d = self._decode_blocks(results, segnum)
+        d.addCallback(self._decrypt_segment)
+        # check to see whether we've been paused before writing
+        # anything.
+        d.addCallback(self._check_for_paused)
+        d.addCallback(self._check_for_stopped)
+        d.addCallback(self._set_segment)
+        return d
 
 
     def _set_segment(self, segment):
@@ -718,13 +697,25 @@ class Retrieve:
         self._current_segment += 1
 
 
-    def _validation_or_decoding_failed(self, f, readers):
+    def _handle_bad_share(self, f, readers):
         """
         I am called when a block or a salt fails to correctly validate, or when
         the decryption or decoding operation fails for some reason.  I react to
         this failure by notifying the remote server of corruption, and then
         removing the remote server from further activity.
         """
+        # these are the errors we can tolerate: by giving up on this share
+        # and finding others to replace it. Any other errors (i.e. coding
+        # bugs) are re-raised, causing the download to fail.
+        f.trap(DeadReferenceError, RemoteException, BadShareError)
+
+        # DeadReferenceError happens when we try to fetch data from a server
+        # that has gone away. RemoteException happens if the server had an
+        # internal error. BadShareError encompasses: (UnknownVersionError,
+        # LayoutInvalid, struct.error) which happen when we get obviously
+        # wrong data, and CorruptShareError which happens later, when we
+        # perform integrity checks on the data.
+
         assert isinstance(readers, list)
         bad_shnums = [reader.shnum for reader in readers]
 
@@ -733,7 +724,7 @@ class Retrieve:
                  (bad_shnums, readers, self._current_segment, str(f)))
         for reader in readers:
             self._mark_bad_share(reader.server, reader.shnum, reader, f)
-        return
+        return None
 
 
     def _validate_block(self, results, segnum, reader, server, started):
@@ -747,30 +738,16 @@ class Retrieve:
         elapsed = time.time() - started
         self._status.add_fetch_timing(server, elapsed)
         self._set_current_status("validating blocks")
-        # Did we fail to fetch either of the things that we were
-        # supposed to? Fail if so.
-        if not results[0][0] and results[1][0]:
-            # handled by the errback handler.
-
-            # These all get batched into one query, so the resulting
-            # failure should be the same for all of them, so we can just
-            # use the first one.
-            assert isinstance(results[0][1], failure.Failure)
-
-            f = results[0][1]
-            raise CorruptShareError(server,
-                                    reader.shnum,
-                                    "Connection error: %s" % str(f))
 
-        block_and_salt, block_and_sharehashes = results
-        block, salt = block_and_salt[1]
-        blockhashes, sharehashes = block_and_sharehashes[1]
+        block_and_salt, blockhashes, sharehashes = results
+        block, salt = block_and_salt
+        assert type(block) is str, (block, salt)
 
-        blockhashes = dict(enumerate(blockhashes[1]))
+        blockhashes = dict(enumerate(blockhashes))
         self.log("the reader gave me the following blockhashes: %s" % \
                  blockhashes.keys())
         self.log("the reader gave me the following sharehashes: %s" % \
-                 sharehashes[1].keys())
+                 sharehashes.keys())
         bht = self._block_hash_trees[reader.shnum]
 
         if bht.needed_hashes(segnum, include_leaf=True):
@@ -798,24 +775,21 @@ class Retrieve:
 
         # Reaching this point means that we know that this segment
         # is correct. Now we need to check to see whether the share
-        # hash chain is also correct. 
+        # hash chain is also correct.
         # SDMF wrote share hash chains that didn't contain the
         # leaves, which would be produced from the block hash tree.
         # So we need to validate the block hash tree first. If
         # successful, then bht[0] will contain the root for the
         # shnum, which will be a leaf in the share hash tree, which
         # will allow us to validate the rest of the tree.
-        if self.share_hash_tree.needed_hashes(reader.shnum,
-                                              include_leaf=True) or \
-                                              self._verify:
-            try:
-                self.share_hash_tree.set_hashes(hashes=sharehashes[1],
-                                            leaves={reader.shnum: bht[0]})
-            except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
-                    IndexError), e:
-                raise CorruptShareError(server,
-                                        reader.shnum,
-                                        "corrupt hashes: %s" % e)
+        try:
+            self.share_hash_tree.set_hashes(hashes=sharehashes,
+                                        leaves={reader.shnum: bht[0]})
+        except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
+                IndexError), e:
+            raise CorruptShareError(server,
+                                    reader.shnum,
+                                    "corrupt hashes: %s" % e)
 
         self.log('share %d is valid for segment %d' % (reader.shnum,
                                                        segnum))
@@ -840,32 +814,29 @@ class Retrieve:
         #needed.discard(0)
         self.log("getting blockhashes for segment %d, share %d: %s" % \
                  (segnum, reader.shnum, str(needed)))
-        d1 = reader.get_blockhashes(needed, force_remote=True)
+        # TODO is force_remote necessary here?
+        d1 = reader.get_blockhashes(needed, force_remote=False)
         if self.share_hash_tree.needed_hashes(reader.shnum):
             need = self.share_hash_tree.needed_hashes(reader.shnum)
             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
                                                                  str(need)))
-            d2 = reader.get_sharehashes(need, force_remote=True)
+            d2 = reader.get_sharehashes(need, force_remote=False)
         else:
             d2 = defer.succeed({}) # the logic in the next method
                                    # expects a dict
-        dl = defer.DeferredList([d1, d2], consumeErrors=True)
-        return dl
+        return d1,d2
 
 
-    def _decode_blocks(self, blocks_and_salts, segnum):
+    def _decode_blocks(self, results, segnum):
         """
         I take a list of k blocks and salts, and decode that into a
         single encrypted segment.
         """
-        d = {}
-        # We want to merge our dictionaries to the form 
-        # {shnum: blocks_and_salts}
-        #
-        # The dictionaries come from validate block that way, so we just
-        # need to merge them.
-        for block_and_salt in blocks_and_salts:
-            d.update(block_and_salt[1])
+        # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
+        # want to merge them all
+        blocks_and_salts = {}
+        for d in results:
+            blocks_and_salts.update(d)
 
         # All of these blocks should have the same salt; in SDMF, it is
         # the file-wide IV, while in MDMF it is the per-segment salt. In
@@ -874,10 +845,10 @@ class Retrieve:
         # d.items()[0] is like (shnum, (block, salt))
         # d.items()[0][1] is like (block, salt)
         # d.items()[0][1][1] is the salt.
-        salt = d.items()[0][1][1]
+        salt = blocks_and_salts.items()[0][1][1]
         # Next, extract just the blocks from the dict. We'll use the
         # salt in the next step.
-        share_and_shareids = [(k, v[0]) for k, v in d.items()]
+        share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
         d2 = dict(share_and_shareids)
         shareids = []
         shares = []
@@ -998,23 +969,25 @@ class Retrieve:
 
     def _raise_notenoughshareserror(self):
         """
-        I am called by _activate_enough_servers when there are not enough
-        active servers left to complete the download. After making some
-        useful logging statements, I throw an exception to that effect
-        to the caller of this Retrieve object through
+        I am called when there are not enough active servers left to complete
+        the download. After making some useful logging statements, I throw an
+        exception to that effect to the caller of this Retrieve object through
         self._done_deferred.
         """
 
         format = ("ran out of servers: "
-                  "have %(have)d of %(total)d segments "
-                  "found %(bad)d bad shares "
+                  "have %(have)d of %(total)d segments; "
+                  "found %(bad)d bad shares; "
+                  "have %(remaining)d remaining shares of the right version; "
                   "encoding %(k)d-of-%(n)d")
         args = {"have": self._current_segment,
                 "total": self._num_segments,
                 "need": self._last_segment,
                 "k": self._required_shares,
                 "n": self._total_shares,
-                "bad": len(self._bad_shares)}
+                "bad": len(self._bad_shares),
+                "remaining": len(self.remaining_sharemap),
+               }
         raise NotEnoughSharesError("%s, last failure: %s" %
                                    (format % args, str(self._last_failure)))