]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
Simplify immutable download API: use just filenode.read(consumer, offset, size)
authorBrian Warner <warner@lothar.com>
Tue, 1 Dec 2009 22:44:35 +0000 (17:44 -0500)
committerBrian Warner <warner@lothar.com>
Tue, 1 Dec 2009 22:53:30 +0000 (17:53 -0500)
* remove Downloader.download_to_data/download_to_filename/download_to_filehandle
* remove download.Data/FileName/FileHandle targets
* remove filenode.download/download_to_data/download_to_filename methods
* leave Downloader.download (the whole Downloader will go away eventually)
* add util.consumer.MemoryConsumer/download_to_data, for convenience
  (this is mostly used by unit tests, but it gets used by enough non-test
   code to warrant putting it in allmydata.util)
* update tests
* removes about 180 lines of code. Yay negative code days!

Overall plan is to rewrite immutable/download.py and leave filenode.read() as
the sole read-side API.

18 files changed:
src/allmydata/control.py
src/allmydata/dirnode.py
src/allmydata/frontends/ftpd.py
src/allmydata/frontends/sftpd.py
src/allmydata/immutable/download.py
src/allmydata/immutable/filenode.py
src/allmydata/interfaces.py
src/allmydata/test/common.py
src/allmydata/test/test_download.py
src/allmydata/test/test_encode.py
src/allmydata/test/test_filenode.py
src/allmydata/test/test_immutable.py
src/allmydata/test/test_mutable.py
src/allmydata/test/test_no_network.py
src/allmydata/test/test_repairer.py
src/allmydata/test/test_system.py
src/allmydata/test/test_web.py
src/allmydata/util/consumer.py [new file with mode: 0755]

index 09d10d617df5fad4e563acfb99b1d6c8a793380c..9fe8e8e7248f032aa52551662c21fb77ba30fbaf 100644 (file)
@@ -3,10 +3,11 @@ import os, time
 from zope.interface import implements
 from twisted.application import service
 from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
 from foolscap.api import Referenceable
 from allmydata.interfaces import RIControlClient
 from allmydata.util import fileutil, mathutil
-from allmydata.immutable import upload, download
+from allmydata.immutable import upload
 from twisted.python import log
 
 def get_memory_usage():
@@ -35,6 +36,22 @@ def log_memory_usage(where=""):
                                               stats["VmPeak"],
                                               where))
 
+class FileWritingConsumer:
+    implements(IConsumer)
+    def __init__(self, filename):
+        self.done = False
+        self.f = open(filename, "wb")
+    def registerProducer(self, p, streaming):
+        if streaming:
+            p.resumeProducing()
+        else:
+            while not self.done:
+                p.resumeProducing()
+    def write(self, data):
+        self.f.write(data)
+    def unregisterProducer(self):
+        self.done = True
+        self.f.close()
 
 class ControlServer(Referenceable, service.Service):
     implements(RIControlClient)
@@ -51,7 +68,8 @@ class ControlServer(Referenceable, service.Service):
 
     def remote_download_from_uri_to_file(self, uri, filename):
         filenode = self.parent.create_node_from_uri(uri)
-        d = filenode.download_to_filename(filename)
+        c = FileWritingConsumer(filename)
+        d = filenode.read(c)
         d.addCallback(lambda res: filename)
         return d
 
@@ -181,7 +199,7 @@ class SpeedTest:
             if i >= self.count:
                 return
             n = self.parent.create_node_from_uri(self.uris[i])
-            d1 = n.download(download.FileHandle(Discard()))
+            d1 = n.read(DiscardingConsumer())
             d1.addCallback(_download_one_file, i+1)
             return d1
         d.addCallback(_download_one_file, 0)
@@ -197,10 +215,17 @@ class SpeedTest:
             os.unlink(fn)
         return res
 
-class Discard:
+class DiscardingConsumer:
+    implements(IConsumer)
+    def __init__(self):
+        self.done = False
+    def registerProducer(self, p, streaming):
+        if streaming:
+            p.resumeProducing()
+        else:
+            while not self.done:
+                p.resumeProducing()
     def write(self, data):
         pass
-    # download_to_filehandle explicitly does not close the filehandle it was
-    # given: that is reserved for the provider of the filehandle. Therefore
-    # the lack of a close() method on this otherwise filehandle-like object
-    # is a part of the test.
+    def unregisterProducer(self):
+        self.done = True
index 96c31ffbdf41b9997ebef1a6cbbdcf35695207f2..ae8516f6e8359f4e80378cc967bb7f3528a8d69f 100644 (file)
@@ -18,6 +18,7 @@ from allmydata.monitor import Monitor
 from allmydata.util import hashutil, mathutil, base32, log
 from allmydata.util.assertutil import precondition
 from allmydata.util.netstring import netstring, split_netstring
+from allmydata.util.consumer import download_to_data
 from allmydata.uri import LiteralFileURI, from_string, wrap_dirnode_cap
 from pycryptopp.cipher.aes import AES
 from allmydata.util.dictutil import AuxValueDict
@@ -217,7 +218,7 @@ class DirectoryNode:
             # use the IMutableFileNode API.
             d = self._node.download_best_version()
         else:
-            d = self._node.download_to_data()
+            d = download_to_data(self._node)
         d.addCallback(self._unpack_contents)
         return d
 
index 4d4256df06c57dd04617489c6af0831feca79b24..bba29160b7be8f792bfc23c2b955e00a2dfaf5e5 100644 (file)
@@ -9,7 +9,6 @@ from twisted.protocols import ftp
 
 from allmydata.interfaces import IDirectoryNode, ExistingChildError, \
      NoSuchChildError
