]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
mutable/retrieve.py: rewrite partial-read handling 177/head
authorBrian Warner <warner@lothar.com>
Tue, 28 Jul 2015 17:11:47 +0000 (10:11 -0700)
committerBrian Warner <warner@lothar.com>
Tue, 28 Jul 2015 17:13:32 +0000 (10:13 -0700)
This should tolerate offset/size combinations that read the last byte of
the file, something which was broken before. It quits early in the case
of zero-byte reads, to simplify the resulting "which segments do I need"
logic. Probably addresses ticket:2459.

src/allmydata/mutable/retrieve.py

index 6c2c5c9bf8c48d81c267d1a5ca79aa2b70ce1011..10a9ed310e7219afe40e1d2e55459c201cb1d299 100644 (file)
@@ -10,7 +10,7 @@ from foolscap.api import eventually, fireEventually, DeadReferenceError, \
 
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
 
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
-from allmydata.util.assertutil import _assert
+from allmydata.util.assertutil import _assert, precondition
 from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
 from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
@@ -117,6 +117,10 @@ class Retrieve:
         self.servermap = servermap
         assert self._node.get_pubkey()
         self.verinfo = verinfo
         self.servermap = servermap
         assert self._node.get_pubkey()
         self.verinfo = verinfo
+        # TODO: make it possible to use self.verinfo.datalength instead
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.verinfo
+        self._data_length = datalength
         # during repair, we may be called upon to grab the private key, since
         # it wasn't picked up during a verify=False checker run, and we'll
         # need it for repair to generate a new version.
         # during repair, we may be called upon to grab the private key, since
         # it wasn't picked up during a verify=False checker run, and we'll
         # need it for repair to generate a new version.
@@ -147,8 +151,6 @@ class Retrieve:
         self._status.set_helper(False)
         self._status.set_progress(0.0)
         self._status.set_active(True)
         self._status.set_helper(False)
         self._status.set_progress(0.0)
         self._status.set_active(True)
-        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-         offsets_tuple) = self.verinfo
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
@@ -232,21 +234,37 @@ class Retrieve:
 
 
     def download(self, consumer=None, offset=0, size=None):
 
 
     def download(self, consumer=None, offset=0, size=None):
-        assert IConsumer.providedBy(consumer) or self._verify
-
+        precondition(self._verify or IConsumer.providedBy(consumer))
+        if size is None:
+            size = self._data_length - offset
+        if self._verify:
+            _assert(size == self._data_length, (size, self._data_length))
+        self.log("starting download")
+        self._done_deferred = defer.Deferred()
         if consumer:
             self._consumer = consumer
         if consumer:
             self._consumer = consumer
-            # we provide IPushProducer, so streaming=True, per
-            # IConsumer.
+            # we provide IPushProducer, so streaming=True, per IConsumer.
             self._consumer.registerProducer(self, streaming=True)
             self._consumer.registerProducer(self, streaming=True)
+        self._started = time.time()
+        self._started_fetching = time.time()
+        if size == 0:
+            # short-circuit the rest of the process
+            self._done()
+        else:
+            self._start_download(consumer, offset, size)
+        return self._done_deferred
+
+    def _start_download(self, consumer, offset, size):
+        precondition((0 <= offset < self._data_length)
+                     and (size > 0)
+                     and (offset+size <= self._data_length),
+                     (offset, size, self._data_length))
 
 
-        self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
         self._setup_encoding_parameters()
         self._setup_download()
         self._offset = offset
         self._read_length = size
         self._setup_encoding_parameters()
         self._setup_download()
-        self.log("starting download")
-        self._started_fetching = time.time()
+
         # The download process beyond this is a state machine.
         # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
         # The download process beyond this is a state machine.
         # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
@@ -256,7 +274,6 @@ class Retrieve:
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
-        return self._done_deferred
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
@@ -267,7 +284,6 @@ class Retrieve:
         d.addErrback(self._error)
 
     def _setup_download(self):
         d.addErrback(self._error)
 
     def _setup_download(self):
-        self._started = time.time()
         self._status.set_status("Retrieving Shares")
 
         # how many shares do we need?
         self._status.set_status("Retrieving Shares")
 
         # how many shares do we need?
@@ -327,6 +343,8 @@ class Retrieve:
         """
         # We don't need the block hash trees in this case.
         self._block_hash_trees = None
         """
         # We don't need the block hash trees in this case.
         self._block_hash_trees = None
+        self._offset = 0
+        self._read_length = self._data_length
         self._setup_encoding_parameters()
 
         # _decode_blocks() expects the output of a gatherResults that
         self._setup_encoding_parameters()
 
         # _decode_blocks() expects the output of a gatherResults that
@@ -354,7 +372,7 @@ class Retrieve:
         self._required_shares = k
         self._total_shares = n
         self._segment_size = segsize
         self._required_shares = k
         self._total_shares = n
         self._segment_size = segsize
