]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/mutable/retrieve.py
mutable/retrieve.py: remove all bare assert()s
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
index 100350a0b62b00925bfae12dddd70b56649665e6..9d3038aedaee515435381bb81d9dc460e4cc9901 100644 (file)
@@ -5,17 +5,21 @@ from zope.interface import implements
 from twisted.internet import defer
 from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
-from foolscap.api import eventually, fireEventually
+from foolscap.api import eventually, fireEventually, DeadReferenceError, \
+     RemoteException
+
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
-                                 MDMF_VERSION, SDMF_VERSION
-from allmydata.util import hashutil, log, mathutil
+     DownloadStopped, MDMF_VERSION, SDMF_VERSION
+from allmydata.util.assertutil import _assert, precondition
+from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
 from allmydata.storage.server import si_b2a
 from pycryptopp.cipher.aes import AES
 from pycryptopp.publickey import rsa
 
-from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
+from allmydata.mutable.common import CorruptShareError, BadShareError, \
+     UncoordinatedWriteError
 from allmydata.mutable.layout import MDMFSlotReadProxy
 
 class RetrieveStatus:
@@ -27,7 +31,7 @@ class RetrieveStatus:
         self.timings["decode"] = 0.0
         self.timings["decrypt"] = 0.0
         self.timings["cumulative_verify"] = 0.0
-        self.problems = {}
+        self._problems = {}
         self.active = True
         self.storage_index = None
         self.helper = False
@@ -56,11 +60,13 @@ class RetrieveStatus:
         return self.active
     def get_counter(self):
         return self.counter
+    def get_problems(self):
+        return self._problems
 
-    def add_fetch_timing(self, peerid, elapsed):
-        if peerid not in self.timings["fetch_per_server"]:
-            self.timings["fetch_per_server"][peerid] = []
-        self.timings["fetch_per_server"][peerid].append(elapsed)
+    def add_fetch_timing(self, server, elapsed):
+        if server not in self.timings["fetch_per_server"]:
+            self.timings["fetch_per_server"][server] = []
+        self.timings["fetch_per_server"][server].append(elapsed)
     def accumulate_decode_time(self, elapsed):
         self.timings["decode"] += elapsed
     def accumulate_decrypt_time(self, elapsed):
@@ -79,6 +85,9 @@ class RetrieveStatus:
         self.progress = value
     def set_active(self, value):
         self.active = value
+    def add_problem(self, server, f):
+        serverid = server.get_serverid()
+        self._problems[serverid] = f
 
 class Marker:
     pass
@@ -91,23 +100,26 @@ class Retrieve:
     # will use a single ServerMap instance.
     implements(IPushProducer)
 
-    def __init__(self, filenode, servermap, verinfo, fetch_privkey=False,
-                 verify=False):
+    def __init__(self, filenode, storage_broker, servermap, verinfo,
+                 fetch_privkey=False, verify=False):
         self._node = filenode
-        assert self._node.get_pubkey()
+        _assert(self._node.get_pubkey())
+        self._storage_broker = storage_broker
         self._storage_index = filenode.get_storage_index()
-        assert self._node.get_readkey()
+        _assert(self._node.get_readkey())
         self._last_failure = None
         prefix = si_b2a(self._storage_index)[:5]
         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
-        self._outstanding_queries = {} # maps (peerid,shnum) to start_time
         self._running = True
         self._decoding = False
         self._bad_shares = set()
 
         self.servermap = servermap
-        assert self._node.get_pubkey()
         self.verinfo = verinfo
+        # TODO: make it possible to use self.verinfo.datalength instead
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.verinfo
+        self._data_length = datalength
         # during repair, we may be called upon to grab the private key, since
         # it wasn't picked up during a verify=False checker run, and we'll
         # need it for repair to generate a new version.
@@ -123,7 +135,7 @@ class Retrieve:
 
         # verify means that we are using the downloader logic to verify all
         # of our shares. This tells the downloader a few things.
-        # 
+        #
         # 1. We need to download all of the shares.
         # 2. We don't need to decode or decrypt the shares, since our
         #    caller doesn't care about the plaintext, only the
@@ -138,11 +150,10 @@ class Retrieve:
         self._status.set_helper(False)
         self._status.set_progress(0.0)
         self._status.set_active(True)
-        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-         offsets_tuple) = self.verinfo
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
+        self._stopped = False
         self._pause_deferred = None
         self._offset = None
         self._read_length = None
@@ -175,7 +186,7 @@ class Retrieve:
         if self._pause_deferred is not None:
             return
 
-        # fired when the download is unpaused. 
+        # fired when the download is unpaused.
         self._old_status = self._status.get_status()
         self._set_current_status("paused")
 