-from allmydata.immutable.download import ConsumerAdapter
 from allmydata.immutable.upload import FileHandle
 
 class ReadFile:
@@ -17,8 +16,7 @@ class ReadFile:
     def __init__(self, node):
         self.node = node
     def send(self, consumer):
-        ad = ConsumerAdapter(consumer)
-        d = self.node.download(ad)
+        d = self.node.read(consumer)
         return d # when consumed
 
 class FileWriter:
index 5a713e1c34d06fd0f40be091a3ec611574b65886..4a866eec07a162aba9f6a7452f4b6be9866ae691 100644 (file)
@@ -4,7 +4,6 @@ from zope.interface import implements
 from twisted.python import components
 from twisted.application import service, strports
 from twisted.internet import defer
-from twisted.internet.interfaces import IConsumer
 from twisted.conch.ssh import factory, keys, session
 from twisted.conch.interfaces import ISFTPServer, ISFTPFile, IConchUser
 from twisted.conch.avatar import ConchUser
@@ -15,28 +14,7 @@ from twisted.cred import portal
 from allmydata.interfaces import IDirectoryNode, ExistingChildError, \
      NoSuchChildError
 from allmydata.immutable.upload import FileHandle
-
-class MemoryConsumer:
-    implements(IConsumer)
-    def __init__(self):
-        self.chunks = []
-        self.done = False
-    def registerProducer(self, p, streaming):
-        if streaming:
-            # call resumeProducing once to start things off
-            p.resumeProducing()
-        else:
-            while not self.done:
-                p.resumeProducing()
-    def write(self, data):
-        self.chunks.append(data)
-    def unregisterProducer(self):
-        self.done = True
-
-def download_to_data(n, offset=0, size=None):
-    d = n.read(MemoryConsumer(), offset, size)
-    d.addCallback(lambda mc: "".join(mc.chunks))
-    return d
+from allmydata.util.consumer import download_to_data
 
 class ReadFile:
     implements(ISFTPFile)
index f6356caf374b9792eeec260bb5e86e1d37441b41..261b65d666326ef3ec46098324ce1008e9caf584 100644 (file)
@@ -1,4 +1,4 @@
-import os, random, weakref, itertools, time
+import random, weakref, itertools, time
 from zope.interface import implements
 from twisted.internet import defer
 from twisted.internet.interfaces import IPushProducer, IConsumer
@@ -1196,93 +1196,6 @@ class CiphertextDownloader(log.PrefixingLogMixin):
         return self._status
 
 
-class FileName:
-    implements(IDownloadTarget)
-    def __init__(self, filename):
-        self._filename = filename
-        self.f = None
-    def open(self, size):
-        self.f = open(self._filename, "wb")
-        return self.f
-    def write(self, data):
-        self.f.write(data)
-    def close(self):
-        if self.f:
-            self.f.close()
-    def fail(self, why):
-        if self.f:
-            self.f.close()
-            os.unlink(self._filename)
-    def register_canceller(self, cb):
-        pass # we won't use it
-    def finish(self):
-        pass
-    # The following methods are just because the target might be a
-    # repairer.DownUpConnector, and just because the current CHKUpload object
-    # expects to find the storage index and encoding parameters in its
-    # Uploadable.
-    def set_storageindex(self, storageindex):
-        pass
-    def set_encodingparams(self, encodingparams):
-        pass
-
-class Data:
-    implements(IDownloadTarget)
-    def __init__(self):
-        self._data = []
-    def open(self, size):
-        pass
-    def write(self, data):
-        self._data.append(data)
-    def close(self):
-        self.data = "".join(self._data)
-        del self._data
-    def fail(self, why):
-        del self._data
-    def register_canceller(self, cb):
-        pass # we won't use it
-    def finish(self):
-        return self.data
-    # The following methods are just because the target might be a
-    # repairer.DownUpConnector, and just because the current CHKUpload object
-    # expects to find the storage index and encoding parameters in its
-    # Uploadable.
-    def set_storageindex(self, storageindex):
-        pass
-    def set_encodingparams(self, encodingparams):
-        pass
-
-class FileHandle:
-    """Use me to download data to a pre-defined filehandle-like object. I
-    will use the target's write() method. I will *not* close the filehandle:
-    I leave that up to the originator of the filehandle. The download process
-    will return the filehandle when it completes.
-    """
-    implements(IDownloadTarget)
-    def __init__(self, filehandle):
-        self._filehandle = filehandle
-    def open(self, size):
-        pass
-    def write(self, data):
-        self._filehandle.write(data)
-    def close(self):
-        # the originator of the filehandle reserves the right to close it
-        pass
-    def fail(self, why):
-        pass
-    def register_canceller(self, cb):
-        pass
-    def finish(self):
-        return self._filehandle
-    # The following methods are just because the target might be a
-    # repairer.DownUpConnector, and just because the current CHKUpload object
-    # expects to find the storage index and encoding parameters in its
-    # Uploadable.
-    def set_storageindex(self, storageindex):
-        pass
-    def set_encodingparams(self, encodingparams):
-        pass
-
 class ConsumerAdapter:
     implements(IDownloadTarget, IConsumer)
     def __init__(self, consumer):
@@ -1351,11 +1264,3 @@ class Downloader:
             history.add_download(dl.get_download_status())
         d = dl.start()
         return d
