]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
Merge branch 'sdmf-partial-2'
authorBrian Warner <warner@lothar.com>
Tue, 28 Jul 2015 18:02:17 +0000 (11:02 -0700)
committerBrian Warner <warner@lothar.com>
Tue, 28 Jul 2015 18:02:17 +0000 (11:02 -0700)
This cleans up the mutable-retrieve code a bit, and should fix some
corner cases where an offset/size combinations that reads the last byte
of the file (but not the first) could cause an assert to fire, making
the download hang. Should address ticket:2459 and ticket:2462.

src/allmydata/mutable/retrieve.py
src/allmydata/test/test_mutable.py

index 2be92163ec822df2c972bee9005a5ee264aa6a1f..10a9ed310e7219afe40e1d2e55459c201cb1d299 100644 (file)
@@ -7,8 +7,10 @@ from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
      RemoteException
+
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
      DownloadStopped, MDMF_VERSION, SDMF_VERSION
+from allmydata.util.assertutil import _assert, precondition
 from allmydata.util import hashutil, log, mathutil, deferredutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
@@ -115,6 +117,10 @@ class Retrieve:
         self.servermap = servermap
         assert self._node.get_pubkey()
         self.verinfo = verinfo
+        # TODO: make it possible to use self.verinfo.datalength instead
+        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
+         offsets_tuple) = self.verinfo
+        self._data_length = datalength
         # during repair, we may be called upon to grab the private key, since
         # it wasn't picked up during a verify=False checker run, and we'll
         # need it for repair to generate a new version.
@@ -145,8 +151,6 @@ class Retrieve:
         self._status.set_helper(False)
         self._status.set_progress(0.0)
         self._status.set_active(True)
-        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
-         offsets_tuple) = self.verinfo
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
@@ -230,21 +234,37 @@ class Retrieve:
 
 
     def download(self, consumer=None, offset=0, size=None):
-        assert IConsumer.providedBy(consumer) or self._verify
-
+        precondition(self._verify or IConsumer.providedBy(consumer))
+        if size is None:
+            size = self._data_length - offset
+        if self._verify:
+            _assert(size == self._data_length, (size, self._data_length))
+        self.log("starting download")
+        self._done_deferred = defer.Deferred()
         if consumer:
             self._consumer = consumer
-            # we provide IPushProducer, so streaming=True, per
-            # IConsumer.
+            # we provide IPushProducer, so streaming=True, per IConsumer.
             self._consumer.registerProducer(self, streaming=True)
+        self._started = time.time()
+        self._started_fetching = time.time()
+        if size == 0:
+            # short-circuit the rest of the process
+            self._done()
+        else:
+            self._start_download(consumer, offset, size)
+        return self._done_deferred
+
+    def _start_download(self, consumer, offset, size):
+        precondition((0 <= offset < self._data_length)
+                     and (size > 0)
+                     and (offset+size <= self._data_length),
+                     (offset, size, self._data_length))
 
-        self._done_deferred = defer.Deferred()
         self._offset = offset
         self._read_length = size
         self._setup_encoding_parameters()
         self._setup_download()
-        self.log("starting download")
-        self._started_fetching = time.time()
+
         # The download process beyond this is a state machine.
         # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
@@ -254,7 +274,6 @@ class Retrieve:
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
-        return self._done_deferred
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
@@ -265,7 +284,6 @@ class Retrieve:
         d.addErrback(self._error)
 
     def _setup_download(self):
-        self._started = time.time()
         self._status.set_status("Retrieving Shares")
 
         # how many shares do we need?