@@ -196,6 +207,10 @@ class Retrieve:
 
         eventually(p.callback, None)
 
+    def stopProducing(self):
+        self._stopped = True
+        self.resumeProducing()
+
 
     def _check_for_paused(self, res):
         """
@@ -209,46 +224,65 @@ class Retrieve:
             d = defer.Deferred()
             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
             return d
-        return defer.succeed(res)
+        return res
 
+    def _check_for_stopped(self, res):
+        if self._stopped:
+            raise DownloadStopped("our Consumer called stopProducing()")
+        return res
 
-    def download(self, consumer=None, offset=0, size=None):
-        assert IConsumer.providedBy(consumer) or self._verify
 
+    def download(self, consumer=None, offset=0, size=None):
+        precondition(self._verify or IConsumer.providedBy(consumer))
+        if size is None:
+            size = self._data_length - offset
+        if self._verify:
+            _assert(size == self._data_length, (size, self._data_length))
+        self.log("starting download")
+        self._done_deferred = defer.Deferred()
         if consumer:
             self._consumer = consumer
-            # we provide IPushProducer, so streaming=True, per
-            # IConsumer.
+            # we provide IPushProducer, so streaming=True, per IConsumer.
             self._consumer.registerProducer(self, streaming=True)
+        self._started = time.time()
+        self._started_fetching = time.time()
+        if size == 0:
+            # short-circuit the rest of the process
+            self._done()
+        else:
+            self._start_download(consumer, offset, size)
+        return self._done_deferred
+
+    def _start_download(self, consumer, offset, size):
+        precondition((0 <= offset < self._data_length)
+                     and (size > 0)
+                     and (offset+size <= self._data_length),
+                     (offset, size, self._data_length))
 
-        self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
-        self._setup_download()
         self._setup_encoding_parameters()
-        self.log("starting download")
-        self._started_fetching = time.time()
+        self._setup_download()
+
         # The download process beyond this is a state machine.
-        # _add_active_peers will select the peers that we want to use
+        # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
         # each segment, it will check for doneness, reacting to broken
-        # peers and corrupt shares as necessary. If it runs out of good
-        # peers before downloading all of the segments, _done_deferred
+        # servers and corrupt shares as necessary. If it runs out of good
+        # servers before downloading all of the segments, _done_deferred
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
-        return self._done_deferred
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
-        d.addCallback(lambda ign: self._activate_enough_peers())
+        d.addCallback(lambda ign: self._activate_enough_servers())
         d.addCallback(lambda ign: self._download_current_segment())
         # when we're done, _download_current_segment will call _done. If we
         # aren't, it will call loop() again.
         d.addErrback(self._error)
 
     def _setup_download(self):
-        self._started = time.time()
         self._status.set_status("Retrieving Shares")
 
         # how many shares do we need?
@@ -267,27 +301,30 @@ class Retrieve:
         shares = versionmap[self.verinfo]
         # this sharemap is consumed as we decide to send requests
         self.remaining_sharemap = DictOfSets()
-        for (shnum, peerid, timestamp) in shares:
-            self.remaining_sharemap.add(shnum, peerid)
-            # If the servermap update fetched anything, it fetched at least 1
-            # KiB, so we ask for that much.
-            # TODO: Change the cache methods to allow us to fetch all of the
-            # data that they have, then change this method to do that.
-            any_cache = self._node._read_from_cache(self.verinfo, shnum,
-                                                    0, 1000)
-            ss = self.servermap.connections[peerid]
-            reader = MDMFSlotReadProxy(ss,
-                                       self._storage_index,
-                                       shnum,
-                                       any_cache)
-            reader.peerid = peerid
+        for (shnum, server, timestamp) in shares:
+            self.remaining_sharemap.add(shnum, server)
+            # Reuse the SlotReader from the servermap.
+            key = (self.verinfo, server.get_serverid(),
+                   self._storage_index, shnum)
+            if key in self.servermap.proxies:
+                reader = self.servermap.proxies[key]
+            else:
+                reader = MDMFSlotReadProxy(server.get_rref(),
+                                           self._storage_index, shnum, None)
+            reader.server = server
             self.readers[shnum] = reader
-        assert len(self.remaining_sharemap) >= k
+
+        if len(self.remaining_sharemap) < k:
+            self._raise_notenoughshareserror()
 
         self.shares = {} # maps shnum to validated blocks
         self._active_readers = [] # list of active readers for this dl.
         self._block_hash_trees = {} # shnum => hashtree
 
+        for i in xrange(self._total_shares):
+            # So we don't have to do this later.
+            self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
+
         # We need one share hash tree for the entire file; its leaves
         # are the roots of the block hash trees for the shares that
         # comprise it, and its root is in the verinfo.
@@ -303,16 +340,16 @@ class Retrieve:
         segment with. I return the plaintext associated with that
         segment.
         """
