]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/mutable/retrieve.py
mutable/retrieve.py: rewrite partial-read handling
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
index 3b55659349e52ebc0495c9d3de6beea0c539bbca..10a9ed310e7219afe40e1d2e55459c201cb1d299 100644 (file)
@@ -7,8 +7,10 @@ from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
      RemoteException
+
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
+from allmydata.util.assertutil import _assert, precondition
 from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
@@ -62,10 +64,9 @@ class RetrieveStatus:
         return self._problems
 
     def add_fetch_timing(self, server, elapsed):
-        serverid = server.get_serverid()
-        if serverid not in self.timings["fetch_per_server"]:
-            self.timings["fetch_per_server"][serverid] = []
-        self.timings["fetch_per_server"][serverid].append(elapsed)
+        if server not in self.timings["fetch_per_server"]:
+            self.timings["fetch_per_server"][server] = []
+        self.timings["fetch_per_server"][server].append(elapsed)
     def accumulate_decode_time(self, elapsed):
         self.timings["decode"] += elapsed
     def accumulate_decrypt_time(self, elapsed):
@@ -116,6 +117,10 @@ class Retrieve:
         self.servermap = servermap
         assert self._node.get_pubkey()
         self.verinfo = verinfo
+        # TODO: make it possible to use self.verinfo.datalength instead
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.verinfo
+        self._data_length = datalength
         # during repair, we may be called upon to grab the private key, since
         # it wasn't picked up during a verify=False checker run, and we'll
         # need it for repair to generate a new version.
@@ -131,7 +136,7 @@ class Retrieve:
 
         # verify means that we are using the downloader logic to verify all
         # of our shares. This tells the downloader a few things.
-        # 
+        #
         # 1. We need to download all of the shares.
         # 2. We don't need to decode or decrypt the shares, since our
         #    caller doesn't care about the plaintext, only the
@@ -146,8 +151,6 @@ class Retrieve:
         self._status.set_helper(False)
         self._status.set_progress(0.0)
         self._status.set_active(True)
-        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-         offsets_tuple) = self.verinfo
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
@@ -231,21 +234,37 @@ class Retrieve:
 
 
     def download(self, consumer=None, offset=0, size=None):
-        assert IConsumer.providedBy(consumer) or self._verify
-
+        precondition(self._verify or IConsumer.providedBy(consumer))
+        if size is None:
+            size = self._data_length - offset
+        if self._verify:
+            _assert(size == self._data_length, (size, self._data_length))
+        self.log("starting download")
+        self._done_deferred = defer.Deferred()
         if consumer:
             self._consumer = consumer
-            # we provide IPushProducer, so streaming=True, per
-            # IConsumer.
+            # we provide IPushProducer, so streaming=True, per IConsumer.
             self._consumer.registerProducer(self, streaming=True)
+        self._started = time.time()
+        self._started_fetching = time.time()
+        if size == 0:
+            # short-circuit the rest of the process
+            self._done()
+        else:
+            self._start_download(consumer, offset, size)
+        return self._done_deferred
+
+    def _start_download(self, consumer, offset, size):
+        precondition((0 <= offset < self._data_length)
+                     and (size > 0)
+                     and (offset+size <= self._data_length),
+                     (offset, size, self._data_length))
 
-        self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
-        self._setup_download()
         self._setup_encoding_parameters()
-        self.log("starting download")
-        self._started_fetching = time.time()
+        self._setup_download()
+
         # The download process beyond this is a state machine.
         # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
@@ -255,7 +274,6 @@ class Retrieve:
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
-        return self._done_deferred
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
@@ -266,7 +284,6 @@ class Retrieve:
         d.addErrback(self._error)
 
     def _setup_download(self):
-        self._started = time.time()
         self._status.set_status("Retrieving Shares")
 
         # how many shares do we need?
@@ -287,24 +304,28 @@ class Retrieve:
         self.remaining_sharemap = DictOfSets()
         for (shnum, server, timestamp) in shares:
             self.remaining_sharemap.add(shnum, server)
