download: fix stopProducing failure ('self._paused_at not defined'), add tests
authorBrian Warner <warner@allmydata.com>
Mon, 14 Jul 2008 22:25:21 +0000 (15:25 -0700)
committerBrian Warner <warner@allmydata.com>
Mon, 14 Jul 2008 22:25:21 +0000 (15:25 -0700)
src/allmydata/download.py
src/allmydata/interfaces.py
src/allmydata/test/test_encode.py

index 572878ae5a4773d5ffa1a3ae7cfc3a76beb9b38a..a33016f68d2cc7804d692da1abc7f77bcb8ed727 100644 (file)
@@ -496,6 +496,8 @@ class FileDownloader:
 
     def resumeProducing(self):
         if self._paused:
+            paused_for = time.time() - self._paused_at
+            self._results.timings['paused'] += paused_for
             p = self._paused
             self._paused = None
             eventually(p.callback, None)
@@ -505,8 +507,7 @@ class FileDownloader:
     def stopProducing(self):
         self.log("Download.stopProducing")
         self._stopped = True
-        paused_for = time.time() - self._paused_at
-        self._results.timings['paused'] += paused_for
+        self.resumeProducing()
         if self._status:
             self._status.set_stopped(True)
             self._status.set_active(False)
index ff6b6b5b5a654ca0303d809ab4478ff64775431a..1fbf95b141c786bbac92f8d2a8c2c8e9de1c3808 100644 (file)
@@ -1175,6 +1175,9 @@ class IDecoder(Interface):
         """
 
 class IDownloadTarget(Interface):
+    # Note that if the IDownloadTarget is also an IConsumable, the downloader
+    # will register itself as a producer. This allows the target to invoke
+    # downloader.pauseProducing, resumeProducing, and stopProducing.
     def open(size):
         """Called before any calls to write() or close(). If an error
         occurs before any data is available, fail() may be called without
index 6a0711e3fee84d74eaf4371bec4359533d4b93e1..6ba0e6dbf35edc9e16f3c68c591585f0aa178eec 100644 (file)
@@ -1,11 +1,12 @@
 
 from zope.interface import implements
 from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
+from twisted.internet.interfaces import IConsumer
 from twisted.python.failure import Failure
 from foolscap import eventual
 from allmydata import encode, upload, download, hashtree, uri
-from allmydata.util import hashutil
+from allmydata.util import hashutil, testutil
 from allmydata.util.assertutil import _assert
 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
 
@@ -274,13 +275,47 @@ class Encode(unittest.TestCase):
         # 5 segments: 25, 25, 25, 25, 1
         return self.do_encode(25, 101, 100, 5, 15, 8)
 
-class Roundtrip(unittest.TestCase):
+class PausingTarget(download.Data):
+    implements(IConsumer)
+    def __init__(self):
+        download.Data.__init__(self)
+        self.size = 0
+        self.writes = 0
+    def write(self, data):
+        self.size += len(data)
+        self.writes += 1
+        if self.writes <= 2:
+            # we happen to use 4 segments, and want to avoid pausing on the
+            # last one (since then the _unpause timer will still be running)
+            self.producer.pauseProducing()
+            reactor.callLater(0.1, self._unpause)
+        return download.Data.write(self, data)
+    def _unpause(self):
+        self.producer.resumeProducing()
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+    def unregisterProducer(self):
+        self.producer = None
+
+class PausingAndStoppingTarget(PausingTarget):
+    def write(self, data):
+        self.producer.pauseProducing()
+        reactor.callLater(0.5, self._stop)
+    def _stop(self):
+        self.producer.stopProducing()
+
+class StoppingTarget(PausingTarget):
+    def write(self, data):
+        self.producer.stopProducing()
+
+class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
     def send_and_recover(self, k_and_happy_and_n=(25,75,100),
                          AVAILABLE_SHARES=None,
                          datalen=76,
                          max_segment_size=25,
                          bucket_modes={},
                          recover_mode="recover",
+                         target=None,
                          ):
         if AVAILABLE_SHARES is None:
             AVAILABLE_SHARES = k_and_happy_and_n[2]
@@ -288,7 +323,8 @@ class Roundtrip(unittest.TestCase):
         d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
                       max_segment_size, bucket_modes, data)
         # that fires with (uri_extension_hash, e, shareholders)
-        d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
+        d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode,
+                      target=target)
         # that fires with newdata
         def _downloaded((newdata, fd)):
             self.failUnless(newdata == data)
@@ -332,7 +368,7 @@ class Roundtrip(unittest.TestCase):
         return d
 
     def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
-                recover_mode):
+                recover_mode, target=None):
         (uri_extension_hash, required_shares, num_shares, file_size) = res
 
         if "corrupt_key" in recover_mode:
@@ -350,7 +386,8 @@ class Roundtrip(unittest.TestCase):
         URI = u.to_string()
 
         client = FakeClient()
-        target = download.Data()
+        if not target:
+            target = download.Data()
         fd = download.FileDownloader(client, URI, target)
 
         # we manually cycle the FileDownloader through a number of steps that
@@ -436,6 +473,29 @@ class Roundtrip(unittest.TestCase):
     def test_101(self):
         return self.send_and_recover(datalen=101)
 
+    def test_pause(self):
+        # use a DownloadTarget that does pauseProducing/resumeProducing a few
+        # times, then finishes
+        t = PausingTarget()
+        d = self.send_and_recover(target=t)
+        return d
+
+    def test_pause_then_stop(self):
+        # use a DownloadTarget that pauses, then stops.
+        t = PausingAndStoppingTarget()
+        d = self.shouldFail(download.DownloadStopped, "test_pause_then_stop",
+                            "our Consumer called stopProducing()",
+                            self.send_and_recover, target=t)
+        return d
+
+    def test_stop(self):
+        # use a DownloadTarget that does an immediate stop (ticket #473)
+        t = StoppingTarget()
+        d = self.shouldFail(download.DownloadStopped, "test_stop",
+                            "our Consumer called stopProducing()",
+                            self.send_and_recover, target=t)
+        return d
+
     # the following tests all use 4-out-of-10 encoding
 
     def test_bad_blocks(self):