]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/mutable/retrieve.py
Remove some bare asserts in retrieve.py (there are still quite a few left). refs...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
index 291fb120c9ab1f86c6f5547c81679980b3fd0a20..6c2c5c9bf8c48d81c267d1a5ca79aa2b70ce1011 100644 (file)
@@ -7,8 +7,10 @@ from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
      RemoteException
+
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
+from allmydata.util.assertutil import _assert
 from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
@@ -130,7 +132,7 @@ class Retrieve:
 
         # verify means that we are using the downloader logic to verify all
         # of our shares. This tells the downloader a few things.
-        # 
+        #
         # 1. We need to download all of the shares.
         # 2. We don't need to decode or decrypt the shares, since our
         #    caller doesn't care about the plaintext, only the
@@ -241,8 +243,8 @@ class Retrieve:
         self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
-        self._setup_download()
         self._setup_encoding_parameters()
+        self._setup_download()
         self.log("starting download")
         self._started_fetching = time.time()
         # The download process beyond this is a state machine.
@@ -286,24 +288,28 @@ class Retrieve:
         self.remaining_sharemap = DictOfSets()
         for (shnum, server, timestamp) in shares:
             self.remaining_sharemap.add(shnum, server)
-            # If the servermap update fetched anything, it fetched at least 1
-            # KiB, so we ask for that much.
-            # TODO: Change the cache methods to allow us to fetch all of the
-            # data that they have, then change this method to do that.
-            any_cache = self._node._read_from_cache(self.verinfo, shnum,
-                                                    0, 1000)
-            reader = MDMFSlotReadProxy(server.get_rref(),
-                                       self._storage_index,
-                                       shnum,
-                                       any_cache)
+            # Reuse the SlotReader from the servermap.
+            key = (self.verinfo, server.get_serverid(),
+                   self._storage_index, shnum)
+            if key in self.servermap.proxies:
+                reader = self.servermap.proxies[key]
+            else:
+                reader = MDMFSlotReadProxy(server.get_rref(),
+                                           self._storage_index, shnum, None)
             reader.server = server
             self.readers[shnum] = reader
-        assert len(self.remaining_sharemap) >= k
+
+        if len(self.remaining_sharemap) < k:
+            self._raise_notenoughshareserror()
 
         self.shares = {} # maps shnum to validated blocks
         self._active_readers = [] # list of active readers for this dl.
         self._block_hash_trees = {} # shnum => hashtree
 
+        for i in xrange(self._total_shares):
+            # So we don't have to do this later.
+            self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
+
         # We need one share hash tree for the entire file; its leaves
         # are the roots of the block hash trees for the shares that
         # comprise it, and its root is in the verinfo.
@@ -319,8 +325,7 @@ class Retrieve:
         segment with. I return the plaintext associated with that
         segment.
         """
-        # shnum => block hash tree. Unused, but setup_encoding_parameters will
-        # want to set this.
+        # We don't need the block hash trees in this case.
         self._block_hash_trees = None
         self._setup_encoding_parameters()
 
@@ -386,15 +391,10 @@ class Retrieve:
                  (k, n, self._num_segments, self._segment_size,
                   self._tail_segment_size))
 
-        if self._block_hash_trees is not None:
-            for i in xrange(self._total_shares):
-                # So we don't have to do this later.
-                self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
-
         # Our last task is to tell the downloader where to start and
         # where to stop. We use three parameters for that:
         #   - self._start_segment: the segment that we need to start
-        #     downloading from. 
+        #     downloading from.
         #   - self._current_segment: the next segment that we need to
         #     download.
         #   - self._last_segment: The last segment that we were asked to
@@ -407,10 +407,12 @@ class Retrieve:
         if self._offset:
             self.log("got offset: %d" % self._offset)
             # our start segment is the first segment containing the
-            # offset we were given. 
+            # offset we were given.
             start = self._offset // self._segment_size
 
-            assert start < self._num_segments
+            _assert(start < self._num_segments,
+                    start=start, num_segments=self._num_segments,
+                    offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
             self.log("got start segment: %d" % self._start_segment)
         else:
@@ -431,7 +433,10 @@ class Retrieve:
                 # but the one before it.
                 end = (end_data - 1) // self._segment_size
 
-                assert end < self._num_segments
+                _assert(end < self._num_segments,
+                        end=end, num_segments=self._num_segments,
+                        end_data=end_data, offset=self._offset, read_length=self._read_length,
+                        segment_size=self._segment_size)
                 self._last_segment = end
             else:
                 self._last_segment = self._start_segment
@@ -529,40 +534,6 @@ class Retrieve:
                                           "indicate an uncoordinated write")
         # Otherwise, we're okay -- no issues.
 
-
-    def _remove_reader(self, reader):
-        """
-        At various points, we will wish to remove a server from
-        consideration and/or use. These include, but are not necessarily
-        limited to:
-
-            - A connection error.
-            - A mismatched prefix (that is, a prefix that does not match
-              our conception of the version information string).
-            - A failing block hash, salt hash, or share hash, which can
-              indicate disk failure/bit flips, or network trouble.
-
-        This method will do that. I will make sure that the
-        (shnum,reader) combination represented by my reader argument is
-        not used for anything else during this download. I will not
-        advise the reader of any corruption, something that my callers
-        may wish to do on their own.
-        """
-        # TODO: When you're done writing this, see if this is ever
-        # actually used for something that _mark_bad_share isn't. I have
-        # a feeling that they will be used for very similar things, and
-        # that having them both here is just going to be an epic amount
-        # of code duplication.
-        #
-        # (well, okay, not epic, but meaningful)
-        self.log("removing reader %s" % reader)
-        # Remove the reader from _active_readers
-        self._active_readers.remove(reader)
-        # TODO: self.readers.remove(reader)?
-        for shnum in list(self.remaining_sharemap.keys()):
-            self.remaining_sharemap.discard(shnum, reader.server)
-
-
     def _mark_bad_share(self, server, shnum, reader, f):
         """
         I mark the given (server, shnum) as a bad share, which means that it