-            # If the servermap update fetched anything, it fetched at least 1
-            # KiB, so we ask for that much.
-            # TODO: Change the cache methods to allow us to fetch all of the
-            # data that they have, then change this method to do that.
-            any_cache = self._node._read_from_cache(self.verinfo, shnum,
-                                                    0, 1000)
-            reader = MDMFSlotReadProxy(server.get_rref(),
-                                       self._storage_index,
-                                       shnum,
-                                       any_cache)
+            # Reuse the SlotReader from the servermap.
+            key = (self.verinfo, server.get_serverid(),
+                   self._storage_index, shnum)
+            if key in self.servermap.proxies:
+                reader = self.servermap.proxies[key]
+            else:
+                reader = MDMFSlotReadProxy(server.get_rref(),
+                                           self._storage_index, shnum, None)
             reader.server = server
             self.readers[shnum] = reader
-        assert len(self.remaining_sharemap) >= k
+
+        if len(self.remaining_sharemap) < k:
+            self._raise_notenoughshareserror()
 
         self.shares = {} # maps shnum to validated blocks
         self._active_readers = [] # list of active readers for this dl.
         self._block_hash_trees = {} # shnum => hashtree
 
+        for i in xrange(self._total_shares):
+            # So we don't have to do this later.
+            self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
+
         # We need one share hash tree for the entire file; its leaves
         # are the roots of the block hash trees for the shares that
         # comprise it, and its root is in the verinfo.
@@ -320,9 +341,10 @@ class Retrieve:
         segment with. I return the plaintext associated with that
         segment.
         """
-        # shnum => block hash tree. Unused, but setup_encoding_parameters will
-        # want to set this.
+        # We don't need the block hash trees in this case.
         self._block_hash_trees = None
+        self._offset = 0
+        self._read_length = self._data_length
         self._setup_encoding_parameters()
 
         # _decode_blocks() expects the output of a gatherResults that
@@ -350,7 +372,7 @@ class Retrieve:
         self._required_shares = k
         self._total_shares = n
         self._segment_size = segsize
-        self._data_length = datalength
+        #self._data_length = datalength # set during __init__()
 
         if not IV:
             self._version = MDMF_VERSION
@@ -387,15 +409,10 @@ class Retrieve:
                  (k, n, self._num_segments, self._segment_size,
                   self._tail_segment_size))
 
-        if self._block_hash_trees is not None:
-            for i in xrange(self._total_shares):
-                # So we don't have to do this later.
-                self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
-
         # Our last task is to tell the downloader where to start and
         # where to stop. We use three parameters for that:
         #   - self._start_segment: the segment that we need to start
-        #     downloading from. 
+        #     downloading from.
         #   - self._current_segment: the next segment that we need to
         #     download.
         #   - self._last_segment: The last segment that we were asked to
@@ -408,37 +425,32 @@ class Retrieve:
         if self._offset:
             self.log("got offset: %d" % self._offset)
             # our start segment is the first segment containing the
-            # offset we were given. 
+            # offset we were given.
             start = self._offset // self._segment_size
 
-            assert start < self._num_segments
+            _assert(start <= self._num_segments,
+                    start=start, num_segments=self._num_segments,
+                    offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
             self.log("got start segment: %d" % self._start_segment)
         else:
             self._start_segment = 0
 
-
-        # If self._read_length is None, then we want to read the whole
-        # file. Otherwise, we want to read only part of the file, and
-        # need to figure out where to stop reading.
-        if self._read_length is not None:
-            # our end segment is the last segment containing part of the
-            # segment that we were asked to read.
-            self.log("got read length %d" % self._read_length)
-            if self._read_length != 0:
-                end_data = self._offset + self._read_length
-
-                # We don't actually need to read the byte at end_data,
-                # but the one before it.
-                end = (end_data - 1) // self._segment_size
-
-                assert end < self._num_segments
-                self._last_segment = end
-            else:
-                self._last_segment = self._start_segment
-            self.log("got end segment: %d" % self._last_segment)
-        else:
-            self._last_segment = self._num_segments - 1
+        # We might want to read only part of the file, and need to figure out
+        # where to stop reading. Our end segment is the last segment
+        # containing part of the segment that we were asked to read.
+        _assert(self._read_length > 0, self._read_length)
+        end_data = self._offset + self._read_length
+
+        # We don't actually need to read the byte at end_data, but the one
+        # before it.
+        end = (end_data - 1) // self._segment_size
+        _assert(0 <= end < self._num_segments,
+                end=end, num_segments=self._num_segments,
+                end_data=end_data, offset=self._offset,
+                read_length=self._read_length, segment_size=self._segment_size)
+        self._last_segment = end
+        self.log("got end segment: %d" % self._last_segment)
 
         self._current_segment = self._start_segment
 
@@ -505,7 +517,6 @@ class Retrieve:
                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
                 # XXX: don't just drop the Deferred. We need error-reporting
                 # but not flow-control here.
-        assert len(self._active_readers) >= self._required_shares
 
     def _try_to_validate_prefix(self, prefix, reader):
         """