-
-    # utility functions
-    def download_to_data(self, uri, _log_msg_id=None, history=None):
-        return self.download(uri, Data(), _log_msg_id=_log_msg_id, history=history)
-    def download_to_filename(self, uri, filename, _log_msg_id=None):
-        return self.download(uri, FileName(filename), _log_msg_id=_log_msg_id)
-    def download_to_filehandle(self, uri, filehandle, _log_msg_id=None):
-        return self.download(uri, FileHandle(filehandle), _log_msg_id=_log_msg_id)
index c57709b3431aa860ffb52c9902ae15670445dda2..196d4e0335ae170d861ac18e093689c309ecc11e 100644 (file)
@@ -2,7 +2,7 @@ import copy, os.path, stat
 from cStringIO import StringIO
 from zope.interface import implements
 from twisted.internet import defer
-from twisted.internet.interfaces import IPushProducer, IConsumer
+from twisted.internet.interfaces import IPushProducer
 from twisted.protocols import basic
 from foolscap.api import eventually
 from allmydata.interfaces import IImmutableFileNode, ICheckable, \
@@ -284,6 +284,8 @@ class ImmutableFileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
         return v.start()
 
     def read(self, consumer, offset=0, size=None):
+        self.log("read", offset=offset, size=size,
+                 umid="UPP8FA", level=log.OPERATIONAL)
         if size is None:
             size = self.get_size() - offset
         size = min(size, self.get_size() - offset)
@@ -291,24 +293,16 @@ class ImmutableFileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
         if offset == 0 and size == self.get_size():
             # don't use the cache, just do a normal streaming download
             self.log("doing normal full download", umid="VRSBwg", level=log.OPERATIONAL)
-            return self.download(download.ConsumerAdapter(consumer))
+            target = download.ConsumerAdapter(consumer)
+            return self._downloader.download(self.get_cap(), target,
+                                             self._parentmsgid,
+                                             history=self._history)
 
         d = self.download_cache.when_range_available(offset, size)
         d.addCallback(lambda res:
                       self.download_cache.read(consumer, offset, size))
         return d
 
-    def download(self, target):
-        return self._downloader.download(self.get_cap(), target,
-                                         self._parentmsgid,
-                                         history=self._history)
-
-    def download_to_data(self):
-        return self._downloader.download_to_data(self.get_cap(),
-                                                 history=self._history)
-    def download_to_filename(self, filename):
-        return self._downloader.download_to_filename(self.get_cap(), filename)
-
 class LiteralProducer:
     implements(IPushProducer)
     def resumeProducing(self):
@@ -367,19 +361,3 @@ class LiteralFileNode(_ImmutableFileNodeBase):
         d = basic.FileSender().beginFileTransfer(StringIO(data), consumer)
         d.addCallback(lambda lastSent: consumer)
         return d
-
-    def download(self, target):
-        # note that this does not update the stats_provider
-        data = self.u.data
-        if IConsumer.providedBy(target):
-            target.registerProducer(LiteralProducer(), True)
-        target.open(len(data))
-        target.write(data)
-        if IConsumer.providedBy(target):
-            target.unregisterProducer()
-        target.close()
-        return defer.maybeDeferred(target.finish)
-
-    def download_to_data(self):
-        data = self.u.data
-        return defer.succeed(data)
index c6897ef5e687224fafdc6c5d4f025968c735b098..babacd61b1a1ae7ef9dc5bcdb0e34e8c99ab914f 100644 (file)
@@ -575,13 +575,6 @@ class IFileNode(IFilesystemNode):
     container, like IDirectoryNode."""
 
 class IImmutableFileNode(IFileNode):
-    def download(target):
-        """Download the file's contents to a given IDownloadTarget"""
-
-    def download_to_data():
-        """Download the file's contents. Return a Deferred that fires
-        with those contents."""
-
     def read(consumer, offset=0, size=None):
         """Download a portion (possibly all) of the file's contents, making
         them available to the given IConsumer. Return a Deferred that fires
@@ -613,25 +606,8 @@ class IImmutableFileNode(IFileNode):
         p.stopProducing(), which will result in an exception being delivered
         via deferred.errback().
 
-        A simple download-to-memory consumer example would look like this::
-
-         class MemoryConsumer:
-           implements(IConsumer)
-           def __init__(self):
-             self.chunks = []
-             self.done = False
-           def registerProducer(self, p, streaming):
-             assert streaming == False
-             while not self.done:
-               p.resumeProducing()
-           def write(self, data):
-             self.chunks.append(data)
-           def unregisterProducer(self):
-             self.done = True
-         d = filenode.read(MemoryConsumer())
-         d.addCallback(lambda mc: "".join(mc.chunks))
-         return d
-
+        See src/allmydata/util/consumer.py for an example of a simple
+        download-to-memory consumer.
         """
 
 class IMutableFileNode(IFileNode):
index 98bb6aed60685973d9d60f53a77ac360320a2300..7113ccebf576d4246a289e7663abfd3de82d6e25 100644 (file)
@@ -1,7 +1,7 @@
 import os, random, struct
 from zope.interface import implements
 from twisted.internet import defer
-from twisted.internet.interfaces import IConsumer
+from twisted.internet.interfaces import IPullProducer
 from twisted.python import failure
 from twisted.application import service
 from twisted.web.error import Error as WebError
@@ -18,6 +18,7 @@ from allmydata.storage.server import storage_index_to_dir
 from allmydata.storage.mutable import MutableShareFile
 from allmydata.util import hashutil, log, fileutil, pollmixin
 from allmydata.util.assertutil import precondition
+from allmydata.util.consumer import download_to_data
 from allmydata.stats import StatsGathererService
 from allmydata.key_generator import KeyGeneratorService
 import common_util as testutil
@@ -31,6 +32,11 @@ def flush_but_dont_ignore(res):
     d.addCallback(_done)
     return d
 
