From a15ce96846c301a417116ef66f4831767ce93ffc Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Mon, 5 Sep 2011 12:02:42 -0700
Subject: [PATCH] Retrieve: implement/test stopProducing

---
 src/allmydata/mutable/retrieve.py  |  9 ++-
 src/allmydata/test/test_mutable.py | 98 +++++++++++++++++-------------
 2 files changed, 63 insertions(+), 44 deletions(-)

diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py
index 100350a0..9c09abfe 100644
--- a/src/allmydata/mutable/retrieve.py
+++ b/src/allmydata/mutable/retrieve.py
@@ -7,7 +7,7 @@ from twisted.python import failure
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from foolscap.api import eventually, fireEventually
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
-                                 MDMF_VERSION, SDMF_VERSION
+     DownloadStopped, MDMF_VERSION, SDMF_VERSION
 from allmydata.util import hashutil, log, mathutil
 from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
@@ -143,6 +143,7 @@ class Retrieve:
         self._status.set_size(datalength)
         self._status.set_encoding(k, N)
         self.readers = {}
+        self._stopped = False
         self._pause_deferred = None
         self._offset = None
         self._read_length = None
@@ -196,6 +197,10 @@ class Retrieve:
 
         eventually(p.callback, None)
 
+    def stopProducing(self):
+        self._stopped = True
+        self.resumeProducing()
+
 
     def _check_for_paused(self, res):
         """
@@ -205,6 +210,8 @@ class Retrieve:
         the Deferred fires immediately. Otherwise, the Deferred fires
         when the downloader is unpaused.
         """
+        if self._stopped:
+            raise DownloadStopped("our Consumer called stopProducing()")
         if self._pause_deferred is not None:
             d = defer.Deferred()
             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 69ce131b..2b119ee1 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -3,16 +3,15 @@ import os, re, base64
 from cStringIO import StringIO
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
-from twisted.internet.interfaces import IConsumer
-from zope.interface import implements
 from allmydata import uri, client
 from allmydata.nodemaker import NodeMaker
 from allmydata.util import base32, consumer, fileutil, mathutil
 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
      ssk_pubkey_fingerprint_hash
+from allmydata.util.consumer import MemoryConsumer
 from allmydata.util.deferredutil import gatherResults
 from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
-     NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION
+     NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION, DownloadStopped
 from allmydata.monitor import Monitor
 from allmydata.test.common import ShouldFailMixin
 from allmydata.test.no_network import GridTestMixin
@@ -37,6 +36,9 @@ from allmydata.mutable.repairer import MustForceRepairError
 
 import allmydata.test.common_util as testutil
 from allmydata.test.common import TEST_RSA_KEY_SIZE
+from allmydata.test.test_download import PausingConsumer, \
+     PausingAndStoppingConsumer, StoppingConsumer, \
+     ImmediatelyStoppingConsumer
 
 
 # this "FakeStorage" exists to put the share data in RAM and avoid using real
@@ -544,26 +546,60 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
         return d
 
 
-    def test_retrieve_pause(self):
-        # We should make sure that the retriever is able to pause
+    def test_retrieve_producer_mdmf(self):
+        # We should make sure that the retriever is able to pause and stop
         # correctly.
-        d = self.nodemaker.create_mutable_file(version=MDMF_VERSION)
-        def _created(node):
-            self.node = node
+        data = "contents1" * 100000
+        d = self.nodemaker.create_mutable_file(MutableData(data),
+                                               version=MDMF_VERSION)
+        d.addCallback(lambda node: node.get_best_mutable_version())
+        d.addCallback(self._test_retrieve_producer, "MDMF", data)
+        return d
 
-            return node.overwrite(MutableData("contents1" * 100000))
-        d.addCallback(_created)
-        # Now we'll retrieve it into a pausing consumer.
-        d.addCallback(lambda ignored:
-            self.node.get_best_mutable_version())
-        def _got_version(version):
-            self.c = PausingConsumer()
-            return version.read(self.c)
-        d.addCallback(_got_version)
-        d.addCallback(lambda ignored:
-            self.failUnlessEqual(self.c.data, "contents1" * 100000))
+    # note: SDMF has only one big segment, so we can't use the usual
+    # after-the-first-write() trick to pause or stop the download.
+    # Disabled until we find a better approach.
+    def OFF_test_retrieve_producer_sdmf(self):
+        data = "contents1" * 100000
+        d = self.nodemaker.create_mutable_file(MutableData(data),
+                                               version=SDMF_VERSION)
+        d.addCallback(lambda node: node.get_best_mutable_version())
+        d.addCallback(self._test_retrieve_producer, "SDMF", data)
         return d
 
+    def _test_retrieve_producer(self, version, kind, data):
+        # Now we'll retrieve it into a pausing consumer.
+        c = PausingConsumer()
+        d = version.read(c)
+        d.addCallback(lambda ign: self.failUnlessEqual(c.size, len(data)))
+
+        c2 = PausingAndStoppingConsumer()
+        d.addCallback(lambda ign:
+                      self.shouldFail(DownloadStopped, kind+"_pause_stop",
+                                      "our Consumer called stopProducing()",
+                                      version.read, c2))
+
+        c3 = StoppingConsumer()
+        d.addCallback(lambda ign:
+                      self.shouldFail(DownloadStopped, kind+"_stop",
+                                      "our Consumer called stopProducing()",
+                                      version.read, c3))
+
+        c4 = ImmediatelyStoppingConsumer()
+        d.addCallback(lambda ign:
+                      self.shouldFail(DownloadStopped, kind+"_stop_imm",
+                                      "our Consumer called stopProducing()",
+                                      version.read, c4))
+
+        def _then(ign):
+            c5 = MemoryConsumer()
+            d1 = version.read(c5)
+            c5.producer.stopProducing()
+            return self.shouldFail(DownloadStopped, kind+"_stop_imm2",
+                                   "our Consumer called stopProducing()",
+                                   lambda: d1)
+        d.addCallback(_then)
+        return d
 
     def test_download_from_mdmf_cap(self):
         # We should be able to download an MDMF file given its cap
@@ -1048,30 +1084,6 @@ class PublishMixin:
                     index = versionmap[shnum]
                     shares[peerid][shnum] = oldshares[index][peerid][shnum]
 
-class PausingConsumer:
-    implements(IConsumer)
-    def __init__(self):
-        self.data = ""
-        self.already_paused = False
-
-    def registerProducer(self, producer, streaming):
-        self.producer = producer
-        self.producer.resumeProducing()
-
-    def unregisterProducer(self):
-        self.producer = None
-
-    def _unpause(self, ignored):
-        self.producer.resumeProducing()
-
-    def write(self, data):
-        self.data += data
-        if not self.already_paused:
-           self.producer.pauseProducing()
-           self.already_paused = True
-           reactor.callLater(15, self._unpause, None)
-
-
 class Servermap(unittest.TestCase, PublishMixin):
     def setUp(self):
         return self.publish_one()
-- 
2.45.2