from zope.interface import implements
from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
+from twisted.internet.interfaces import IConsumer
from twisted.python.failure import Failure
from foolscap import eventual
from allmydata import encode, upload, download, hashtree, uri
-from allmydata.util import hashutil
+from allmydata.util import hashutil, testutil
from allmydata.util.assertutil import _assert
from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
# 5 segments: 25, 25, 25, 25, 1
return self.do_encode(25, 101, 100, 5, 15, 8)
-class Roundtrip(unittest.TestCase):
+class PausingTarget(download.Data):
+ implements(IConsumer)
+ def __init__(self):
+ download.Data.__init__(self)
+ self.size = 0
+ self.writes = 0
+ def write(self, data):
+ self.size += len(data)
+ self.writes += 1
+ if self.writes <= 2:
+ # we happen to use 4 segments, and want to avoid pausing on the
+ # last one (since then the _unpause timer will still be running)
+ self.producer.pauseProducing()
+ reactor.callLater(0.1, self._unpause)
+ return download.Data.write(self, data)
+ def _unpause(self):
+ self.producer.resumeProducing()
+ def registerProducer(self, producer, streaming):
+ self.producer = producer
+ def unregisterProducer(self):
+ self.producer = None
+
+class PausingAndStoppingTarget(PausingTarget):
+ def write(self, data):
+ self.producer.pauseProducing()
+ reactor.callLater(0.5, self._stop)
+ def _stop(self):
+ self.producer.stopProducing()
+
+class StoppingTarget(PausingTarget):
+ def write(self, data):
+ self.producer.stopProducing()
+
+class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
def send_and_recover(self, k_and_happy_and_n=(25,75,100),
AVAILABLE_SHARES=None,
datalen=76,
max_segment_size=25,
bucket_modes={},
recover_mode="recover",
+ target=None,
):
if AVAILABLE_SHARES is None:
AVAILABLE_SHARES = k_and_happy_and_n[2]
d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
max_segment_size, bucket_modes, data)
# that fires with (uri_extension_hash, e, shareholders)
- d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
+ d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode,
+ target=target)
# that fires with newdata
def _downloaded((newdata, fd)):
self.failUnless(newdata == data)
return d
def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
- recover_mode):
+ recover_mode, target=None):
(uri_extension_hash, required_shares, num_shares, file_size) = res
if "corrupt_key" in recover_mode:
URI = u.to_string()
client = FakeClient()
- target = download.Data()
+ if not target:
+ target = download.Data()
fd = download.FileDownloader(client, URI, target)
# we manually cycle the FileDownloader through a number of steps that
def test_101(self):
return self.send_and_recover(datalen=101)
+ def test_pause(self):
+ # use a DownloadTarget that does pauseProducing/resumeProducing a few
+ # times, then finishes
+ t = PausingTarget()
+ d = self.send_and_recover(target=t)
+ return d
+
+ def test_pause_then_stop(self):
+ # use a DownloadTarget that pauses, then stops.
+ t = PausingAndStoppingTarget()
+ d = self.shouldFail(download.DownloadStopped, "test_pause_then_stop",
+ "our Consumer called stopProducing()",
+ self.send_and_recover, target=t)
+ return d
+
+ def test_stop(self):
+ # use a DownloadTarget that does an immediate stop (ticket #473)
+ t = StoppingTarget()
+ d = self.shouldFail(download.DownloadStopped, "test_stop",
+ "our Consumer called stopProducing()",
+ self.send_and_recover, target=t)
+ return d
+
# the following tests all use 4-out-of-10 encoding
def test_bad_blocks(self):