From: Brian Warner Date: Tue, 1 Dec 2009 22:44:35 +0000 (-0500) Subject: Simplify immutable download API: use just filenode.read(consumer, offset, size) X-Git-Url: https://git.rkrishnan.org/simplejson/components/com_hotproperty/install.html?a=commitdiff_plain;h=96834da0a271b703ce8b448450a7f379b1751f27;p=tahoe-lafs%2Ftahoe-lafs.git Simplify immutable download API: use just filenode.read(consumer, offset, size) * remove Downloader.download_to_data/download_to_filename/download_to_filehandle * remove download.Data/FileName/FileHandle targets * remove filenode.download/download_to_data/download_to_filename methods * leave Downloader.download (the whole Downloader will go away eventually) * add util.consumer.MemoryConsumer/download_to_data, for convenience (this is mostly used by unit tests, but it gets used by enough non-test code to warrant putting it in allmydata.util) * update tests * removes about 180 lines of code. Yay negative code days! Overall plan is to rewrite immutable/download.py and leave filenode.read() as the sole read-side API. --- diff --git a/src/allmydata/control.py b/src/allmydata/control.py index 09d10d61..9fe8e8e7 100644 --- a/src/allmydata/control.py +++ b/src/allmydata/control.py @@ -3,10 +3,11 @@ import os, time from zope.interface import implements from twisted.application import service from twisted.internet import defer +from twisted.internet.interfaces import IConsumer from foolscap.api import Referenceable from allmydata.interfaces import RIControlClient from allmydata.util import fileutil, mathutil -from allmydata.immutable import upload, download +from allmydata.immutable import upload from twisted.python import log def get_memory_usage(): @@ -35,6 +36,22 @@ def log_memory_usage(where=""): stats["VmPeak"], where)) +class FileWritingConsumer: + implements(IConsumer) + def __init__(self, filename): + self.done = False + self.f = open(filename, "wb") + def registerProducer(self, p, streaming): + if streaming: + p.resumeProducing() + else: + while not self.done: + p.resumeProducing() + def write(self, data): + self.f.write(data) + def unregisterProducer(self): + self.done = True + self.f.close() class ControlServer(Referenceable, service.Service): implements(RIControlClient) @@ -51,7 +68,8 @@ class ControlServer(Referenceable, service.Service): def remote_download_from_uri_to_file(self, uri, filename): filenode = self.parent.create_node_from_uri(uri) - d = filenode.download_to_filename(filename) + c = FileWritingConsumer(filename) + d = filenode.read(c) d.addCallback(lambda res: filename) return d @@ -181,7 +199,7 @@ class SpeedTest: if i >= self.count: return n = self.parent.create_node_from_uri(self.uris[i]) - d1 = n.download(download.FileHandle(Discard())) + d1 = n.read(DiscardingConsumer()) d1.addCallback(_download_one_file, i+1) return d1 d.addCallback(_download_one_file, 0) @@ -197,10 +215,17 @@ class SpeedTest: os.unlink(fn) return res -class Discard: +class DiscardingConsumer: + implements(IConsumer) + def __init__(self): + self.done = False + def registerProducer(self, p, streaming): + if streaming: + p.resumeProducing() + else: + while not self.done: + p.resumeProducing() def write(self, data): pass - # download_to_filehandle explicitly does not close the filehandle it was - # given: that is reserved for the provider of the filehandle. Therefore - # the lack of a close() method on this otherwise filehandle-like object - # is a part of the test. + def unregisterProducer(self): + self.done = True diff --git a/src/allmydata/dirnode.py b/src/allmydata/dirnode.py index 96c31ffb..ae8516f6 100644 --- a/src/allmydata/dirnode.py +++ b/src/allmydata/dirnode.py @@ -18,6 +18,7 @@ from allmydata.monitor import Monitor from allmydata.util import hashutil, mathutil, base32, log from allmydata.util.assertutil import precondition from allmydata.util.netstring import netstring, split_netstring +from allmydata.util.consumer import download_to_data from allmydata.uri import LiteralFileURI, from_string, wrap_dirnode_cap from pycryptopp.cipher.aes import AES from allmydata.util.dictutil import AuxValueDict @@ -217,7 +218,7 @@ class DirectoryNode: # use the IMutableFileNode API. d = self._node.download_best_version() else: - d = self._node.download_to_data() + d = download_to_data(self._node) d.addCallback(self._unpack_contents) return d diff --git a/src/allmydata/frontends/ftpd.py b/src/allmydata/frontends/ftpd.py index 4d4256df..bba29160 100644 --- a/src/allmydata/frontends/ftpd.py +++ b/src/allmydata/frontends/ftpd.py @@ -9,7 +9,6 @@ from twisted.protocols import ftp from allmydata.interfaces import IDirectoryNode, ExistingChildError, \ NoSuchChildError -from allmydata.immutable.download import ConsumerAdapter from allmydata.immutable.upload import FileHandle class ReadFile: @@ -17,8 +16,7 @@ class ReadFile: def __init__(self, node): self.node = node def send(self, consumer): - ad = ConsumerAdapter(consumer) - d = self.node.download(ad) + d = self.node.read(consumer) return d # when consumed class FileWriter: diff --git a/src/allmydata/frontends/sftpd.py b/src/allmydata/frontends/sftpd.py index 5a713e1c..4a866eec 100644 --- a/src/allmydata/frontends/sftpd.py +++ b/src/allmydata/frontends/sftpd.py @@ -4,7 +4,6 @@ from zope.interface import implements from twisted.python import components from twisted.application import service, strports from twisted.internet import defer -from twisted.internet.interfaces import IConsumer from twisted.conch.ssh import factory, keys, session from twisted.conch.interfaces import ISFTPServer, ISFTPFile, IConchUser from twisted.conch.avatar import ConchUser @@ -15,28 +14,7 @@ from twisted.cred import portal from allmydata.interfaces import IDirectoryNode, ExistingChildError, \ NoSuchChildError from allmydata.immutable.upload import FileHandle - -class MemoryConsumer: - implements(IConsumer) - def __init__(self): - self.chunks = [] - self.done = False - def registerProducer(self, p, streaming): - if streaming: - # call resumeProducing once to start things off - p.resumeProducing() - else: - while not self.done: - p.resumeProducing() - def write(self, data): - self.chunks.append(data) - def unregisterProducer(self): - self.done = True - -def download_to_data(n, offset=0, size=None): - d = n.read(MemoryConsumer(), offset, size) - d.addCallback(lambda mc: "".join(mc.chunks)) - return d +from allmydata.util.consumer import download_to_data class ReadFile: implements(ISFTPFile) diff --git a/src/allmydata/immutable/download.py b/src/allmydata/immutable/download.py index f6356caf..261b65d6 100644 --- a/src/allmydata/immutable/download.py +++ b/src/allmydata/immutable/download.py @@ -1,4 +1,4 @@ -import os, random, weakref, itertools, time +import random, weakref, itertools, time from zope.interface import implements from twisted.internet import defer from twisted.internet.interfaces import IPushProducer, IConsumer @@ -1196,93 +1196,6 @@ class CiphertextDownloader(log.PrefixingLogMixin): return self._status -class FileName: - implements(IDownloadTarget) - def __init__(self, filename): - self._filename = filename - self.f = None - def open(self, size): - self.f = open(self._filename, "wb") - return self.f - def write(self, data): - self.f.write(data) - def close(self): - if self.f: - self.f.close() - def fail(self, why): - if self.f: - self.f.close() - os.unlink(self._filename) - def register_canceller(self, cb): - pass # we won't use it - def finish(self): - pass - # The following methods are just because the target might be a - # repairer.DownUpConnector, and just because the current CHKUpload object - # expects to find the storage index and encoding parameters in its - # Uploadable. - def set_storageindex(self, storageindex): - pass - def set_encodingparams(self, encodingparams): - pass - -class Data: - implements(IDownloadTarget) - def __init__(self): - self._data = [] - def open(self, size): - pass - def write(self, data): - self._data.append(data) - def close(self): - self.data = "".join(self._data) - del self._data - def fail(self, why): - del self._data - def register_canceller(self, cb): - pass # we won't use it - def finish(self): - return self.data - # The following methods are just because the target might be a - # repairer.DownUpConnector, and just because the current CHKUpload object - # expects to find the storage index and encoding parameters in its - # Uploadable. - def set_storageindex(self, storageindex): - pass - def set_encodingparams(self, encodingparams): - pass - -class FileHandle: - """Use me to download data to a pre-defined filehandle-like object. I - will use the target's write() method. I will *not* close the filehandle: - I leave that up to the originator of the filehandle. The download process - will return the filehandle when it completes. - """ - implements(IDownloadTarget) - def __init__(self, filehandle): - self._filehandle = filehandle - def open(self, size): - pass - def write(self, data): - self._filehandle.write(data) - def close(self): - # the originator of the filehandle reserves the right to close it - pass - def fail(self, why): - pass - def register_canceller(self, cb): - pass - def finish(self): - return self._filehandle - # The following methods are just because the target might be a - # repairer.DownUpConnector, and just because the current CHKUpload object - # expects to find the storage index and encoding parameters in its - # Uploadable. - def set_storageindex(self, storageindex): - pass - def set_encodingparams(self, encodingparams): - pass - class ConsumerAdapter: implements(IDownloadTarget, IConsumer) def __init__(self, consumer): @@ -1351,11 +1264,3 @@ class Downloader: history.add_download(dl.get_download_status()) d = dl.start() return d - - # utility functions - def download_to_data(self, uri, _log_msg_id=None, history=None): - return self.download(uri, Data(), _log_msg_id=_log_msg_id, history=history) - def download_to_filename(self, uri, filename, _log_msg_id=None): - return self.download(uri, FileName(filename), _log_msg_id=_log_msg_id) - def download_to_filehandle(self, uri, filehandle, _log_msg_id=None): - return self.download(uri, FileHandle(filehandle), _log_msg_id=_log_msg_id) diff --git a/src/allmydata/immutable/filenode.py b/src/allmydata/immutable/filenode.py index c57709b3..196d4e03 100644 --- a/src/allmydata/immutable/filenode.py +++ b/src/allmydata/immutable/filenode.py @@ -2,7 +2,7 @@ import copy, os.path, stat from cStringIO import StringIO from zope.interface import implements from twisted.internet import defer -from twisted.internet.interfaces import IPushProducer, IConsumer +from twisted.internet.interfaces import IPushProducer from twisted.protocols import basic from foolscap.api import eventually from allmydata.interfaces import IImmutableFileNode, ICheckable, \ @@ -284,6 +284,8 @@ class ImmutableFileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin): return v.start() def read(self, consumer, offset=0, size=None): + self.log("read", offset=offset, size=size, + umid="UPP8FA", level=log.OPERATIONAL) if size is None: size = self.get_size() - offset size = min(size, self.get_size() - offset) @@ -291,24 +293,16 @@ class ImmutableFileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin): if offset == 0 and size == self.get_size(): # don't use the cache, just do a normal streaming download self.log("doing normal full download", umid="VRSBwg", level=log.OPERATIONAL) - return self.download(download.ConsumerAdapter(consumer)) + target = download.ConsumerAdapter(consumer) + return self._downloader.download(self.get_cap(), target, + self._parentmsgid, + history=self._history) d = self.download_cache.when_range_available(offset, size) d.addCallback(lambda res: self.download_cache.read(consumer, offset, size)) return d - def download(self, target): - return self._downloader.download(self.get_cap(), target, - self._parentmsgid, - history=self._history) - - def download_to_data(self): - return self._downloader.download_to_data(self.get_cap(), - history=self._history) - def download_to_filename(self, filename): - return self._downloader.download_to_filename(self.get_cap(), filename) - class LiteralProducer: implements(IPushProducer) def resumeProducing(self): @@ -367,19 +361,3 @@ class LiteralFileNode(_ImmutableFileNodeBase): d = basic.FileSender().beginFileTransfer(StringIO(data), consumer) d.addCallback(lambda lastSent: consumer) return d - - def download(self, target): - # note that this does not update the stats_provider - data = self.u.data - if IConsumer.providedBy(target): - target.registerProducer(LiteralProducer(), True) - target.open(len(data)) - target.write(data) - if IConsumer.providedBy(target): - target.unregisterProducer() - target.close() - return defer.maybeDeferred(target.finish) - - def download_to_data(self): - data = self.u.data - return defer.succeed(data) diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py index c6897ef5..babacd61 100644 --- a/src/allmydata/interfaces.py +++ b/src/allmydata/interfaces.py @@ -575,13 +575,6 @@ class IFileNode(IFilesystemNode): container, like IDirectoryNode.""" class IImmutableFileNode(IFileNode): - def download(target): - """Download the file's contents to a given IDownloadTarget""" - - def download_to_data(): - """Download the file's contents. Return a Deferred that fires - with those contents.""" - def read(consumer, offset=0, size=None): """Download a portion (possibly all) of the file's contents, making them available to the given IConsumer. Return a Deferred that fires @@ -613,25 +606,8 @@ class IImmutableFileNode(IFileNode): p.stopProducing(), which will result in an exception being delivered via deferred.errback(). - A simple download-to-memory consumer example would look like this:: - - class MemoryConsumer: - implements(IConsumer) - def __init__(self): - self.chunks = [] - self.done = False - def registerProducer(self, p, streaming): - assert streaming == False - while not self.done: - p.resumeProducing() - def write(self, data): - self.chunks.append(data) - def unregisterProducer(self): - self.done = True - d = filenode.read(MemoryConsumer()) - d.addCallback(lambda mc: "".join(mc.chunks)) - return d - + See src/allmydata/util/consumer.py for an example of a simple + download-to-memory consumer. """ class IMutableFileNode(IFileNode): diff --git a/src/allmydata/test/common.py b/src/allmydata/test/common.py index 98bb6aed..7113cceb 100644 --- a/src/allmydata/test/common.py +++ b/src/allmydata/test/common.py @@ -1,7 +1,7 @@ import os, random, struct from zope.interface import implements from twisted.internet import defer -from twisted.internet.interfaces import IConsumer +from twisted.internet.interfaces import IPullProducer from twisted.python import failure from twisted.application import service from twisted.web.error import Error as WebError @@ -18,6 +18,7 @@ from allmydata.storage.server import storage_index_to_dir from allmydata.storage.mutable import MutableShareFile from allmydata.util import hashutil, log, fileutil, pollmixin from allmydata.util.assertutil import precondition +from allmydata.util.consumer import download_to_data from allmydata.stats import StatsGathererService from allmydata.key_generator import KeyGeneratorService import common_util as testutil @@ -31,6 +32,11 @@ def flush_but_dont_ignore(res): d.addCallback(_done) return d +class DummyProducer: + implements(IPullProducer) + def resumeProducing(self): + pass + class FakeCHKFileNode: """I provide IImmutableFileNode, but all of my data is stored in a class-level dictionary.""" @@ -98,40 +104,36 @@ class FakeCHKFileNode: def is_readonly(self): return True - def download(self, target): - if self.my_uri.to_string() not in self.all_contents: - f = failure.Failure(NotEnoughSharesError(None, 0, 3)) - target.fail(f) - return defer.fail(f) - data = self.all_contents[self.my_uri.to_string()] - target.open(len(data)) - target.write(data) - target.close() - return defer.maybeDeferred(target.finish) - def download_to_data(self): - if self.my_uri.to_string() not in self.all_contents: - return defer.fail(NotEnoughSharesError(None, 0, 3)) - data = self.all_contents[self.my_uri.to_string()] - return defer.succeed(data) def get_size(self): try: data = self.all_contents[self.my_uri.to_string()] except KeyError, le: raise NotEnoughSharesError(le, 0, 3) return len(data) + def read(self, consumer, offset=0, size=None): - d = self.download_to_data() - def _got(data): - start = offset - if size is not None: - end = offset + size - else: - end = len(data) - consumer.write(data[start:end]) - return consumer - d.addCallback(_got) + # we don't bother to call registerProducer/unregisterProducer, + # because it's a hassle to write a dummy Producer that does the right + # thing (we have to make sure that DummyProducer.resumeProducing + # writes the data into the consumer immediately, otherwise it will + # loop forever). + + d = defer.succeed(None) + d.addCallback(self._read, consumer, offset, size) return d + def _read(self, ignored, consumer, offset, size): + if self.my_uri.to_string() not in self.all_contents: + raise NotEnoughSharesError(None, 0, 3) + data = self.all_contents[self.my_uri.to_string()] + start = offset + if size is not None: + end = offset + size + else: + end = len(data) + consumer.write(data[start:end]) + return consumer + def make_chk_file_cap(size): return uri.CHKFileURI(key=os.urandom(16), uri_extension_hash=os.urandom(32), @@ -927,6 +929,7 @@ class ShareManglingMixin(SystemTestMixin): d2 = cl0.upload(immutable.upload.Data(TEST_DATA, convergence="")) def _after_upload(u): filecap = u.uri + self.n = self.clients[1].create_node_from_uri(filecap) self.uri = uri.CHKFileURI.init_from_string(filecap) return cl0.create_node_from_uri(filecap) d2.addCallback(_after_upload) @@ -1045,8 +1048,7 @@ class ShareManglingMixin(SystemTestMixin): return sum_of_write_counts def _download_and_check_plaintext(self, unused=None): - d = self.clients[1].downloader.download_to_data(self.uri) - + d = download_to_data(self.n) def _after_download(result): self.failUnlessEqual(result, TEST_DATA) d.addCallback(_after_download) @@ -1140,28 +1142,6 @@ class ErrorMixin(WebErrorMixin): print "First Error:", f.value.subFailure return f -class MemoryConsumer: - implements(IConsumer) - def __init__(self): - self.chunks = [] - self.done = False - def registerProducer(self, p, streaming): - if streaming: - # call resumeProducing once to start things off - p.resumeProducing() - else: - while not self.done: - p.resumeProducing() - def write(self, data): - self.chunks.append(data) - def unregisterProducer(self): - self.done = True - -def download_to_data(n, offset=0, size=None): - d = n.read(MemoryConsumer(), offset, size) - d.addCallback(lambda mc: "".join(mc.chunks)) - return d - def corrupt_field(data, offset, size, debug=False): if random.random() < 0.5: newdata = testutil.flip_one_bit(data, offset, size) diff --git a/src/allmydata/test/test_download.py b/src/allmydata/test/test_download.py index a4743d2e..b54bf017 100644 --- a/src/allmydata/test/test_download.py +++ b/src/allmydata/test/test_download.py @@ -8,6 +8,7 @@ from twisted.trial import unittest from allmydata import uri from allmydata.storage.server import storage_index_to_dir from allmydata.util import base32, fileutil +from allmydata.util.consumer import download_to_data from allmydata.immutable import upload from allmydata.test.no_network import GridTestMixin @@ -173,7 +174,7 @@ class DownloadTest(GridTestMixin, unittest.TestCase): def download_immutable(self, ignored=None): n = self.c0.create_node_from_uri(immutable_uri) - d = n.download_to_data() + d = download_to_data(n) def _got_data(data): self.failUnlessEqual(data, plaintext) d.addCallback(_got_data) diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index c6a7bda1..2d91dbc5 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -1,13 +1,13 @@ from zope.interface import implements from twisted.trial import unittest from twisted.internet import defer, reactor -from twisted.internet.interfaces import IConsumer from twisted.python.failure import Failure from foolscap.api import fireEventually from allmydata import hashtree, uri from allmydata.immutable import encode, upload, download from allmydata.util import hashutil from allmydata.util.assertutil import _assert +from allmydata.util.consumer import MemoryConsumer from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader, \ NotEnoughSharesError, IStorageBroker from allmydata.monitor import Monitor @@ -384,10 +384,9 @@ class Encode(unittest.TestCase): # 5 segments: 25, 25, 25, 25, 1 return self.do_encode(25, 101, 100, 5, 15, 8) -class PausingTarget(download.Data): - implements(IConsumer) +class PausingConsumer(MemoryConsumer): def __init__(self): - download.Data.__init__(self) + MemoryConsumer.__init__(self) self.size = 0 self.writes = 0 def write(self, data): @@ -398,22 +397,18 @@ class PausingTarget(download.Data): # last one (since then the _unpause timer will still be running) self.producer.pauseProducing() reactor.callLater(0.1, self._unpause) - return download.Data.write(self, data) + return MemoryConsumer.write(self, data) def _unpause(self): self.producer.resumeProducing() - def registerProducer(self, producer, streaming): - self.producer = producer - def unregisterProducer(self): - self.producer = None -class PausingAndStoppingTarget(PausingTarget): +class PausingAndStoppingConsumer(PausingConsumer): def write(self, data): self.producer.pauseProducing() reactor.callLater(0.5, self._stop) def _stop(self): self.producer.stopProducing() -class StoppingTarget(PausingTarget): +class StoppingConsumer(PausingConsumer): def write(self, data): self.producer.stopProducing() @@ -425,7 +420,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): max_segment_size=25, bucket_modes={}, recover_mode="recover", - target=None, + consumer=None, ): if AVAILABLE_SHARES is None: AVAILABLE_SHARES = k_and_happy_and_n[2] @@ -434,7 +429,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): max_segment_size, bucket_modes, data) # that fires with (uri_extension_hash, e, shareholders) d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode, - target=target) + consumer=consumer) # that fires with newdata def _downloaded((newdata, fd)): self.failUnless(newdata == data, str((len(newdata), len(data)))) @@ -478,7 +473,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): return d def recover(self, (res, key, shareholders), AVAILABLE_SHARES, - recover_mode, target=None): + recover_mode, consumer=None): verifycap = res if "corrupt_key" in recover_mode: @@ -495,9 +490,10 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): size=verifycap.size) sb = FakeStorageBroker() - if not target: - target = download.Data() - target = download.DecryptingTarget(target, u.key) + if not consumer: + consumer = MemoryConsumer() + innertarget = download.ConsumerAdapter(consumer) + target = download.DecryptingTarget(innertarget, u.key) fd = download.CiphertextDownloader(sb, u.get_verify_cap(), target, monitor=Monitor()) # we manually cycle the CiphertextDownloader through a number of steps that @@ -544,7 +540,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): d.addCallback(fd._download_all_segments) d.addCallback(fd._done) - def _done(newdata): + def _done(t): + newdata = "".join(consumer.chunks) return (newdata, fd) d.addCallback(_done) return d @@ -582,26 +579,26 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): return self.send_and_recover(datalen=101) def test_pause(self): - # use a DownloadTarget that does pauseProducing/resumeProducing a few - # times, then finishes - t = PausingTarget() - d = self.send_and_recover(target=t) + # use a download target that does pauseProducing/resumeProducing a + # few times, then finishes + c = PausingConsumer() + d = self.send_and_recover(consumer=c) return d def test_pause_then_stop(self): - # use a DownloadTarget that pauses, then stops. - t = PausingAndStoppingTarget() + # use a download target that pauses, then stops. + c = PausingAndStoppingConsumer() d = self.shouldFail(download.DownloadStopped, "test_pause_then_stop", "our Consumer called stopProducing()", - self.send_and_recover, target=t) + self.send_and_recover, consumer=c) return d def test_stop(self): - # use a DownloadTarget that does an immediate stop (ticket #473) - t = StoppingTarget() + # use a download targetthat does an immediate stop (ticket #473) + c = StoppingConsumer() d = self.shouldFail(download.DownloadStopped, "test_stop", "our Consumer called stopProducing()", - self.send_and_recover, target=t) + self.send_and_recover, consumer=c) return d # the following tests all use 4-out-of-10 encoding diff --git a/src/allmydata/test/test_filenode.py b/src/allmydata/test/test_filenode.py index feeda226..a8de4207 100644 --- a/src/allmydata/test/test_filenode.py +++ b/src/allmydata/test/test_filenode.py @@ -2,11 +2,10 @@ from twisted.trial import unittest from allmydata import uri, client from allmydata.monitor import Monitor -from allmydata.immutable import download from allmydata.immutable.filenode import ImmutableFileNode, LiteralFileNode from allmydata.mutable.filenode import MutableFileNode from allmydata.util import hashutil, cachedir -from allmydata.test.common import download_to_data +from allmydata.util.consumer import download_to_data class NotANode: pass @@ -77,17 +76,11 @@ class Node(unittest.TestCase): self.failUnlessEqual(v, None) self.failUnlessEqual(fn1.get_repair_cap(), None) - d = fn1.download(download.Data()) + d = download_to_data(fn1) def _check(res): self.failUnlessEqual(res, DATA) d.addCallback(_check) - d.addCallback(lambda res: fn1.download_to_data()) - d.addCallback(_check) - - d.addCallback(lambda res: download_to_data(fn1)) - d.addCallback(_check) - d.addCallback(lambda res: download_to_data(fn1, 1, 5)) def _check_segment(res): self.failUnlessEqual(res, DATA[1:1+5]) diff --git a/src/allmydata/test/test_immutable.py b/src/allmydata/test/test_immutable.py index ad1843c0..c9dfc064 100644 --- a/src/allmydata/test/test_immutable.py +++ b/src/allmydata/test/test_immutable.py @@ -1,5 +1,6 @@ from allmydata.test import common from allmydata.interfaces import NotEnoughSharesError +from allmydata.util.consumer import download_to_data from twisted.internet import defer from twisted.trial import unittest import random @@ -26,7 +27,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase): d.addCallback(_then_delete_8) def _then_download(unused=None): - d2 = self.clients[1].downloader.download_to_data(self.uri) + d2 = download_to_data(self.n) def _after_download_callb(result): self.fail() # should have gotten an errback instead @@ -93,7 +94,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase): before_download_reads = self._count_reads() def _attempt_to_download(unused=None): - d2 = self.clients[1].downloader.download_to_data(self.uri) + d2 = download_to_data(self.n) def _callb(res): self.fail("Should have gotten an error from attempt to download, not %r" % (res,)) @@ -126,7 +127,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase): before_download_reads = self._count_reads() def _attempt_to_download(unused=None): - d2 = self.clients[1].downloader.download_to_data(self.uri) + d2 = download_to_data(self.n) def _callb(res): self.fail("Should have gotten an error from attempt to download, not %r" % (res,)) diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index 7b7c0990..fc6dc37f 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -5,7 +5,6 @@ from twisted.trial import unittest from twisted.internet import defer, reactor from allmydata import uri, client from allmydata.nodemaker import NodeMaker -from allmydata.immutable import download from allmydata.util import base32 from allmydata.util.idlib import shortnodeid_b2a from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \ @@ -264,8 +263,6 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin): d.addCallback(lambda res: n.overwrite("contents 2")) d.addCallback(lambda res: n.download_best_version()) d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2")) - d.addCallback(lambda res: n.download(download.Data())) - d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2")) d.addCallback(lambda res: n.get_servermap(MODE_WRITE)) d.addCallback(lambda smap: n.upload("contents 3", smap)) d.addCallback(lambda res: n.download_best_version()) @@ -495,8 +492,6 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin): d.addCallback(lambda res: n.overwrite("contents 2")) d.addCallback(lambda res: n.download_best_version()) d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2")) - d.addCallback(lambda res: n.download(download.Data())) - d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2")) d.addCallback(lambda res: n.get_servermap(MODE_WRITE)) d.addCallback(lambda smap: n.upload("contents 3", smap)) d.addCallback(lambda res: n.download_best_version()) diff --git a/src/allmydata/test/test_no_network.py b/src/allmydata/test/test_no_network.py index 911dc86f..ef7eed5c 100644 --- a/src/allmydata/test/test_no_network.py +++ b/src/allmydata/test/test_no_network.py @@ -5,7 +5,7 @@ from twisted.trial import unittest from twisted.application import service from allmydata.test.no_network import NoNetworkGrid from allmydata.immutable.upload import Data - +from allmydata.util.consumer import download_to_data class Harness(unittest.TestCase): def setUp(self): @@ -32,7 +32,7 @@ class Harness(unittest.TestCase): d = c0.upload(data) def _uploaded(res): n = c0.create_node_from_uri(res.uri) - return n.download_to_data() + return download_to_data(n) d.addCallback(_uploaded) def _check(res): self.failUnlessEqual(res, DATA) diff --git a/src/allmydata/test/test_repairer.py b/src/allmydata/test/test_repairer.py index 337f89aa..8c14f74e 100644 --- a/src/allmydata/test/test_repairer.py +++ b/src/allmydata/test/test_repairer.py @@ -4,6 +4,7 @@ from allmydata.monitor import Monitor from allmydata import check_results from allmydata.interfaces import NotEnoughSharesError from allmydata.immutable import repairer, upload +from allmydata.util.consumer import download_to_data from twisted.internet import defer from twisted.trial import unittest import random @@ -428,7 +429,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin, d.addCallback(lambda ignored: self.shouldFail(NotEnoughSharesError, "then_download", None, - self.c1_filenode.download_to_data)) + download_to_data, self.c1_filenode)) d.addCallback(lambda ignored: self.shouldFail(NotEnoughSharesError, "then_repair", @@ -499,7 +500,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin, d.addCallback(lambda ignored: self.delete_shares_numbered(self.uri, range(3, 10+1))) - d.addCallback(lambda ignored: self.c1_filenode.download_to_data()) + d.addCallback(lambda ignored: download_to_data(self.c1_filenode)) d.addCallback(lambda newdata: self.failUnlessEqual(newdata, common.TEST_DATA)) return d @@ -544,7 +545,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin, d.addCallback(lambda ignored: self.delete_shares_numbered(self.uri, range(3, 10+1))) - d.addCallback(lambda ignored: self.c1_filenode.download_to_data()) + d.addCallback(lambda ignored: download_to_data(self.c1_filenode)) d.addCallback(lambda newdata: self.failUnlessEqual(newdata, common.TEST_DATA)) return d diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py index 423657db..f5c5c91e 100644 --- a/src/allmydata/test/test_system.py +++ b/src/allmydata/test/test_system.py @@ -1,20 +1,19 @@ from base64 import b32encode import os, sys, time, simplejson from cStringIO import StringIO -from zope.interface import implements from twisted.trial import unittest from twisted.internet import defer from twisted.internet import threads # CLI tests use deferToThread from twisted.internet.error import ConnectionDone, ConnectionLost -from twisted.internet.interfaces import IConsumer, IPushProducer import allmydata from allmydata import uri from allmydata.storage.mutable import MutableShareFile from allmydata.storage.server import si_a2b -from allmydata.immutable import download, offloaded, upload +from allmydata.immutable import offloaded, upload from allmydata.immutable.filenode import ImmutableFileNode, LiteralFileNode from allmydata.util import idlib, mathutil from allmydata.util import log, base32 +from allmydata.util.consumer import MemoryConsumer, download_to_data from allmydata.scripts import runner from allmydata.interfaces import IDirectoryNode, IFileNode, \ NoSuchChildError, NoSharesError @@ -26,8 +25,7 @@ from twisted.python.failure import Failure from twisted.web.client import getPage from twisted.web.error import Error -from allmydata.test.common import SystemTestMixin, MemoryConsumer, \ - download_to_data +from allmydata.test.common import SystemTestMixin LARGE_DATA = """ This is some data to publish to the virtual drive, which needs to be large @@ -47,22 +45,6 @@ class CountingDataUploadable(upload.Data): self.interrupt_after_d.callback(self) return upload.Data.read(self, length) -class GrabEverythingConsumer: - implements(IConsumer) - - def __init__(self): - self.contents = "" - - def registerProducer(self, producer, streaming): - assert streaming - assert IPushProducer.providedBy(producer) - - def write(self, data): - self.contents += data - - def unregisterProducer(self): - pass - class SystemTest(SystemTestMixin, unittest.TestCase): timeout = 3600 # It takes longer than 960 seconds on Zandr's ARM box. @@ -137,7 +119,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase): self.uri = theuri assert isinstance(self.uri, str), self.uri self.cap = uri.from_string(self.uri) - self.downloader = self.clients[1].downloader + self.n = self.clients[1].create_node_from_uri(self.uri) d.addCallback(_upload_done) def _upload_again(res): @@ -153,42 +135,13 @@ class SystemTest(SystemTestMixin, unittest.TestCase): def _download_to_data(res): log.msg("DOWNLOADING") - return self.downloader.download_to_data(self.cap) + return download_to_data(self.n) d.addCallback(_download_to_data) def _download_to_data_done(data): log.msg("download finished") self.failUnlessEqual(data, DATA) d.addCallback(_download_to_data_done) - target_filename = os.path.join(self.basedir, "download.target") - def _download_to_filename(res): - return self.downloader.download_to_filename(self.cap, - target_filename) - d.addCallback(_download_to_filename) - def _download_to_filename_done(res): - newdata = open(target_filename, "rb").read() - self.failUnlessEqual(newdata, DATA) - d.addCallback(_download_to_filename_done) - - target_filename2 = os.path.join(self.basedir, "download.target2") - def _download_to_filehandle(res): - fh = open(target_filename2, "wb") - return self.downloader.download_to_filehandle(self.cap, fh) - d.addCallback(_download_to_filehandle) - def _download_to_filehandle_done(fh): - fh.close() - newdata = open(target_filename2, "rb").read() - self.failUnlessEqual(newdata, DATA) - d.addCallback(_download_to_filehandle_done) - - consumer = GrabEverythingConsumer() - ct = download.ConsumerAdapter(consumer) - d.addCallback(lambda res: - self.downloader.download(self.cap, ct)) - def _download_to_consumer_done(ign): - self.failUnlessEqual(consumer.contents, DATA) - d.addCallback(_download_to_consumer_done) - def _test_read(res): n = self.clients[1].create_node_from_uri(self.uri) d = download_to_data(n) @@ -228,9 +181,10 @@ class SystemTest(SystemTestMixin, unittest.TestCase): def _download_nonexistent_uri(res): baduri = self.mangle_uri(self.uri) + badnode = self.clients[1].create_node_from_uri(baduri) log.msg("about to download non-existent URI", level=log.UNUSUAL, facility="tahoe.tests") - d1 = self.downloader.download_to_data(uri.from_string(baduri)) + d1 = download_to_data(badnode) def _baduri_should_fail(res): log.msg("finished downloading non-existend URI", level=log.UNUSUAL, facility="tahoe.tests") @@ -255,8 +209,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase): u = upload.Data(HELPER_DATA, convergence=convergence) d = self.extra_node.upload(u) def _uploaded(results): - cap = uri.from_string(results.uri) - return self.downloader.download_to_data(cap) + n = self.clients[1].create_node_from_uri(results.uri) + return download_to_data(n) d.addCallback(_uploaded) def _check(newdata): self.failUnlessEqual(newdata, HELPER_DATA) @@ -269,8 +223,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase): u.debug_stash_RemoteEncryptedUploadable = True d = self.extra_node.upload(u) def _uploaded(results): - cap = uri.from_string(results.uri) - return self.downloader.download_to_data(cap) + n = self.clients[1].create_node_from_uri(results.uri) + return download_to_data(n) d.addCallback(_uploaded) def _check(newdata): self.failUnlessEqual(newdata, HELPER_DATA) @@ -357,7 +311,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase): d.addCallback(lambda res: self.extra_node.upload(u2)) def _uploaded(results): - cap = uri.from_string(results.uri) + cap = results.uri log.msg("Second upload complete", level=log.NOISY, facility="tahoe.test.test_system") @@ -385,7 +339,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase): "resumption saved us some work even though we were using random keys:" " read %d bytes out of %d total" % (bytes_sent, len(DATA))) - return self.downloader.download_to_data(cap) + n = self.clients[1].create_node_from_uri(cap) + return download_to_data(n) d.addCallback(_uploaded) def _check(newdata): @@ -885,7 +840,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase): d.addCallback(self.log, "check_publish1 got /") d.addCallback(lambda root: root.get(u"subdir1")) d.addCallback(lambda subdir1: subdir1.get(u"mydata567")) - d.addCallback(lambda filenode: filenode.download_to_data()) + d.addCallback(lambda filenode: download_to_data(filenode)) d.addCallback(self.log, "get finished") def _get_done(data): self.failUnlessEqual(data, self.data) @@ -899,7 +854,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase): d.addCallback(lambda dirnode: self.failUnless(IDirectoryNode.providedBy(dirnode))) d.addCallback(lambda res: rootnode.get_child_at_path(u"subdir1/mydata567")) - d.addCallback(lambda filenode: filenode.download_to_data()) + d.addCallback(lambda filenode: download_to_data(filenode)) d.addCallback(lambda data: self.failUnlessEqual(data, self.data)) d.addCallback(lambda res: rootnode.get_child_at_path(u"subdir1/mydata567")) @@ -925,7 +880,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase): return self._private_node.get_child_at_path(path) d.addCallback(lambda res: get_path(u"personal/sekrit data")) - d.addCallback(lambda filenode: filenode.download_to_data()) + d.addCallback(lambda filenode: download_to_data(filenode)) d.addCallback(lambda data: self.failUnlessEqual(data, self.smalldata)) d.addCallback(lambda res: get_path(u"s2-rw")) d.addCallback(lambda dirnode: self.failUnless(dirnode.is_mutable())) diff --git a/src/allmydata/test/test_web.py b/src/allmydata/test/test_web.py index 9d0aad0e..1801a8bd 100644 --- a/src/allmydata/test/test_web.py +++ b/src/allmydata/test/test_web.py @@ -17,6 +17,7 @@ from allmydata.unknown import UnknownNode from allmydata.web import status, common from allmydata.scripts.debug import CorruptShareOptions, corrupt_share from allmydata.util import fileutil, base32 +from allmydata.util.consumer import download_to_data from allmydata.test.common import FakeCHKFileNode, FakeMutableFileNode, \ create_chk_filenode, WebErrorMixin, ShouldFailMixin, make_mutable_file_uri from allmydata.interfaces import IMutableFileNode @@ -1292,7 +1293,7 @@ class Web(WebMixin, WebErrorMixin, testutil.StallMixin, unittest.TestCase): def failUnlessChildContentsAre(self, node, name, expected_contents): assert isinstance(name, unicode) d = node.get_child_at_path(name) - d.addCallback(lambda node: node.download_to_data()) + d.addCallback(lambda node: download_to_data(node)) def _check(contents): self.failUnlessEqual(contents, expected_contents) d.addCallback(_check) diff --git a/src/allmydata/util/consumer.py b/src/allmydata/util/consumer.py new file mode 100755 index 00000000..4128c200 --- /dev/null +++ b/src/allmydata/util/consumer.py @@ -0,0 +1,30 @@ + +"""This file defines a basic download-to-memory consumer, suitable for use in +a filenode's read() method. See download_to_data() for an example of its use. +""" + +from zope.interface import implements +from twisted.internet.interfaces import IConsumer + +class MemoryConsumer: + implements(IConsumer) + def __init__(self): + self.chunks = [] + self.done = False + def registerProducer(self, p, streaming): + self.producer = p + if streaming: + # call resumeProducing once to start things off + p.resumeProducing() + else: + while not self.done: + p.resumeProducing() + def write(self, data): + self.chunks.append(data) + def unregisterProducer(self): + self.done = True + +def download_to_data(n, offset=0, size=None): + d = n.read(MemoryConsumer(), offset, size) + d.addCallback(lambda mc: "".join(mc.chunks)) + return d