-        self._data_length = datalength
+        #self._data_length = datalength # set during __init__()
 
         if not IV:
             self._version = MDMF_VERSION
 
         if not IV:
             self._version = MDMF_VERSION
@@ -410,7 +428,7 @@ class Retrieve:
             # offset we were given.
             start = self._offset // self._segment_size
 
             # offset we were given.
             start = self._offset // self._segment_size
 
-            _assert(start < self._num_segments,
+            _assert(start <= self._num_segments,
                     start=start, num_segments=self._num_segments,
                     offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
                     start=start, num_segments=self._num_segments,
                     offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
@@ -418,31 +436,21 @@ class Retrieve:
         else:
             self._start_segment = 0
 
         else:
             self._start_segment = 0
 
-
-        # If self._read_length is None, then we want to read the whole
-        # file. Otherwise, we want to read only part of the file, and
-        # need to figure out where to stop reading.
-        if self._read_length is not None:
-            # our end segment is the last segment containing part of the
-            # segment that we were asked to read.
-            self.log("got read length %d" % self._read_length)
-            if self._read_length != 0:
-                end_data = self._offset + self._read_length
-
-                # We don't actually need to read the byte at end_data,
-                # but the one before it.
-                end = (end_data - 1) // self._segment_size
-
-                _assert(end < self._num_segments,
-                        end=end, num_segments=self._num_segments,
-                        end_data=end_data, offset=self._offset, read_length=self._read_length,
-                        segment_size=self._segment_size)
-                self._last_segment = end
-            else:
-                self._last_segment = self._start_segment
-            self.log("got end segment: %d" % self._last_segment)
-        else:
-            self._last_segment = self._num_segments - 1
+        # We might want to read only part of the file, and need to figure out
+        # where to stop reading. Our end segment is the last segment
+        # containing part of the segment that we were asked to read.
+        _assert(self._read_length > 0, self._read_length)
+        end_data = self._offset + self._read_length
+
+        # We don't actually need to read the byte at end_data, but the one
+        # before it.
+        end = (end_data - 1) // self._segment_size
+        _assert(0 <= end < self._num_segments,
+                end=end, num_segments=self._num_segments,
+                end_data=end_data, offset=self._offset,
+                read_length=self._read_length, segment_size=self._segment_size)
+        self._last_segment = end
+        self.log("got end segment: %d" % self._last_segment)
 
         self._current_segment = self._start_segment
 
 
         self._current_segment = self._start_segment
 
@@ -575,6 +583,7 @@ class Retrieve:
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
+
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
@@ -661,33 +670,27 @@ class Retrieve:
         target that is handling the file download.
         """
         self.log("got plaintext for segment %d" % self._current_segment)
         target that is handling the file download.
         """
         self.log("got plaintext for segment %d" % self._current_segment)
+
+        if self._read_length == 0:
+            self.log("on first+last segment, size=0, using 0 bytes")
+            segment = b""
+
+        if self._current_segment == self._last_segment:
+            # trim off the tail
+            wanted = (self._offset + self._read_length) % self._segment_size
+            if wanted != 0:
+                self.log("on the last segment: using first %d bytes" % wanted)
+                segment = segment[:wanted]
+            else:
+                self.log("on the last segment: using all %d bytes" %
+                         len(segment))
+
         if self._current_segment == self._start_segment:
         if self._current_segment == self._start_segment:
-            # We're on the first segment. It's possible that we want
-            # only some part of the end of this segment, and that we
-            # just downloaded the whole thing to get that part. If so,
-            # we need to account for that and give the reader just the
-            # data that they want.
-            n = self._offset % self._segment_size
-            self.log("stripping %d bytes off of the first segment" % n)
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[n:]
-            self.log("new segment length: %d" % len(segment))
-
-        if self._current_segment == self._last_segment and self._read_length is not None:
-            # We're on the last segment. It's possible that we only want
-            # part of the beginning of this segment, and that we
-            # downloaded the whole thing anyway. Make sure to give the
-            # caller only the portion of the segment that they want to
-            # receive.
-            extra = self._read_length
-            if self._start_segment != self._last_segment:
-                extra -= self._segment_size - \
-                            (self._offset % self._segment_size)
-            extra %= self._segment_size
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[:extra]
-            self.log("new segment length: %d" % len(segment))
-            self.log("only taking %d bytes of the last segment" % extra)
+            # Trim off the head, if offset != 0. This should also work if
+            # start==last, because we trim the tail first.
+            skip = self._offset % self._segment_size
+            self.log("on the first segment: skipping first %d bytes" % skip)
+            segment = segment[skip:]
 
         if not self._verify:
             self._consumer.write(segment)
 
         if not self._verify:
             self._consumer.write(segment)