-        # shnum => block hash tree. Unused, but setup_encoding_parameters will
-        # want to set this.
+        # We don't need the block hash trees in this case.
         self._block_hash_trees = None
+        self._offset = 0
+        self._read_length = self._data_length
         self._setup_encoding_parameters()
 
-        # This is the form expected by decode.
-        blocks_and_salts = blocks_and_salts.items()
-        blocks_and_salts = [(True, [d]) for d in blocks_and_salts]
-
-        d = self._decode_blocks(blocks_and_salts, segnum)
+        # _decode_blocks() expects the output of a gatherResults that
+        # contains the outputs of _validate_block() (each of which is a dict
+        # mapping shnum to (block,salt) bytestrings).
+        d = self._decode_blocks([blocks_and_salts], segnum)
         d.addCallback(self._decrypt_segment)
         return d
 
@@ -334,7 +371,7 @@ class Retrieve:
         self._required_shares = k
         self._total_shares = n
         self._segment_size = segsize
-        self._data_length = datalength
+        #self._data_length = datalength # set during __init__()
 
         if not IV:
             self._version = MDMF_VERSION
@@ -371,15 +408,10 @@ class Retrieve:
                  (k, n, self._num_segments, self._segment_size,
                   self._tail_segment_size))
 
-        if self._block_hash_trees is not None:
-            for i in xrange(self._total_shares):
-                # So we don't have to do this later.
-                self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
-
         # Our last task is to tell the downloader where to start and
         # where to stop. We use three parameters for that:
         #   - self._start_segment: the segment that we need to start
-        #     downloading from. 
+        #     downloading from.
         #   - self._current_segment: the next segment that we need to
         #     download.
         #   - self._last_segment: The last segment that we were asked to
@@ -392,41 +424,36 @@ class Retrieve:
         if self._offset:
             self.log("got offset: %d" % self._offset)
             # our start segment is the first segment containing the
-            # offset we were given. 
+            # offset we were given.
             start = self._offset // self._segment_size
 