+class DummyProducer:
+    implements(IPullProducer)
+    def resumeProducing(self):
+        pass
+
 class FakeCHKFileNode:
     """I provide IImmutableFileNode, but all of my data is stored in a
     class-level dictionary."""
@@ -98,40 +104,36 @@ class FakeCHKFileNode:
     def is_readonly(self):
         return True
 
-    def download(self, target):
-        if self.my_uri.to_string() not in self.all_contents:
-            f = failure.Failure(NotEnoughSharesError(None, 0, 3))
-            target.fail(f)
-            return defer.fail(f)
-        data = self.all_contents[self.my_uri.to_string()]
-        target.open(len(data))
-        target.write(data)
-        target.close()
-        return defer.maybeDeferred(target.finish)
-    def download_to_data(self):
-        if self.my_uri.to_string() not in self.all_contents:
-            return defer.fail(NotEnoughSharesError(None, 0, 3))
-        data = self.all_contents[self.my_uri.to_string()]
-        return defer.succeed(data)
     def get_size(self):
         try:
             data = self.all_contents[self.my_uri.to_string()]
         except KeyError, le:
             raise NotEnoughSharesError(le, 0, 3)
         return len(data)
+
     def read(self, consumer, offset=0, size=None):
-        d = self.download_to_data()
-        def _got(data):
-            start = offset
-            if size is not None:
-                end = offset + size
-            else:
-                end = len(data)
-            consumer.write(data[start:end])
-            return consumer
-        d.addCallback(_got)
+        # we don't bother to call registerProducer/unregisterProducer,
+        # because it's a hassle to write a dummy Producer that does the right
+        # thing (we have to make sure that DummyProducer.resumeProducing
+        # writes the data into the consumer immediately, otherwise it will
+        # loop forever).
+
+        d = defer.succeed(None)
+        d.addCallback(self._read, consumer, offset, size)
         return d
 
+    def _read(self, ignored, consumer, offset, size):
+        if self.my_uri.to_string() not in self.all_contents:
+            raise NotEnoughSharesError(None, 0, 3)
+        data = self.all_contents[self.my_uri.to_string()]
+        start = offset
+        if size is not None:
+            end = offset + size
+        else:
+            end = len(data)
+        consumer.write(data[start:end])
+        return consumer
+
 def make_chk_file_cap(size):
     return uri.CHKFileURI(key=os.urandom(16),
                           uri_extension_hash=os.urandom(32),
@@ -927,6 +929,7 @@ class ShareManglingMixin(SystemTestMixin):
             d2 = cl0.upload(immutable.upload.Data(TEST_DATA, convergence=""))
             def _after_upload(u):
                 filecap = u.uri
+                self.n = self.clients[1].create_node_from_uri(filecap)
                 self.uri = uri.CHKFileURI.init_from_string(filecap)
                 return cl0.create_node_from_uri(filecap)
             d2.addCallback(_after_upload)
@@ -1045,8 +1048,7 @@ class ShareManglingMixin(SystemTestMixin):
         return sum_of_write_counts
 
     def _download_and_check_plaintext(self, unused=None):
-        d = self.clients[1].downloader.download_to_data(self.uri)
-
+        d = download_to_data(self.n)
         def _after_download(result):
             self.failUnlessEqual(result, TEST_DATA)
         d.addCallback(_after_download)
@@ -1140,28 +1142,6 @@ class ErrorMixin(WebErrorMixin):
             print "First Error:", f.value.subFailure
         return f
 
-class MemoryConsumer:
-    implements(IConsumer)
-    def __init__(self):
-        self.chunks = []
-        self.done = False
-    def registerProducer(self, p, streaming):
-        if streaming:
-            # call resumeProducing once to start things off
-            p.resumeProducing()
-        else:
-            while not self.done:
-                p.resumeProducing()
-    def write(self, data):
-        self.chunks.append(data)
-    def unregisterProducer(self):
-        self.done = True
-
-def download_to_data(n, offset=0, size=None):
-    d = n.read(MemoryConsumer(), offset, size)
-    d.addCallback(lambda mc: "".join(mc.chunks))
-    return d
-
 def corrupt_field(data, offset, size, debug=False):
     if random.random() < 0.5:
         newdata = testutil.flip_one_bit(data, offset, size)
index a4743d2eb6c62a8e3038d850342007d68da80c91..b54bf017437b4326caae99e7ea11ed280ebb3eaa 100644 (file)
@@ -8,6 +8,7 @@ from twisted.trial import unittest
 from allmydata import uri
 from allmydata.storage.server import storage_index_to_dir
 from allmydata.util import base32, fileutil
+from allmydata.util.consumer import download_to_data
 from allmydata.immutable import upload
 from allmydata.test.no_network import GridTestMixin
 
@@ -173,7 +174,7 @@ class DownloadTest(GridTestMixin, unittest.TestCase):
 
     def download_immutable(self, ignored=None):
         n = self.c0.create_node_from_uri(immutable_uri)
-        d = n.download_to_data()
+        d = download_to_data(n)
         def _got_data(data):
             self.failUnlessEqual(data, plaintext)
         d.addCallback(_got_data)
index c6a7bda1592e9f06ca20c95bc013c8bd439f3c3c..2d91dbc581b9aace89a0e00b4b6f8a78b072a37e 100644 (file)
@@ -1,13 +1,13 @@
 from zope.interface import implements
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
-from twisted.internet.interfaces import IConsumer
 from twisted.python.failure import Failure
 from foolscap.api import fireEventually
 from allmydata import hashtree, uri
 from allmydata.immutable import encode, upload, download
 from allmydata.util import hashutil
 from allmydata.util.assertutil import _assert
+from allmydata.util.consumer import MemoryConsumer
 from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader, \
      NotEnoughSharesError, IStorageBroker
 from allmydata.monitor import Monitor
@@ -384,10 +384,9 @@ class Encode(unittest.TestCase):
         # 5 segments: 25, 25, 25, 25, 1
         return self.do_encode(25, 101, 100, 5, 15, 8)
 
-class PausingTarget(download.Data):
-    implements(IConsumer)
+class PausingConsumer(MemoryConsumer):
     def __init__(self):
-        download.Data.__init__(self)
+        MemoryConsumer.__init__(self)
         self.size = 0
         self.writes = 0
     def write(self, data):
@@ -398,22 +397,18 @@ class PausingTarget(download.Data):
             # last one (since then the _unpause timer will still be running)
             self.producer.pauseProducing()
             reactor.callLater(0.1, self._unpause)
-        return download.Data.write(self, data)
+        return MemoryConsumer.write(self, data)
     def _unpause(self):
         self.producer.resumeProducing()
-    def registerProducer(self, producer, streaming):
-        self.producer = producer
-    def unregisterProducer(self):
-        self.producer = None
 
-class PausingAndStoppingTarget(PausingTarget):
+class PausingAndStoppingConsumer(PausingConsumer):
     def write(self, data):
         self.producer.pauseProducing()
         reactor.callLater(0.5, self._stop)
     def _stop(self):
         self.producer.stopProducing()
 
-class StoppingTarget(PausingTarget):
+class StoppingConsumer(PausingConsumer):
     def write(self, data):
         self.producer.stopProducing()
 
@@ -425,7 +420,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
                          max_segment_size=25,
                          bucket_modes={},
                          recover_mode="recover",
-                         target=None,
+                         consumer=None,
                          ):
         if AVAILABLE_SHARES is None:
             AVAILABLE_SHARES = k_and_happy_and_n[2]