@@ -325,6 +343,8 @@ class Retrieve:
         """
         # We don't need the block hash trees in this case.
         self._block_hash_trees = None
+        self._offset = 0
+        self._read_length = self._data_length
         self._setup_encoding_parameters()
 
         # _decode_blocks() expects the output of a gatherResults that
@@ -352,7 +372,7 @@ class Retrieve:
         self._required_shares = k
         self._total_shares = n
         self._segment_size = segsize
-        self._data_length = datalength
+        #self._data_length = datalength # set during __init__()
 
         if not IV:
             self._version = MDMF_VERSION
@@ -408,34 +428,29 @@ class Retrieve:
             # offset we were given.
             start = self._offset // self._segment_size
 
-            assert start < self._num_segments
+            _assert(start <= self._num_segments,
+                    start=start, num_segments=self._num_segments,
+                    offset=self._offset, segment_size=self._segment_size)
             self._start_segment = start
             self.log("got start segment: %d" % self._start_segment)
         else:
             self._start_segment = 0
 
-
-        # If self._read_length is None, then we want to read the whole
-        # file. Otherwise, we want to read only part of the file, and
-        # need to figure out where to stop reading.
-        if self._read_length is not None:
-            # our end segment is the last segment containing part of the
-            # segment that we were asked to read.
-            self.log("got read length %d" % self._read_length)
-            if self._read_length != 0:
-                end_data = self._offset + self._read_length
-
-                # We don't actually need to read the byte at end_data,
-                # but the one before it.
-                end = (end_data - 1) // self._segment_size
-
-                assert end < self._num_segments
-                self._last_segment = end
-            else:
-                self._last_segment = self._start_segment
-            self.log("got end segment: %d" % self._last_segment)
-        else:
-            self._last_segment = self._num_segments - 1
+        # We might want to read only part of the file, and need to figure out
+        # where to stop reading. Our end segment is the last segment
+        # containing part of the segment that we were asked to read.
+        _assert(self._read_length > 0, self._read_length)
+        end_data = self._offset + self._read_length
+
+        # We don't actually need to read the byte at end_data, but the one
+        # before it.
+        end = (end_data - 1) // self._segment_size
+        _assert(0 <= end < self._num_segments,
+                end=end, num_segments=self._num_segments,
+                end_data=end_data, offset=self._offset,
+                read_length=self._read_length, segment_size=self._segment_size)
+        self._last_segment = end
+        self.log("got end segment: %d" % self._last_segment)
 
         self._current_segment = self._start_segment
 
@@ -568,6 +583,7 @@ class Retrieve:
         I download, validate, decode, decrypt, and assemble the segment
         that this Retrieve is currently responsible for downloading.
         """
+
         if self._current_segment > self._last_segment:
             # No more segments to download, we're done.
             self.log("got plaintext, done")