-            assert start < self._num_segments
+            _assert(start <= self._num_segments,
+                    start=start, num_segments=self._num_segments,
+                    offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
             self.log("got start segment: %d" % self._start_segment)
         else:
             self._start_segment = 0
 
-
-        # If self._read_length is None, then we want to read the whole
-        # file. Otherwise, we want to read only part of the file, and
-        # need to figure out where to stop reading.
-        if self._read_length is not None:
-            # our end segment is the last segment containing part of the
-            # segment that we were asked to read.
-            self.log("got read length %d" % self._read_length)
-            if self._read_length != 0:
-                end_data = self._offset + self._read_length
-
-                # We don't actually need to read the byte at end_data,
-                # but the one before it.
-                end = (end_data - 1) // self._segment_size
-
-                assert end < self._num_segments
-                self._last_segment = end
-            else:
-                self._last_segment = self._start_segment
-            self.log("got end segment: %d" % self._last_segment)
-        else:
-            self._last_segment = self._num_segments - 1
+        # We might want to read only part of the file, and need to figure out
+        # where to stop reading. Our end segment is the last segment
+        # containing part of the segment that we were asked to read.
+        _assert(self._read_length > 0, self._read_length)
+        end_data = self._offset + self._read_length
+
+        # We don't actually need to read the byte at end_data, but the one
+        # before it.
+        end = (end_data - 1) // self._segment_size
+        _assert(0 <= end < self._num_segments,
+                end=end, num_segments=self._num_segments,
+                end_data=end_data, offset=self._offset,
+                read_length=self._read_length, segment_size=self._segment_size)
+        self._last_segment = end
+        self.log("got end segment: %d" % self._last_segment)
 
         self._current_segment = self._start_segment
 
-    def _activate_enough_peers(self):
+    def _activate_enough_servers(self):
         """
         I populate self._active_readers with enough active readers to
         retrieve the contents of this mutable file. I am called before
@@ -435,9 +462,9 @@ class Retrieve:
         """
         # TODO: It would be cool to investigate other heuristics for
         # reader selection. For instance, the cost (in time the user
-        # spends waiting for their file) of selecting a really slow peer
+        # spends waiting for their file) of selecting a really slow server
         # that happens to have a primary share is probably more than
-        # selecting a really fast peer that doesn't have a primary
+        # selecting a really fast server that doesn't have a primary
         # share. Maybe the servermap could be extended to provide this
         # information; it could keep track of latency information while
         # it gathers more important data, and then this routine could
@@ -453,32 +480,29 @@ class Retrieve:
         #  instead of just reasoning about what the effect might be. Out
         #  of scope for MDMF, though.)
 
-        # We need at least self._required_shares readers to download a
-        # segment. If we're verifying, we need all shares.
-        if self._verify:
-            needed = self._total_shares
-        else:
-            needed = self._required_shares
         # XXX: Why don't format= log messages work here?
-        self.log("adding %d peers to the active peers list" % needed)
-
-        if len(self._active_readers) >= needed:
-            # enough shares are active
-            return
 
-        more = needed - len(self._active_readers)
         known_shnums = set(self.remaining_sharemap.keys())
         used_shnums = set([r.shnum for r in self._active_readers])
         unused_shnums = known_shnums - used_shnums
-        # We favor lower numbered shares, since FEC is faster with
-        # primary shares than with other shares, and lower-numbered
-        # shares are more likely to be primary than higher numbered
-        # shares.
-        new_shnums = sorted(unused_shnums)[:more]
-        if len(new_shnums) < more and not self._verify:
-            # We don't have enough readers to retrieve the file; fail.
-            self._raise_notenoughshareserror()
 
+        if self._verify:
+            new_shnums = unused_shnums # use them all
+        elif len(self._active_readers) < self._required_shares:
+            # need more shares
+            more = self._required_shares - len(self._active_readers)
+            # We favor lower numbered shares, since FEC is faster with
+            # primary shares than with other shares, and lower-numbered
+            # shares are more likely to be primary than higher numbered
+            # shares.
+            new_shnums = sorted(unused_shnums)[:more]
+            if len(new_shnums) < more:
+                # We don't have enough readers to retrieve the file; fail.
+                self._raise_notenoughshareserror()
+        else:
+            new_shnums = []
+
+        self.log("adding %d new servers to the active list" % len(new_shnums))
         for shnum in new_shnums:
             reader = self.readers[shnum]
             self._active_readers.append(reader)
@@ -489,10 +513,9 @@ class Retrieve:
             # segment decoding, then we'll take more drastic measures.
             if self._need_privkey and not self._node.is_readonly():
                 d = reader.get_encprivkey()
-                d.addCallback(self._try_to_validate_privkey, reader)
+                d.addCallback(self._try_to_validate_privkey, reader, reader.server)
                 # XXX: don't just drop the Deferred. We need error-reporting
                 # but not flow-control here.
-        assert len(self._active_readers) >= self._required_shares
 
     def _try_to_validate_prefix(self, prefix, reader):
         """
@@ -518,49 +541,15 @@ class Retrieve:
                                           "indicate an uncoordinated write")
         # Otherwise, we're okay -- no issues.
 
-
-    def _remove_reader(self, reader):
-        """
-        At various points, we will wish to remove a peer from
-        consideration and/or use. These include, but are not necessarily
-        limited to:
-
-            - A connection error.
-            - A mismatched prefix (that is, a prefix that does not match
-              our conception of the version information string).
-            - A failing block hash, salt hash, or share hash, which can
-              indicate disk failure/bit flips, or network trouble.
-
-        This method will do that. I will make sure that the
-        (shnum,reader) combination represented by my reader argument is
-        not used for anything else during this download. I will not
-        advise the reader of any corruption, something that my callers
-        may wish to do on their own.
-        """
-        # TODO: When you're done writing this, see if this is ever
-        # actually used for something that _mark_bad_share isn't. I have
-        # a feeling that they will be used for very similar things, and
-        # that having them both here is just going to be an epic amount
-        # of code duplication.
-        #
-        # (well, okay, not epic, but meaningful)
-        self.log("removing reader %s" % reader)
-        # Remove the reader from _active_readers
-        self._active_readers.remove(reader)
-        # TODO: self.readers.remove(reader)?
-        for shnum in list(self.remaining_sharemap.keys()):
-            self.remaining_sharemap.discard(shnum, reader.peerid)
-
-
-    def _mark_bad_share(self, reader, f):
+    def _mark_bad_share(self, server, shnum, reader, f):
         """
-        I mark the (peerid, shnum) encapsulated by my reader argument as
-        a bad share, which means that it will not be used anywhere else.
+        I mark the given (server, shnum) as a bad share, which means that it
+        will not be used anywhere else.
 
         There are several reasons to want to mark something as a bad
         share. These include:
 
-            - A connection error to the peer.
+            - A connection error to the server.
             - A mismatched prefix (that is, a prefix that does not match
               our local conception of the version information string).
             - A failing block hash, salt hash, share hash, or other
@@ -569,33 +558,38 @@ class Retrieve:
         This method will ensure that readers that we wish to mark bad
         (for these reasons or other reasons) are not used for the rest
         of the download. Additionally, it will attempt to tell the
-        remote peer (with no guarantee of success) that its share is
+        remote server (with no guarantee of success) that its share is
         corrupt.
         """
         self.log("marking share %d on server %s as bad" % \
-                 (reader.shnum, reader))
+                 (shnum, server.get_name()))
         prefix = self.verinfo[-2]
-        self.servermap.mark_bad_share(reader.peerid,
-                                      reader.shnum,
-                                      prefix)
-        self._remove_reader(reader)
-        self._bad_shares.add((reader.peerid, reader.shnum, f))
-        self._status.problems[reader.peerid] = f
+        self.servermap.mark_bad_share(server, shnum, prefix)
+        self._bad_shares.add((server, shnum, f))
+        self._status.add_problem(server, f)
         self._last_failure = f
-        self.notify_server_corruption(reader.peerid, reader.shnum,
-                                      str(f.value))
 
+        # Remove the reader from _active_readers
+        self._active_readers.remove(reader)
+        for shnum in list(self.remaining_sharemap.keys()):
+            self.remaining_sharemap.discard(shnum, reader.server)
+
+        if f.check(BadShareError):
+            self.notify_server_corruption(server, shnum, str(f.value))
 
     def _download_current_segment(self):
         """
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
-        assert len(self._active_readers) >= self._required_shares
+
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
             return self._done()
+        elif self._verify and len(self._active_readers) == 0:
+            self.log("no more good shares, no need to keep verifying")
+            return self._done()
         self.log("on segment %d of %d" %
                  (self._current_segment + 1, self._num_segments))
         d = self._process_segment(self._current_segment)
@@ -614,7 +608,6 @@ class Retrieve:
 
         # TODO: The old code uses a marker. Should this code do that
         # too? What did the Marker do?
-        assert len(self._active_readers) >= self._required_shares
 
         # We need to ask each of our active readers for its block and
         # salt. We will then validate those. If validation is
@@ -622,13 +615,16 @@ class Retrieve:
         ds = []
         for reader in self._active_readers:
             started = time.time()
-            d = reader.get_block_and_salt(segnum)
-            d2 = self._get_needed_hashes(reader, segnum)
-            dl = defer.DeferredList([d, d2], consumeErrors=True)
-            dl.addCallback(self._validate_block, segnum, reader, started)
-            dl.addErrback(self._validation_or_decoding_failed, [reader])
-            ds.append(dl)
-        dl = defer.DeferredList(ds)
+            d1 = reader.get_block_and_salt(segnum)
+            d2,d3 = self._get_needed_hashes(reader, segnum)
+            d = deferredutil.gatherResults([d1,d2,d3])
+            d.addCallback(self._validate_block, segnum, reader, reader.server, started)
+            # _handle_bad_share takes care of recoverable errors (by dropping
+            # that share and returning None). Any other errors (i.e. code
+            # bugs) are passed through and cause the retrieve to fail.
+            d.addErrback(self._handle_bad_share, [reader])
+            ds.append(d)
+        dl = deferredutil.gatherResults(ds)
         if self._verify:
             dl.addCallback(lambda ignored: "")
             dl.addCallback(self._set_segment)
@@ -637,34 +633,34 @@ class Retrieve:
         return dl
 
 
-    def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum):
+    def _maybe_decode_and_decrypt_segment(self, results, segnum):
         """
-        I take the results of fetching and validating the blocks from a
-        callback chain in another method. If the results are such that
-        they tell me that validation and fetching succeeded without
-        incident, I will proceed with decoding and decryption.
-        Otherwise, I will do nothing.
+        I take the results of fetching and validating the blocks from
+        _process_segment. If validation and fetching succeeded without
+        incident, I will proceed with decoding and decryption. Otherwise, I
+        will do nothing.
         """
         self.log("trying to decode and decrypt segment %d" % segnum)