@@ -434,7 +429,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
                       max_segment_size, bucket_modes, data)
         # that fires with (uri_extension_hash, e, shareholders)
         d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode,
-                      target=target)
+                      consumer=consumer)
         # that fires with newdata
         def _downloaded((newdata, fd)):
             self.failUnless(newdata == data, str((len(newdata), len(data))))
@@ -478,7 +473,7 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
         return d
 
     def recover(self, (res, key, shareholders), AVAILABLE_SHARES,
-                recover_mode, target=None):
+                recover_mode, consumer=None):
         verifycap = res
 
         if "corrupt_key" in recover_mode:
@@ -495,9 +490,10 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
                            size=verifycap.size)
 
         sb = FakeStorageBroker()
-        if not target:
-            target = download.Data()
-        target = download.DecryptingTarget(target, u.key)
+        if not consumer:
+            consumer = MemoryConsumer()
+        innertarget = download.ConsumerAdapter(consumer)
+        target = download.DecryptingTarget(innertarget, u.key)
         fd = download.CiphertextDownloader(sb, u.get_verify_cap(), target, monitor=Monitor())
 
         # we manually cycle the CiphertextDownloader through a number of steps that
@@ -544,7 +540,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
 
         d.addCallback(fd._download_all_segments)
         d.addCallback(fd._done)
-        def _done(newdata):
+        def _done(t):
+            newdata = "".join(consumer.chunks)
             return (newdata, fd)
         d.addCallback(_done)
         return d
@@ -582,26 +579,26 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
         return self.send_and_recover(datalen=101)
 
     def test_pause(self):
-        # use a DownloadTarget that does pauseProducing/resumeProducing a few
-        # times, then finishes
-        t = PausingTarget()
-        d = self.send_and_recover(target=t)
+        # use a download target that does pauseProducing/resumeProducing a
+        # few times, then finishes
+        c = PausingConsumer()
+        d = self.send_and_recover(consumer=c)
         return d
 
     def test_pause_then_stop(self):
-        # use a DownloadTarget that pauses, then stops.
-        t = PausingAndStoppingTarget()
+        # use a download target that pauses, then stops.
+        c = PausingAndStoppingConsumer()
         d = self.shouldFail(download.DownloadStopped, "test_pause_then_stop",
                             "our Consumer called stopProducing()",
-                            self.send_and_recover, target=t)
+                            self.send_and_recover, consumer=c)
         return d
 
     def test_stop(self):
-        # use a DownloadTarget that does an immediate stop (ticket #473)
-        t = StoppingTarget()
+        # use a download targetthat does an immediate stop (ticket #473)
+        c = StoppingConsumer()
         d = self.shouldFail(download.DownloadStopped, "test_stop",
                             "our Consumer called stopProducing()",
-                            self.send_and_recover, target=t)
+                            self.send_and_recover, consumer=c)
         return d
 
     # the following tests all use 4-out-of-10 encoding
index feeda226d39a497c9502871942165640972524f7..a8de4207bf90c27e63515b4f7456bb133f2541c8 100644 (file)
@@ -2,11 +2,10 @@
 from twisted.trial import unittest
 from allmydata import uri, client
 from allmydata.monitor import Monitor
-from allmydata.immutable import download
 from allmydata.immutable.filenode import ImmutableFileNode, LiteralFileNode
 from allmydata.mutable.filenode import MutableFileNode
 from allmydata.util import hashutil, cachedir
-from allmydata.test.common import download_to_data
+from allmydata.util.consumer import download_to_data
 
 class NotANode:
     pass
