From a15ce96846c301a417116ef66f4831767ce93ffc Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 5 Sep 2011 12:02:42 -0700 Subject: [PATCH] Retrieve: implement/test stopProducing --- src/allmydata/mutable/retrieve.py | 9 ++- src/allmydata/test/test_mutable.py | 98 +++++++++++++++++------------- 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py index 100350a0..9c09abfe 100644 --- a/src/allmydata/mutable/retrieve.py +++ b/src/allmydata/mutable/retrieve.py @@ -7,7 +7,7 @@ from twisted.python import failure from twisted.internet.interfaces import IPushProducer, IConsumer from foolscap.api import eventually, fireEventually from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \ - MDMF_VERSION, SDMF_VERSION + DownloadStopped, MDMF_VERSION, SDMF_VERSION from allmydata.util import hashutil, log, mathutil from allmydata.util.dictutil import DictOfSets from allmydata import hashtree, codec @@ -143,6 +143,7 @@ class Retrieve: self._status.set_size(datalength) self._status.set_encoding(k, N) self.readers = {} + self._stopped = False self._pause_deferred = None self._offset = None self._read_length = None @@ -196,6 +197,10 @@ class Retrieve: eventually(p.callback, None) + def stopProducing(self): + self._stopped = True + self.resumeProducing() + def _check_for_paused(self, res): """ @@ -205,6 +210,8 @@ class Retrieve: the Deferred fires immediately. Otherwise, the Deferred fires when the downloader is unpaused. """ + if self._stopped: + raise DownloadStopped("our Consumer called stopProducing()") if self._pause_deferred is not None: d = defer.Deferred() self._pause_deferred.addCallback(lambda ignored: d.callback(res)) diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index 69ce131b..2b119ee1 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -3,16 +3,15 @@ import os, re, base64 from cStringIO import StringIO from twisted.trial import unittest from twisted.internet import defer, reactor -from twisted.internet.interfaces import IConsumer -from zope.interface import implements from allmydata import uri, client from allmydata.nodemaker import NodeMaker from allmydata.util import base32, consumer, fileutil, mathutil from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \ ssk_pubkey_fingerprint_hash +from allmydata.util.consumer import MemoryConsumer from allmydata.util.deferredutil import gatherResults from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \ - NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION + NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION, DownloadStopped from allmydata.monitor import Monitor from allmydata.test.common import ShouldFailMixin from allmydata.test.no_network import GridTestMixin @@ -37,6 +36,9 @@ from allmydata.mutable.repairer import MustForceRepairError import allmydata.test.common_util as testutil from allmydata.test.common import TEST_RSA_KEY_SIZE +from allmydata.test.test_download import PausingConsumer, \ + PausingAndStoppingConsumer, StoppingConsumer, \ + ImmediatelyStoppingConsumer # this "FakeStorage" exists to put the share data in RAM and avoid using real @@ -544,26 +546,60 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin): return d - def test_retrieve_pause(self): - # We should make sure that the retriever is able to pause + def test_retrieve_producer_mdmf(self): + # We should make sure that the retriever is able to pause and stop # correctly. - d = self.nodemaker.create_mutable_file(version=MDMF_VERSION) - def _created(node): - self.node = node + data = "contents1" * 100000 + d = self.nodemaker.create_mutable_file(MutableData(data), + version=MDMF_VERSION) + d.addCallback(lambda node: node.get_best_mutable_version()) + d.addCallback(self._test_retrieve_producer, "MDMF", data) + return d - return node.overwrite(MutableData("contents1" * 100000)) - d.addCallback(_created) - # Now we'll retrieve it into a pausing consumer. - d.addCallback(lambda ignored: - self.node.get_best_mutable_version()) - def _got_version(version): - self.c = PausingConsumer() - return version.read(self.c) - d.addCallback(_got_version) - d.addCallback(lambda ignored: - self.failUnlessEqual(self.c.data, "contents1" * 100000)) + # note: SDMF has only one big segment, so we can't use the usual + # after-the-first-write() trick to pause or stop the download. + # Disabled until we find a better approach. + def OFF_test_retrieve_producer_sdmf(self): + data = "contents1" * 100000 + d = self.nodemaker.create_mutable_file(MutableData(data), + version=SDMF_VERSION) + d.addCallback(lambda node: node.get_best_mutable_version()) + d.addCallback(self._test_retrieve_producer, "SDMF", data) return d + def _test_retrieve_producer(self, version, kind, data): + # Now we'll retrieve it into a pausing consumer. + c = PausingConsumer() + d = version.read(c) + d.addCallback(lambda ign: self.failUnlessEqual(c.size, len(data))) + + c2 = PausingAndStoppingConsumer() + d.addCallback(lambda ign: + self.shouldFail(DownloadStopped, kind+"_pause_stop", + "our Consumer called stopProducing()", + version.read, c2)) + + c3 = StoppingConsumer() + d.addCallback(lambda ign: + self.shouldFail(DownloadStopped, kind+"_stop", + "our Consumer called stopProducing()", + version.read, c3)) + + c4 = ImmediatelyStoppingConsumer() + d.addCallback(lambda ign: + self.shouldFail(DownloadStopped, kind+"_stop_imm", + "our Consumer called stopProducing()", + version.read, c4)) + + def _then(ign): + c5 = MemoryConsumer() + d1 = version.read(c5) + c5.producer.stopProducing() + return self.shouldFail(DownloadStopped, kind+"_stop_imm2", + "our Consumer called stopProducing()", + lambda: d1) + d.addCallback(_then) + return d def test_download_from_mdmf_cap(self): # We should be able to download an MDMF file given its cap @@ -1048,30 +1084,6 @@ class PublishMixin: index = versionmap[shnum] shares[peerid][shnum] = oldshares[index][peerid][shnum] -class PausingConsumer: - implements(IConsumer) - def __init__(self): - self.data = "" - self.already_paused = False - - def registerProducer(self, producer, streaming): - self.producer = producer - self.producer.resumeProducing() - - def unregisterProducer(self): - self.producer = None - - def _unpause(self, ignored): - self.producer.resumeProducing() - - def write(self, data): - self.data += data - if not self.already_paused: - self.producer.pauseProducing() - self.already_paused = True - reactor.callLater(15, self._unpause, None) - - class Servermap(unittest.TestCase, PublishMixin): def setUp(self): return self.publish_one() -- 2.45.2