-        failures = False
-        for block_and_salt in blocks_and_salts:
-            if not block_and_salt[0] or block_and_salt[1] == None:
-                self.log("some validation operations failed; not proceeding")
-                failures = True
-                break
-        if not failures:
-            self.log("everything looks ok, building segment %d" % segnum)
-            d = self._decode_blocks(blocks_and_salts, segnum)
-            d.addCallback(self._decrypt_segment)
-            d.addErrback(self._validation_or_decoding_failed,
-                         self._active_readers)
-            # check to see whether we've been paused before writing
-            # anything.
-            d.addCallback(self._check_for_paused)
-            d.addCallback(self._set_segment)
-            return d
-        else:
+
+        # 'results' is the output of a gatherResults set up in
+        # _process_segment(). Each component Deferred will either contain the
+        # non-Failure output of _validate_block() for a single block (i.e.
+        # {segnum:(block,salt)}), or None if _validate_block threw an
+        # exception and _validation_or_decoding_failed handled it (by
+        # dropping that server).
+
+        if None in results:
+            self.log("some validation operations failed; not proceeding")
             return defer.succeed(None)
+        self.log("everything looks ok, building segment %d" % segnum)
+        d = self._decode_blocks(results, segnum)
+        d.addCallback(self._decrypt_segment)
+        # check to see whether we've been paused before writing
+        # anything.
+        d.addCallback(self._check_for_paused)
+        d.addCallback(self._check_for_stopped)
+        d.addCallback(self._set_segment)
+        return d
 
 
     def _set_segment(self, segment):