@@ -77,17 +76,11 @@ class Node(unittest.TestCase):
         self.failUnlessEqual(v, None)
         self.failUnlessEqual(fn1.get_repair_cap(), None)
 
-        d = fn1.download(download.Data())
+        d = download_to_data(fn1)
         def _check(res):
             self.failUnlessEqual(res, DATA)
         d.addCallback(_check)
 
-        d.addCallback(lambda res: fn1.download_to_data())
-        d.addCallback(_check)
-
-        d.addCallback(lambda res: download_to_data(fn1))
-        d.addCallback(_check)
-
         d.addCallback(lambda res: download_to_data(fn1, 1, 5))
         def _check_segment(res):
             self.failUnlessEqual(res, DATA[1:1+5])
index ad1843c04227254268fa6323a4a74e1cc3478448..c9dfc064f97fdd27230f88ac0d913594a648ecd0 100644 (file)
@@ -1,5 +1,6 @@
 from allmydata.test import common
 from allmydata.interfaces import NotEnoughSharesError
+from allmydata.util.consumer import download_to_data
 from twisted.internet import defer
 from twisted.trial import unittest
 import random
@@ -26,7 +27,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase):
         d.addCallback(_then_delete_8)
 
         def _then_download(unused=None):
-            d2 = self.clients[1].downloader.download_to_data(self.uri)
+            d2 = download_to_data(self.n)
 
             def _after_download_callb(result):
                 self.fail() # should have gotten an errback instead
@@ -93,7 +94,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase):
 
         before_download_reads = self._count_reads()
         def _attempt_to_download(unused=None):
-            d2 = self.clients[1].downloader.download_to_data(self.uri)
+            d2 = download_to_data(self.n)
 
             def _callb(res):
                 self.fail("Should have gotten an error from attempt to download, not %r" % (res,))
@@ -126,7 +127,7 @@ class Test(common.ShareManglingMixin, unittest.TestCase):
 
         before_download_reads = self._count_reads()
         def _attempt_to_download(unused=None):
-            d2 = self.clients[1].downloader.download_to_data(self.uri)
+            d2 = download_to_data(self.n)
 
             def _callb(res):
                 self.fail("Should have gotten an error from attempt to download, not %r" % (res,))
index 7b7c0990f66aa0603dc5a16e5abccfef4c9c5d84..fc6dc37f00810c6b15b1b9fcd8a2f625839ba877 100644 (file)
@@ -5,7 +5,6 @@ from twisted.trial import unittest
 from twisted.internet import defer, reactor
 from allmydata import uri, client
 from allmydata.nodemaker import NodeMaker
-from allmydata.immutable import download
 from allmydata.util import base32
 from allmydata.util.idlib import shortnodeid_b2a
 from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
@@ -264,8 +263,6 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
             d.addCallback(lambda res: n.overwrite("contents 2"))
             d.addCallback(lambda res: n.download_best_version())
             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
-            d.addCallback(lambda res: n.download(download.Data()))
-            d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
             d.addCallback(lambda smap: n.upload("contents 3", smap))
             d.addCallback(lambda res: n.download_best_version())
@@ -495,8 +492,6 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
             d.addCallback(lambda res: n.overwrite("contents 2"))
             d.addCallback(lambda res: n.download_best_version())
             d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
-            d.addCallback(lambda res: n.download(download.Data()))
-            d.addCallback(lambda res: self.failUnlessEqual(res, "contents 2"))
             d.addCallback(lambda res: n.get_servermap(MODE_WRITE))
             d.addCallback(lambda smap: n.upload("contents 3", smap))
             d.addCallback(lambda res: n.download_best_version())
index 911dc86f33f468a321a101df46c870a173a8c672..ef7eed5cafb2cba80418328f3fdb7bb9d01f4e15 100644 (file)
@@ -5,7 +5,7 @@ from twisted.trial import unittest
 from twisted.application import service
 from allmydata.test.no_network import NoNetworkGrid
 from allmydata.immutable.upload import Data
-
+from allmydata.util.consumer import download_to_data
 
 class Harness(unittest.TestCase):
     def setUp(self):
@@ -32,7 +32,7 @@ class Harness(unittest.TestCase):
         d = c0.upload(data)
         def _uploaded(res):
             n = c0.create_node_from_uri(res.uri)
-            return n.download_to_data()
+            return download_to_data(n)
         d.addCallback(_uploaded)
         def _check(res):
             self.failUnlessEqual(res, DATA)
index 337f89aa1189d0c9148833003fae9ef7dfdd9cd8..8c14f74eca8fa96a8e73da6732ab2b287b4e846b 100644 (file)
@@ -4,6 +4,7 @@ from allmydata.monitor import Monitor
 from allmydata import check_results
 from allmydata.interfaces import NotEnoughSharesError
 from allmydata.immutable import repairer, upload
+from allmydata.util.consumer import download_to_data
 from twisted.internet import defer
 from twisted.trial import unittest
 import random