@@ -531,40 +542,6 @@ class Retrieve:
                                           "indicate an uncoordinated write")
         # Otherwise, we're okay -- no issues.
 
-
-    def _remove_reader(self, reader):
-        """
-        At various points, we will wish to remove a server from
-        consideration and/or use. These include, but are not necessarily
-        limited to:
-
-            - A connection error.
-            - A mismatched prefix (that is, a prefix that does not match
-              our conception of the version information string).
-            - A failing block hash, salt hash, or share hash, which can
-              indicate disk failure/bit flips, or network trouble.
-
-        This method will do that. I will make sure that the
-        (shnum,reader) combination represented by my reader argument is
-        not used for anything else during this download. I will not
-        advise the reader of any corruption, something that my callers
-        may wish to do on their own.
-        """
-        # TODO: When you're done writing this, see if this is ever
-        # actually used for something that _mark_bad_share isn't. I have
-        # a feeling that they will be used for very similar things, and
-        # that having them both here is just going to be an epic amount
-        # of code duplication.
-        #
-        # (well, okay, not epic, but meaningful)
-        self.log("removing reader %s" % reader)
-        # Remove the reader from _active_readers
-        self._active_readers.remove(reader)
-        # TODO: self.readers.remove(reader)?
-        for shnum in list(self.remaining_sharemap.keys()):
-            self.remaining_sharemap.discard(shnum, reader.server)
-
-
     def _mark_bad_share(self, server, shnum, reader, f):
         """
         I mark the given (server, shnum) as a bad share, which means that it
@@ -589,24 +566,31 @@ class Retrieve:
                  (shnum, server.get_name()))
         prefix = self.verinfo[-2]
         self.servermap.mark_bad_share(server, shnum, prefix)
-        self._remove_reader(reader)
         self._bad_shares.add((server, shnum, f))
         self._status.add_problem(server, f)
         self._last_failure = f
+
+        # Remove the reader from _active_readers
+        self._active_readers.remove(reader)
+        for shnum in list(self.remaining_sharemap.keys()):
+            self.remaining_sharemap.discard(shnum, reader.server)
+
         if f.check(BadShareError):
             self.notify_server_corruption(server, shnum, str(f.value))
 
-
     def _download_current_segment(self):
         """
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
-        assert len(self._active_readers) >= self._required_shares
+
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
             return self._done()
+        elif self._verify and len(self._active_readers) == 0:
+            self.log("no more good shares, no need to keep verifying")
+            return self._done()
         self.log("on segment %d of %d" %
                  (self._current_segment + 1, self._num_segments))
         d = self._process_segment(self._current_segment)
@@ -625,7 +609,6 @@ class Retrieve:
 
         # TODO: The old code uses a marker. Should this code do that
         # too? What did the Marker do?
-        assert len(self._active_readers) >= self._required_shares
 
         # We need to ask each of our active readers for its block and
         # salt. We will then validate those. If validation is