@@ -673,33 +669,27 @@ class Retrieve:
         target that is handling the file download.
         """
         self.log("got plaintext for segment %d" % self._current_segment)
+
+        if self._read_length == 0:
+            self.log("on first+last segment, size=0, using 0 bytes")
+            segment = b""
+
+        if self._current_segment == self._last_segment:
+            # trim off the tail
+            wanted = (self._offset + self._read_length) % self._segment_size
+            if wanted != 0:
+                self.log("on the last segment: using first %d bytes" % wanted)
+                segment = segment[:wanted]
+            else:
+                self.log("on the last segment: using all %d bytes" %
+                         len(segment))
+
         if self._current_segment == self._start_segment:
-            # We're on the first segment. It's possible that we want
-            # only some part of the end of this segment, and that we
-            # just downloaded the whole thing to get that part. If so,
-            # we need to account for that and give the reader just the
-            # data that they want.
-            n = self._offset % self._segment_size
-            self.log("stripping %d bytes off of the first segment" % n)
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[n:]
-            self.log("new segment length: %d" % len(segment))
-
-        if self._current_segment == self._last_segment and self._read_length is not None:
-            # We're on the last segment. It's possible that we only want
-            # part of the beginning of this segment, and that we
-            # downloaded the whole thing anyway. Make sure to give the
-            # caller only the portion of the segment that they want to
-            # receive.
-            extra = self._read_length
-            if self._start_segment != self._last_segment:
-                extra -= self._segment_size - \
-                            (self._offset % self._segment_size)
-            extra %= self._segment_size
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[:extra]
-            self.log("new segment length: %d" % len(segment))
-            self.log("only taking %d bytes of the last segment" % extra)
+            # Trim off the head, if offset != 0. This should also work if
+            # start==last, because we trim the tail first.
+            skip = self._offset % self._segment_size
+            self.log("on the first segment: skipping first %d bytes" % skip)
+            segment = segment[skip:]
 
         if not self._verify:
             self._consumer.write(segment)
@@ -709,25 +699,37 @@ class Retrieve:
         self._current_segment += 1
 
 
-    def _validation_or_decoding_failed(self, f, readers):
+    def _handle_bad_share(self, f, readers):
         """
         I am called when a block or a salt fails to correctly validate, or when
         the decryption or decoding operation fails for some reason.  I react to
         this failure by notifying the remote server of corruption, and then