@@ -428,7 +429,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin,
         d.addCallback(lambda ignored:
                       self.shouldFail(NotEnoughSharesError, "then_download",
                                       None,
-                                      self.c1_filenode.download_to_data))
+                                      download_to_data, self.c1_filenode))
 
         d.addCallback(lambda ignored:
                       self.shouldFail(NotEnoughSharesError, "then_repair",
@@ -499,7 +500,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin,
 
         d.addCallback(lambda ignored:
                       self.delete_shares_numbered(self.uri, range(3, 10+1)))
-        d.addCallback(lambda ignored: self.c1_filenode.download_to_data())
+        d.addCallback(lambda ignored: download_to_data(self.c1_filenode))
         d.addCallback(lambda newdata:
                       self.failUnlessEqual(newdata, common.TEST_DATA))
         return d
@@ -544,7 +545,7 @@ class Repairer(GridTestMixin, unittest.TestCase, RepairTestMixin,
 
         d.addCallback(lambda ignored:
                       self.delete_shares_numbered(self.uri, range(3, 10+1)))
-        d.addCallback(lambda ignored: self.c1_filenode.download_to_data())
+        d.addCallback(lambda ignored: download_to_data(self.c1_filenode))
         d.addCallback(lambda newdata:
                       self.failUnlessEqual(newdata, common.TEST_DATA))
         return d
index 423657db7cf2d26b0f8035cb057b18e8b5c617b0..f5c5c91e9b60b93bc0dc83bbc786208de76e0e9c 100644 (file)
@@ -1,20 +1,19 @@
 from base64 import b32encode
 import os, sys, time, simplejson
 from cStringIO import StringIO
-from zope.interface import implements
 from twisted.trial import unittest
 from twisted.internet import defer
 from twisted.internet import threads # CLI tests use deferToThread
 from twisted.internet.error import ConnectionDone, ConnectionLost
-from twisted.internet.interfaces import IConsumer, IPushProducer
 import allmydata
 from allmydata import uri
 from allmydata.storage.mutable import MutableShareFile
 from allmydata.storage.server import si_a2b
-from allmydata.immutable import download, offloaded, upload
+from allmydata.immutable import offloaded, upload
 from allmydata.immutable.filenode import ImmutableFileNode, LiteralFileNode
 from allmydata.util import idlib, mathutil
 from allmydata.util import log, base32
+from allmydata.util.consumer import MemoryConsumer, download_to_data
 from allmydata.scripts import runner
 from allmydata.interfaces import IDirectoryNode, IFileNode, \
      NoSuchChildError, NoSharesError
@@ -26,8 +25,7 @@ from twisted.python.failure import Failure
 from twisted.web.client import getPage
 from twisted.web.error import Error
 
-from allmydata.test.common import SystemTestMixin, MemoryConsumer, \
-     download_to_data
+from allmydata.test.common import SystemTestMixin
 
 LARGE_DATA = """
 This is some data to publish to the virtual drive, which needs to be large
@@ -47,22 +45,6 @@ class CountingDataUploadable(upload.Data):
                 self.interrupt_after_d.callback(self)
         return upload.Data.read(self, length)
 
-class GrabEverythingConsumer:
-    implements(IConsumer)
-
-    def __init__(self):
-        self.contents = ""
-
-    def registerProducer(self, producer, streaming):
-        assert streaming
-        assert IPushProducer.providedBy(producer)
-
-    def write(self, data):
-        self.contents += data
-
-    def unregisterProducer(self):
-        pass
-
 class SystemTest(SystemTestMixin, unittest.TestCase):
     timeout = 3600 # It takes longer than 960 seconds on Zandr's ARM box.
 
@@ -137,7 +119,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             self.uri = theuri
             assert isinstance(self.uri, str), self.uri
             self.cap = uri.from_string(self.uri)
-            self.downloader = self.clients[1].downloader
+            self.n = self.clients[1].create_node_from_uri(self.uri)
         d.addCallback(_upload_done)
 
         def _upload_again(res):
@@ -153,42 +135,13 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
 
         def _download_to_data(res):
             log.msg("DOWNLOADING")
-            return self.downloader.download_to_data(self.cap)
+            return download_to_data(self.n)
         d.addCallback(_download_to_data)
         def _download_to_data_done(data):
             log.msg("download finished")
             self.failUnlessEqual(data, DATA)
         d.addCallback(_download_to_data_done)
 
-        target_filename = os.path.join(self.basedir, "download.target")
-        def _download_to_filename(res):
-            return self.downloader.download_to_filename(self.cap,
-                                                        target_filename)
-        d.addCallback(_download_to_filename)
-        def _download_to_filename_done(res):
-            newdata = open(target_filename, "rb").read()
-            self.failUnlessEqual(newdata, DATA)
-        d.addCallback(_download_to_filename_done)
-
-        target_filename2 = os.path.join(self.basedir, "download.target2")
-        def _download_to_filehandle(res):
-            fh = open(target_filename2, "wb")
-            return self.downloader.download_to_filehandle(self.cap, fh)
-        d.addCallback(_download_to_filehandle)
-        def _download_to_filehandle_done(fh):
-            fh.close()
-            newdata = open(target_filename2, "rb").read()
-            self.failUnlessEqual(newdata, DATA)
-        d.addCallback(_download_to_filehandle_done)
-
-        consumer = GrabEverythingConsumer()
-        ct = download.ConsumerAdapter(consumer)
-        d.addCallback(lambda res:
-                      self.downloader.download(self.cap, ct))
-        def _download_to_consumer_done(ign):
-            self.failUnlessEqual(consumer.contents, DATA)
-        d.addCallback(_download_to_consumer_done)
-
         def _test_read(res):
             n = self.clients[1].create_node_from_uri(self.uri)
             d = download_to_data(n)
@@ -228,9 +181,10 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
 
         def _download_nonexistent_uri(res):
             baduri = self.mangle_uri(self.uri)
+            badnode = self.clients[1].create_node_from_uri(baduri)
             log.msg("about to download non-existent URI", level=log.UNUSUAL,
                     facility="tahoe.tests")
-            d1 = self.downloader.download_to_data(uri.from_string(baduri))
+            d1 = download_to_data(badnode)
             def _baduri_should_fail(res):
                 log.msg("finished downloading non-existend URI",
                         level=log.UNUSUAL, facility="tahoe.tests")
@@ -255,8 +209,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             u = upload.Data(HELPER_DATA, convergence=convergence)
             d = self.extra_node.upload(u)
             def _uploaded(results):
-                cap = uri.from_string(results.uri)
-                return self.downloader.download_to_data(cap)
+                n = self.clients[1].create_node_from_uri(results.uri)
+                return download_to_data(n)
             d.addCallback(_uploaded)
             def _check(newdata):
                 self.failUnlessEqual(newdata, HELPER_DATA)
@@ -269,8 +223,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             u.debug_stash_RemoteEncryptedUploadable = True
             d = self.extra_node.upload(u)
             def _uploaded(results):
-                cap = uri.from_string(results.uri)
-                return self.downloader.download_to_data(cap)
+                n = self.clients[1].create_node_from_uri(results.uri)
+                return download_to_data(n)
             d.addCallback(_uploaded)
             def _check(newdata):
                 self.failUnlessEqual(newdata, HELPER_DATA)
@@ -357,7 +311,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             d.addCallback(lambda res: self.extra_node.upload(u2))
 
             def _uploaded(results):
-                cap = uri.from_string(results.uri)
+                cap = results.uri
                 log.msg("Second upload complete", level=log.NOISY,
                         facility="tahoe.test.test_system")
 
@@ -385,7 +339,8 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
                                 "resumption saved us some work even though we were using random keys:"
                                 " read %d bytes out of %d total" %
                                 (bytes_sent, len(DATA)))
-                return self.downloader.download_to_data(cap)
+                n = self.clients[1].create_node_from_uri(cap)
+                return download_to_data(n)
             d.addCallback(_uploaded)
 
             def _check(newdata):
@@ -885,7 +840,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
         d.addCallback(self.log, "check_publish1 got /")
         d.addCallback(lambda root: root.get(u"subdir1"))
         d.addCallback(lambda subdir1: subdir1.get(u"mydata567"))
-        d.addCallback(lambda filenode: filenode.download_to_data())
+        d.addCallback(lambda filenode: download_to_data(filenode))
         d.addCallback(self.log, "get finished")
         def _get_done(data):
             self.failUnlessEqual(data, self.data)
@@ -899,7 +854,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
         d.addCallback(lambda dirnode:
                       self.failUnless(IDirectoryNode.providedBy(dirnode)))
         d.addCallback(lambda res: rootnode.get_child_at_path(u"subdir1/mydata567"))
-        d.addCallback(lambda filenode: filenode.download_to_data())
+        d.addCallback(lambda filenode: download_to_data(filenode))
         d.addCallback(lambda data: self.failUnlessEqual(data, self.data))
 
         d.addCallback(lambda res: rootnode.get_child_at_path(u"subdir1/mydata567"))
@@ -925,7 +880,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             return self._private_node.get_child_at_path(path)
 
         d.addCallback(lambda res: get_path(u"personal/sekrit data"))
-        d.addCallback(lambda filenode: filenode.download_to_data())
+        d.addCallback(lambda filenode: download_to_data(filenode))
         d.addCallback(lambda data: self.failUnlessEqual(data, self.smalldata))
         d.addCallback(lambda res: get_path(u"s2-rw"))
         d.addCallback(lambda dirnode: self.failUnless(dirnode.is_mutable()))
index 9d0aad0e4b9522a2c65f3993260c3cc8e7c1b80b..1801a8bdee3d1714ceff39d34463d5e2734f525f 100644 (file)
@@ -17,6 +17,7 @@ from allmydata.unknown import UnknownNode
 from allmydata.web import status, common
 from allmydata.scripts.debug import CorruptShareOptions, corrupt_share
 from allmydata.util import fileutil, base32
+from allmydata.util.consumer import download_to_data
 from allmydata.test.common import FakeCHKFileNode, FakeMutableFileNode, \
      create_chk_filenode, WebErrorMixin, ShouldFailMixin, make_mutable_file_uri
 from allmydata.interfaces import IMutableFileNode
@@ -1292,7 +1293,7 @@ class Web(WebMixin, WebErrorMixin, testutil.StallMixin, unittest.TestCase):
     def failUnlessChildContentsAre(self, node, name, expected_contents):
         assert isinstance(name, unicode)
         d = node.get_child_at_path(name)
-        d.addCallback(lambda node: node.download_to_data())
+        d.addCallback(lambda node: download_to_data(node))
         def _check(contents):
             self.failUnlessEqual(contents, expected_contents)
         d.addCallback(_check)
diff --git a/src/allmydata/util/consumer.py b/src/allmydata/util/consumer.py
new file mode 100755 (executable)
index 0000000..4128c20
--- /dev/null
@@ -0,0 +1,30 @@
+
+"""This file defines a basic download-to-memory consumer, suitable for use in
+a filenode's read() method. See download_to_data() for an example of its use.
+"""
+
+from zope.interface import implements
+from twisted.internet.interfaces import IConsumer
+
+class MemoryConsumer:
+    implements(IConsumer)
+    def __init__(self):
+        self.chunks = []
+        self.done = False
+    def registerProducer(self, p, streaming):
+        self.producer = p
+        if streaming:
+            # call resumeProducing once to start things off
+            p.resumeProducing()
+        else:
+            while not self.done:
+                p.resumeProducing()
+    def write(self, data):
+        self.chunks.append(data)
+    def unregisterProducer(self):
+        self.done = True
+
+def download_to_data(n, offset=0, size=None):
+    d = n.read(MemoryConsumer(), offset, size)
+    d.addCallback(lambda mc: "".join(mc.chunks))
+    return d