@@ -654,33 +670,27 @@ class Retrieve:
         target that is handling the file download.
         """
         self.log("got plaintext for segment %d" % self._current_segment)
+
+        if self._read_length == 0:
+            self.log("on first+last segment, size=0, using 0 bytes")
+            segment = b""
+
+        if self._current_segment == self._last_segment:
+            # trim off the tail
+            wanted = (self._offset + self._read_length) % self._segment_size
+            if wanted != 0:
+                self.log("on the last segment: using first %d bytes" % wanted)
+                segment = segment[:wanted]
+            else:
+                self.log("on the last segment: using all %d bytes" %
+                         len(segment))
+
         if self._current_segment == self._start_segment:
-            # We're on the first segment. It's possible that we want
-            # only some part of the end of this segment, and that we
-            # just downloaded the whole thing to get that part. If so,
-            # we need to account for that and give the reader just the
-            # data that they want.
-            n = self._offset % self._segment_size
-            self.log("stripping %d bytes off of the first segment" % n)
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[n:]
-            self.log("new segment length: %d" % len(segment))
-
-        if self._current_segment == self._last_segment and self._read_length is not None:
-            # We're on the last segment. It's possible that we only want
-            # part of the beginning of this segment, and that we
-            # downloaded the whole thing anyway. Make sure to give the
-            # caller only the portion of the segment that they want to
-            # receive.
-            extra = self._read_length
-            if self._start_segment != self._last_segment:
-                extra -= self._segment_size - \
-                            (self._offset % self._segment_size)
-            extra %= self._segment_size
-            self.log("original segment length: %d" % len(segment))
-            segment = segment[:extra]
-            self.log("new segment length: %d" % len(segment))
-            self.log("only taking %d bytes of the last segment" % extra)
+            # Trim off the head, if offset != 0. This should also work if
+            # start==last, because we trim the tail first.
+            skip = self._offset % self._segment_size
+            self.log("on the first segment: skipping first %d bytes" % skip)
+            segment = segment[skip:]
 
         if not self._verify:
             self._consumer.write(segment)
index 890d294e84d314d1e822a9a48c72362a601ddf0e..c9711c6dc3ac5287a6406f746667ce0d5cc18bc5 100644 (file)
@@ -17,7 +17,7 @@ from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
 from allmydata.monitor import Monitor
 from allmydata.test.common import ShouldFailMixin
 from allmydata.test.no_network import GridTestMixin
-from foolscap.api import eventually, fireEventually
+from foolscap.api import eventually, fireEventually, flushEventualQueue
 from foolscap.logging import log
 from allmydata.storage_client import StorageFarmBroker
 from allmydata.storage.common import storage_index_to_dir
@@ -916,11 +916,13 @@ class PublishMixin:
         d.addCallback(_created)
         return d
 
-    def publish_mdmf(self):
+    def publish_mdmf(self, data=None):
         # like publish_one, except that the result is guaranteed to be
         # an MDMF file.
         # self.CONTENTS should have more than one segment.
-        self.CONTENTS = "This is an MDMF file" * 100000
+        if data is None:
+            data = "This is an MDMF file" * 100000
+        self.CONTENTS = data
         self.uploadable = MutableData(self.CONTENTS)
         self._storage = FakeStorage()
         self._nodemaker = make_nodemaker(self._storage)
@@ -933,10 +935,12 @@ class PublishMixin:
         return d
 
 
-    def publish_sdmf(self):
+    def publish_sdmf(self, data=None):
         # like publish_one, except that the result is guaranteed to be
         # an SDMF file
-        self.CONTENTS = "This is an SDMF file" * 1000
+        if data is None:
+            data = "This is an SDMF file" * 1000
+        self.CONTENTS = data
         self.uploadable = MutableData(self.CONTENTS)
         self._storage = FakeStorage()
         self._nodemaker = make_nodemaker(self._storage)
@@ -948,20 +952,6 @@ class PublishMixin:
         d.addCallback(_created)
         return d
 
-    def publish_empty_sdmf(self):
-        self.CONTENTS = ""
-        self.uploadable = MutableData(self.CONTENTS)
-        self._storage = FakeStorage()
-        self._nodemaker = make_nodemaker(self._storage, keysize=None)
-        self._storage_broker = self._nodemaker.storage_broker
-        d = self._nodemaker.create_mutable_file(self.uploadable,
-                                                version=SDMF_VERSION)
-        def _created(node):
-            self._fn = node
-            self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
-        d.addCallback(_created)
-        return d
-
 
     def publish_multiple(self, version=0):
         self.CONTENTS = ["Contents 0",
@@ -1903,6 +1893,19 @@ class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
                       "test_verify_mdmf_bad_encprivkey_uncheckable")
         return d
 
+    def test_verify_sdmf_empty(self):
+        d = self.publish_sdmf("")
+        d.addCallback(lambda ignored: self._fn.check(Monitor(), verify=True))
+        d.addCallback(self.check_good, "test_verify_sdmf")
+        d.addCallback(flushEventualQueue)
+        return d
+
+    def test_verify_mdmf_empty(self):
+        d = self.publish_mdmf("")
+        d.addCallback(lambda ignored: self._fn.check(Monitor(), verify=True))
+        d.addCallback(self.check_good, "test_verify_mdmf")
+        d.addCallback(flushEventualQueue)
+        return d
 
 class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
 
@@ -2155,7 +2158,7 @@ class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
         # In the buggy version, the check that precedes the retrieve+publish
         # cycle uses MODE_READ, instead of MODE_REPAIR, and fails to get the
         # privkey that repair needs.
-        d = self.publish_empty_sdmf()
+        d = self.publish_sdmf("")
         def _delete_one_share(ign):
             shares = self._storage._peers
             for peerid in shares:
@@ -3068,11 +3071,13 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
         self.c = self.g.clients[0]
         self.nm = self.c.nodemaker
         self.data = "test data" * 100000 # about 900 KiB; MDMF
-        self.small_data = "test data" * 10 # about 90 B; SDMF
+        self.small_data = "test data" * 10 # 90 B; SDMF
 
 
-    def do_upload_mdmf(self):
-        d = self.nm.create_mutable_file(MutableData(self.data),
+    def do_upload_mdmf(self, data=None):
+        if data is None:
+            data = self.data
+        d = self.nm.create_mutable_file(MutableData(data),
                                         version=MDMF_VERSION)
         def _then(n):
             assert isinstance(n, MutableFileNode)
@@ -3082,8 +3087,10 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
         d.addCallback(_then)
         return d
 
-    def do_upload_sdmf(self):
-        d = self.nm.create_mutable_file(MutableData(self.small_data))
+    def do_upload_sdmf(self, data=None):
+        if data is None:
+            data = self.small_data
+        d = self.nm.create_mutable_file(MutableData(data))
         def _then(n):
             assert isinstance(n, MutableFileNode)
             assert n._protocol_version == SDMF_VERSION
@@ -3359,56 +3366,127 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
         return d
 
 
-    def test_partial_read(self):
-        d = self.do_upload_mdmf()
-        d.addCallback(lambda ign: self.mdmf_node.get_best_readable_version())
-        modes = [("start_on_segment_boundary",
-                  mathutil.next_multiple(128 * 1024, 3), 50),
-                 ("ending_one_byte_after_segment_boundary",
-                  mathutil.next_multiple(128 * 1024, 3)-50, 51),
-                 ("zero_length_at_start", 0, 0),
-                 ("zero_length_in_middle", 50, 0),
-                 ("zero_length_at_segment_boundary",
-                  mathutil.next_multiple(128 * 1024, 3), 0),
-                 ]
+    def _test_partial_read(self, node, expected, modes, step):
+        d = node.get_best_readable_version()
         for (name, offset, length) in modes:
-            d.addCallback(self._do_partial_read, name, offset, length)
-        # then read only a few bytes at a time, and see that the results are
-        # what we expect.
+            d.addCallback(self._do_partial_read, name, expected, offset, length)
+        # then read the whole thing, but only a few bytes at a time, and see
+        # that the results are what we expect.
         def _read_data(version):
             c = consumer.MemoryConsumer()
             d2 = defer.succeed(None)
-            for i in xrange(0, len(self.data), 10000):
-                d2.addCallback(lambda ignored, i=i: version.read(c, i, 10000))
+            for i in xrange(0, len(expected), step):
+                d2.addCallback(lambda ignored, i=i: version.read(c, i, step))
             d2.addCallback(lambda ignored:
-                self.failUnlessEqual(self.data, "".join(c.chunks)))
+                self.failUnlessEqual(expected, "".join(c.chunks)))
             return d2
         d.addCallback(_read_data)
         return d
-    def _do_partial_read(self, version, name, offset, length):
+
+    def _do_partial_read(self, version, name, expected, offset, length):
         c = consumer.MemoryConsumer()
         d = version.read(c, offset, length)
-        expected = self.data[offset:offset+length]
+        if length is None:
+            expected_range = expected[offset:]
+        else:
+            expected_range = expected[offset:offset+length]
         d.addCallback(lambda ignored: "".join(c.chunks))
         def _check(results):
-            if results != expected:
-                print
+            if results != expected_range:
+                print "read([%d]+%s) got %d bytes, not %d" % \
+                      (offset, length, len(results), len(expected_range))
                 print "got: %s ... %s" % (results[:20], results[-20:])
-                print "exp: %s ... %s" % (expected[:20], expected[-20:])
-                self.fail("results[%s] != expected" % name)
+                print "exp: %s ... %s" % (expected_range[:20], expected_range[-20:])
+                self.fail("results[%s] != expected_range" % name)
             return version # daisy-chained to next call
         d.addCallback(_check)
         return d
 
+    def test_partial_read_mdmf_0(self):
+        data = ""
+        d = self.do_upload_mdmf(data=data)
+        modes = [("all1",    0,0),
+                 ("all2",    0,None),
+                 ]
+        d.addCallback(self._test_partial_read, data, modes, 1)
+        return d
+
+    def test_partial_read_mdmf_large(self):
+        segment_boundary = mathutil.next_multiple(128 * 1024, 3)
+        modes = [("start_on_segment_boundary",              segment_boundary, 50),
+                 ("ending_one_byte_after_segment_boundary", segment_boundary-50, 51),
+                 ("zero_length_at_start",                   0, 0),
+                 ("zero_length_in_middle",                  50, 0),
+                 ("zero_length_at_segment_boundary",        segment_boundary, 0),
+                 ("complete_file1",                         0, len(self.data)),
+                 ("complete_file2",                         0, None),
+                 ]
+        d = self.do_upload_mdmf()
+        d.addCallback(self._test_partial_read, self.data, modes, 10000)
+        return d
+
+    def test_partial_read_sdmf_0(self):
+        data = ""
+        modes = [("all1",    0,0),
+                 ("all2",    0,None),
+                 ]
+        d = self.do_upload_sdmf(data=data)
+        d.addCallback(self._test_partial_read, data, modes, 1)
+        return d
+
+    def test_partial_read_sdmf_2(self):
+        data = "hi"
+        modes = [("one_byte",                  0, 1),
+                 ("last_byte",                 1, 1),
+                 ("last_byte2",                1, None),
+                 ("complete_file",             0, 2),
+                 ("complete_file2",            0, None),
+                 ]
+        d = self.do_upload_sdmf(data=data)
+        d.addCallback(self._test_partial_read, data, modes, 1)
+        return d
+
+    def test_partial_read_sdmf_90(self):
+        modes = [("start_at_middle",           50, 40),
+                 ("start_at_middle2",          50, None),
+                 ("zero_length_at_start",      0, 0),
+                 ("zero_length_in_middle",     50, 0),
+                 ("zero_length_at_end",        90, 0),
+                 ("complete_file1",            0, None),
+                 ("complete_file2",            0, 90),
+                 ]
+        d = self.do_upload_sdmf()
+        d.addCallback(self._test_partial_read, self.small_data, modes, 10)
+        return d
+
+    def test_partial_read_sdmf_100(self):
+        data = "test data "*10
+        modes = [("start_at_middle",           50, 50),
+                 ("start_at_middle2",          50, None),
+                 ("zero_length_at_start",      0, 0),
+                 ("zero_length_in_middle",     50, 0),
+                 ("complete_file1",            0, 100),
+                 ("complete_file2",            0, None),
+                 ]
+        d = self.do_upload_sdmf(data=data)
+        d.addCallback(self._test_partial_read, data, modes, 10)
+        return d
+
 
     def _test_read_and_download(self, node, expected):
         d = node.get_best_readable_version()
         def _read_data(version):
             c = consumer.MemoryConsumer()
+            c2 = consumer.MemoryConsumer()
             d2 = defer.succeed(None)
             d2.addCallback(lambda ignored: version.read(c))
             d2.addCallback(lambda ignored:
                 self.failUnlessEqual(expected, "".join(c.chunks)))
+
+            d2.addCallback(lambda ignored: version.read(c2, offset=0,
+                                                        size=len(expected)))
+            d2.addCallback(lambda ignored:
+                self.failUnlessEqual(expected, "".join(c2.chunks)))
             return d2
         d.addCallback(_read_data)
         d.addCallback(lambda ignored: node.download_best_version())
@@ -3441,7 +3519,7 @@ class Update(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
         self.c = self.g.clients[0]
         self.nm = self.c.nodemaker
         self.data = "testdata " * 100000 # about 900 KiB; MDMF
-        self.small_data = "test data" * 10 # about 90 B; SDMF
+        self.small_data = "test data" * 10 # 90 B; SDMF
 
 
     def do_upload_sdmf(self):