-        removing the remote peer from further activity.
+        removing the remote server from further activity.
         """
-        assert isinstance(readers, list)
+        # these are the errors we can tolerate: by giving up on this share
+        # and finding others to replace it. Any other errors (i.e. coding
+        # bugs) are re-raised, causing the download to fail.
+        f.trap(DeadReferenceError, RemoteException, BadShareError)
+
+        # DeadReferenceError happens when we try to fetch data from a server
+        # that has gone away. RemoteException happens if the server had an
+        # internal error. BadShareError encompasses: (UnknownVersionError,
+        # LayoutInvalid, struct.error) which happen when we get obviously
+        # wrong data, and CorruptShareError which happens later, when we
+        # perform integrity checks on the data.
+
+        precondition(isinstance(readers, list), readers)
         bad_shnums = [reader.shnum for reader in readers]
 
-        self.log("validation or decoding failed on share(s) %s, peer(s) %s "
+        self.log("validation or decoding failed on share(s) %s, server(s) %s "
                  ", segment %d: %s" % \
                  (bad_shnums, readers, self._current_segment, str(f)))
         for reader in readers:
-            self._mark_bad_share(reader, f)
-        return
+            self._mark_bad_share(reader.server, reader.shnum, reader, f)
+        return None
 
 
-    def _validate_block(self, results, segnum, reader, started):
+    def _validate_block(self, results, segnum, reader, server, started):
         """
         I validate a block from one share on a remote server.
         """
@@ -736,32 +738,18 @@ class Retrieve:
         self.log("validating share %d for segment %d" % (reader.shnum,
                                                              segnum))
         elapsed = time.time() - started
-        self._status.add_fetch_timing(reader.peerid, elapsed)
+        self._status.add_fetch_timing(server, elapsed)
         self._set_current_status("validating blocks")
-        # Did we fail to fetch either of the things that we were
-        # supposed to? Fail if so.
-        if not results[0][0] and results[1][0]:
-            # handled by the errback handler.
-
-            # These all get batched into one query, so the resulting
-            # failure should be the same for all of them, so we can just
-            # use the first one.
-            assert isinstance(results[0][1], failure.Failure)
-
-            f = results[0][1]
-            raise CorruptShareError(reader.peerid,
-                                    reader.shnum,
-                                    "Connection error: %s" % str(f))
 
-        block_and_salt, block_and_sharehashes = results
-        block, salt = block_and_salt[1]
-        blockhashes, sharehashes = block_and_sharehashes[1]
+        block_and_salt, blockhashes, sharehashes = results
+        block, salt = block_and_salt
+        _assert(type(block) is str, (block, salt))
 
-        blockhashes = dict(enumerate(blockhashes[1]))
+        blockhashes = dict(enumerate(blockhashes))
         self.log("the reader gave me the following blockhashes: %s" % \
                  blockhashes.keys())
         self.log("the reader gave me the following sharehashes: %s" % \
-                 sharehashes[1].keys())
+                 sharehashes.keys())
         bht = self._block_hash_trees[reader.shnum]
 
         if bht.needed_hashes(segnum, include_leaf=True):
@@ -769,7 +757,7 @@ class Retrieve:
                 bht.set_hashes(blockhashes)
             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                     IndexError), e:
-                raise CorruptShareError(reader.peerid,
+                raise CorruptShareError(server,
                                         reader.shnum,
                                         "block hash tree failure: %s" % e)
 
@@ -783,30 +771,27 @@ class Retrieve:
            bht.set_hashes(leaves={segnum: blockhash})
         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                 IndexError), e:
-            raise CorruptShareError(reader.peerid,
+            raise CorruptShareError(server,
                                     reader.shnum,
                                     "block hash tree failure: %s" % e)
 
         # Reaching this point means that we know that this segment
         # is correct. Now we need to check to see whether the share
-        # hash chain is also correct. 
+        # hash chain is also correct.
         # SDMF wrote share hash chains that didn't contain the
         # leaves, which would be produced from the block hash tree.
         # So we need to validate the block hash tree first. If
         # successful, then bht[0] will contain the root for the
         # shnum, which will be a leaf in the share hash tree, which
         # will allow us to validate the rest of the tree.
-        if self.share_hash_tree.needed_hashes(reader.shnum,
-                                              include_leaf=True) or \
-                                              self._verify:
-            try:
-                self.share_hash_tree.set_hashes(hashes=sharehashes[1],
-                                            leaves={reader.shnum: bht[0]})
-            except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
-                    IndexError), e:
-                raise CorruptShareError(reader.peerid,
-                                        reader.shnum,
-                                        "corrupt hashes: %s" % e)
+        try:
+            self.share_hash_tree.set_hashes(hashes=sharehashes,
+                                        leaves={reader.shnum: bht[0]})
+        except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
+                IndexError), e:
+            raise CorruptShareError(server,
+                                    reader.shnum,
+                                    "corrupt hashes: %s" % e)
 
         self.log('share %d is valid for segment %d' % (reader.shnum,
                                                        segnum))
@@ -831,32 +816,29 @@ class Retrieve:
         #needed.discard(0)
         self.log("getting blockhashes for segment %d, share %d: %s" % \
                  (segnum, reader.shnum, str(needed)))
-        d1 = reader.get_blockhashes(needed, force_remote=True)
+        # TODO is force_remote necessary here?
+        d1 = reader.get_blockhashes(needed, force_remote=False)
         if self.share_hash_tree.needed_hashes(reader.shnum):
             need = self.share_hash_tree.needed_hashes(reader.shnum)
             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
                                                                  str(need)))
-            d2 = reader.get_sharehashes(need, force_remote=True)
+            d2 = reader.get_sharehashes(need, force_remote=False)
         else:
             d2 = defer.succeed({}) # the logic in the next method
                                    # expects a dict
-        dl = defer.DeferredList([d1, d2], consumeErrors=True)
-        return dl
+        return d1,d2
 
 
-    def _decode_blocks(self, blocks_and_salts, segnum):
+    def _decode_blocks(self, results, segnum):
         """
         I take a list of k blocks and salts, and decode that into a
         single encrypted segment.
         """
-        d = {}
-        # We want to merge our dictionaries to the form 
-        # {shnum: blocks_and_salts}
-        #
-        # The dictionaries come from validate block that way, so we just
-        # need to merge them.
-        for block_and_salt in blocks_and_salts:
-            d.update(block_and_salt[1])
+        # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
+        # want to merge them all
+        blocks_and_salts = {}
+        for d in results:
+            blocks_and_salts.update(d)
 
         # All of these blocks should have the same salt; in SDMF, it is
         # the file-wide IV, while in MDMF it is the per-segment salt. In
@@ -865,10 +847,10 @@ class Retrieve:
         # d.items()[0] is like (shnum, (block, salt))
         # d.items()[0][1] is like (block, salt)
         # d.items()[0][1][1] is the salt.
-        salt = d.items()[0][1][1]
+        salt = blocks_and_salts.items()[0][1][1]
         # Next, extract just the blocks from the dict. We'll use the
         # salt in the next step.
-        share_and_shareids = [(k, v[0]) for k, v in d.items()]
+        share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
         d2 = dict(share_and_shareids)
         shareids = []
         shares = []
@@ -878,7 +860,7 @@ class Retrieve:
 
         self._set_current_status("decoding")
         started = time.time()
-        assert len(shareids) >= self._required_shares, len(shareids)
+        _assert(len(shareids) >= self._required_shares, len(shareids))
         # zfec really doesn't want extra shares
         shareids = shareids[:self._required_shares]
         shares = shares[:self._required_shares]
@@ -923,13 +905,13 @@ class Retrieve:
         return plaintext
 
 
-    def notify_server_corruption(self, peerid, shnum, reason):
-        ss = self.servermap.connections[peerid]
-        ss.callRemoteOnly("advise_corrupt_share",
-                          "mutable", self._storage_index, shnum, reason)
+    def notify_server_corruption(self, server, shnum, reason):
+        rref = server.get_rref()
+        rref.callRemoteOnly("advise_corrupt_share",
+                            "mutable", self._storage_index, shnum, reason)
 
 
-    def _try_to_validate_privkey(self, enc_privkey, reader):
+    def _try_to_validate_privkey(self, enc_privkey, reader, server):
         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
         if alleged_writekey != self._node.get_writekey():
@@ -937,13 +919,13 @@ class Retrieve:
                      (reader, reader.shnum),
                      level=log.WEIRD, umid="YIw4tA")
             if self._verify:
-                self.servermap.mark_bad_share(reader.peerid, reader.shnum,
+                self.servermap.mark_bad_share(server, reader.shnum,
                                               self.verinfo[-2])
-                e = CorruptShareError(reader.peerid,
+                e = CorruptShareError(server,
                                       reader.shnum,
                                       "invalid privkey")
                 f = failure.Failure(e)
-                self._bad_shares.add((reader.peerid, reader.shnum, f))
+                self._bad_shares.add((server, reader.shnum, f))
             return
 
         # it's good
@@ -978,7 +960,7 @@ class Retrieve:
         self._node._populate_total_shares(N)
 
         if self._verify:
-            ret = list(self._bad_shares)
+            ret = self._bad_shares
             self.log("done verifying, found %d bad shares" % len(ret))
         else:
             # TODO: upload status here?
@@ -989,23 +971,25 @@ class Retrieve:
 
     def _raise_notenoughshareserror(self):
         """
-        I am called by _activate_enough_peers when there are not enough
-        active peers left to complete the download. After making some
-        useful logging statements, I throw an exception to that effect
-        to the caller of this Retrieve object through
+        I am called when there are not enough active servers left to complete
+        the download. After making some useful logging statements, I throw an
+        exception to that effect to the caller of this Retrieve object through
         self._done_deferred.
         """
 
-        format = ("ran out of peers: "
-                  "have %(have)d of %(total)d segments "
-                  "found %(bad)d bad shares "
+        format = ("ran out of servers: "
+                  "have %(have)d of %(total)d segments; "
+                  "found %(bad)d bad shares; "
+                  "have %(remaining)d remaining shares of the right version; "
                   "encoding %(k)d-of-%(n)d")
         args = {"have": self._current_segment,
                 "total": self._num_segments,
                 "need": self._last_segment,
                 "k": self._required_shares,
                 "n": self._total_shares,
-                "bad": len(self._bad_shares)}
+                "bad": len(self._bad_shares),
+                "remaining": len(self.remaining_sharemap),
+               }
         raise NotEnoughSharesError("%s, last failure: %s" %
                                    (format % args, str(self._last_failure)))