@@ -587,14 +558,18 @@ class Retrieve:
                  (shnum, server.get_name()))
         prefix = self.verinfo[-2]
         self.servermap.mark_bad_share(server, shnum, prefix)
-        self._remove_reader(reader)
         self._bad_shares.add((server, shnum, f))
         self._status.add_problem(server, f)
         self._last_failure = f
+
+        # Remove the reader from _active_readers
+        self._active_readers.remove(reader)
+        for shnum in list(self.remaining_sharemap.keys()):
+            self.remaining_sharemap.discard(shnum, reader.server)
+
         if f.check(BadShareError):
             self.notify_server_corruption(server, shnum, str(f.value))
 
-
     def _download_current_segment(self):
         """
         I download, validate, decode, decrypt, and assemble the segment
@@ -766,6 +741,7 @@ class Retrieve:
 
         block_and_salt, blockhashes, sharehashes = results
         block, salt = block_and_salt
+        assert type(block) is str, (block, salt)
 
         blockhashes = dict(enumerate(blockhashes))
         self.log("the reader gave me the following blockhashes: %s" % \
@@ -799,7 +775,7 @@ class Retrieve:
 
         # Reaching this point means that we know that this segment
         # is correct. Now we need to check to see whether the share
-        # hash chain is also correct. 
+        # hash chain is also correct.
         # SDMF wrote share hash chains that didn't contain the
         # leaves, which would be produced from the block hash tree.
         # So we need to validate the block hash tree first. If
@@ -838,12 +814,13 @@ class Retrieve:
         #needed.discard(0)
         self.log("getting blockhashes for segment %d, share %d: %s" % \
                  (segnum, reader.shnum, str(needed)))
-        d1 = reader.get_blockhashes(needed, force_remote=True)
+        # TODO is force_remote necessary here?
+        d1 = reader.get_blockhashes(needed, force_remote=False)
         if self.share_hash_tree.needed_hashes(reader.shnum):
             need = self.share_hash_tree.needed_hashes(reader.shnum)
             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
                                                                  str(need)))
-            d2 = reader.get_sharehashes(need, force_remote=True)
+            d2 = reader.get_sharehashes(need, force_remote=False)
         else:
             d2 = defer.succeed({}) # the logic in the next method
                                    # expects a dict
@@ -992,23 +969,25 @@ class Retrieve:
 
     def _raise_notenoughshareserror(self):
         """
-        I am called by _activate_enough_servers when there are not enough
-        active servers left to complete the download. After making some
-        useful logging statements, I throw an exception to that effect
-        to the caller of this Retrieve object through
+        I am called when there are not enough active servers left to complete
+        the download. After making some useful logging statements, I throw an
+        exception to that effect to the caller of this Retrieve object through
         self._done_deferred.
         """
 
         format = ("ran out of servers: "
-                  "have %(have)d of %(total)d segments "
-                  "found %(bad)d bad shares "
+                  "have %(have)d of %(total)d segments; "
+                  "found %(bad)d bad shares; "
+                  "have %(remaining)d remaining shares of the right version; "
                   "encoding %(k)d-of-%(n)d")
         args = {"have": self._current_segment,
                 "total": self._num_segments,
                 "need": self._last_segment,
                 "k": self._required_shares,
                 "n": self._total_shares,
-                "bad": len(self._bad_shares)}
+                "bad": len(self._bad_shares),
+                "remaining": len(self.remaining_sharemap),
+               }
         raise NotEnoughSharesError("%s, last failure: %s" %
                                    (format % args, str(self._last_failure)))