@@ -687,33 +670,27 @@ class Retrieve:
         target that is handling the file download.
         """
         self.log("got plaintext for segment %d" % self._current_segment)
+
+        if self._read_length == 0:
+            self.log("on first+last segment, size=0, using 0 bytes")
+            segment = b""
+
+        if self._current_segment == self._last_segment:
+            # trim off the tail
+            wanted = (self._offset + self._read_length) % self._segment_size
+            if wanted != 0:
+                self.log("on the last segment: using first %d bytes" % wanted)
+                segment = segment[:wanted]
+            else:
+                self.log("on the last segment: using all %d bytes" %
+                         len(segment))
+
         if self._current_segment == self._start_segment:
-            # We're on the first segment. It's possible that we want
-            # only some part of the end of this segment, and that we
-            # just downloaded the whole thing to get that part. If so,
-            # we need to account for that and give the reader just the
-            # data that they want.
-            n = self._offset % self._segment_size
-            self.log("stripping %d bytes off of the first segment" % n)
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[n:]
-            self.log("new segment length: %d" % len(segment))
-
-        if self._current_segment == self._last_segment and self._read_length is not None:
-            # We're on the last segment. It's possible that we only want
-            # part of the beginning of this segment, and that we
-            # downloaded the whole thing anyway. Make sure to give the
-            # caller only the portion of the segment that they want to
-            # receive.
-            extra = self._read_length
-            if self._start_segment != self._last_segment:
-                extra -= self._segment_size - \
-                            (self._offset % self._segment_size)
-            extra %= self._segment_size
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[:extra]
-            self.log("new segment length: %d" % len(segment))
-            self.log("only taking %d bytes of the last segment" % extra)
+            # Trim off the head, if offset != 0. This should also work if
+            # start==last, because we trim the tail first.
+            skip = self._offset % self._segment_size
+            self.log("on the first segment: skipping first %d bytes" % skip)
+            segment = segment[skip:]
 
         if not self._verify:
             self._consumer.write(segment)
@@ -767,6 +744,7 @@ class Retrieve:
 
         block_and_salt, blockhashes, sharehashes = results
         block, salt = block_and_salt
+        assert type(block) is str, (block, salt)
 
         blockhashes = dict(enumerate(blockhashes))
         self.log("the reader gave me the following blockhashes: %s" % \
@@ -800,24 +778,21 @@ class Retrieve:
 
         # Reaching this point means that we know that this segment
         # is correct. Now we need to check to see whether the share
-        # hash chain is also correct. 
+        # hash chain is also correct.
         # SDMF wrote share hash chains that didn't contain the
         # leaves, which would be produced from the block hash tree.
         # So we need to validate the block hash tree first. If
         # successful, then bht[0] will contain the root for the
         # shnum, which will be a leaf in the share hash tree, which
         # will allow us to validate the rest of the tree.
-        if self.share_hash_tree.needed_hashes(reader.shnum,
-                                              include_leaf=True) or \
-                                              self._verify:
-            try:
-                self.share_hash_tree.set_hashes(hashes=sharehashes,
-                                            leaves={reader.shnum: bht[0]})
-            except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
-                    IndexError), e:
-                raise CorruptShareError(server,
-                                        reader.shnum,
-                                        "corrupt hashes: %s" % e)
+        try:
+            self.share_hash_tree.set_hashes(hashes=sharehashes,
+                                        leaves={reader.shnum: bht[0]})
+        except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
+                IndexError), e:
+            raise CorruptShareError(server,
+                                    reader.shnum,
+                                    "corrupt hashes: %s" % e)
 
         self.log('share %d is valid for segment %d' % (reader.shnum,
                                                        segnum))
@@ -842,12 +817,13 @@ class Retrieve:
         #needed.discard(0)
         self.log("getting blockhashes for segment %d, share %d: %s" % \
                  (segnum, reader.shnum, str(needed)))
-        d1 = reader.get_blockhashes(needed, force_remote=True)
+        # TODO is force_remote necessary here?
+        d1 = reader.get_blockhashes(needed, force_remote=False)
         if self.share_hash_tree.needed_hashes(reader.shnum):
             need = self.share_hash_tree.needed_hashes(reader.shnum)
             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
                                                                  str(need)))
-            d2 = reader.get_sharehashes(need, force_remote=True)
+            d2 = reader.get_sharehashes(need, force_remote=False)
         else:
             d2 = defer.succeed({}) # the logic in the next method
                                    # expects a dict
@@ -996,23 +972,25 @@ class Retrieve:
 
     def _raise_notenoughshareserror(self):
         """
-        I am called by _activate_enough_servers when there are not enough
-        active servers left to complete the download. After making some
-        useful logging statements, I throw an exception to that effect
-        to the caller of this Retrieve object through
+        I am called when there are not enough active servers left to complete
+        the download. After making some useful logging statements, I throw an
+        exception to that effect to the caller of this Retrieve object through
         self._done_deferred.
         """
 
         format = ("ran out of servers: "
-                  "have %(have)d of %(total)d segments "
-                  "found %(bad)d bad shares "
+                  "have %(have)d of %(total)d segments; "
+                  "found %(bad)d bad shares; "
+                  "have %(remaining)d remaining shares of the right version; "
                   "encoding %(k)d-of-%(n)d")
         args = {"have": self._current_segment,
                 "total": self._num_segments,
                 "need": self._last_segment,
                 "k": self._required_shares,
                 "n": self._total_shares,
-                "bad": len(self._bad_shares)}
+                "bad": len(self._bad_shares),
+                "remaining": len(self.remaining_sharemap),
+               }
         raise NotEnoughSharesError("%s, last failure: %s" %
                                    (format % args, str(self._last_failure)))