download.py: use producer/consumer to reduce memory usage, closes #129.
authorBrian Warner <warner@lothar.com>
Wed, 19 Sep 2007 07:34:47 +0000 (00:34 -0700)
committerBrian Warner <warner@lothar.com>
Wed, 19 Sep 2007 07:34:47 +0000 (00:34 -0700)
If the DownloadTarget is also an IConsumer, give it control of the brakes
by offering ourselves to target.registerProducer(). When they tell us to
pause, set a flag, which is checked between segment downloads and decodes.
webish.py: make WebDownloadTarget an IConsumer and pass control along to
the http.Request, which already knows how to be an IConsumer.
This reduces the memory footprint of stalled HTTP GETs to a bare minimum,
and thus closes #129.

src/allmydata/download.py
src/allmydata/webish.py

index 12acd64f7cc7d851f8778808dc683feb3950c7d6..8e5cd05e39ff3d8a8872d223fcb1a8e2a3ecce5c 100644 (file)
@@ -3,7 +3,9 @@ import os, random
 from zope.interface import implements
 from twisted.python import log
 from twisted.internet import defer
+from twisted.internet.interfaces import IPushProducer, IConsumer
 from twisted.application import service
+from foolscap.eventual import eventually
 
 from allmydata.util import idlib, mathutil, hashutil
 from allmydata.util.assertutil import _assert
@@ -23,6 +25,9 @@ class BadPlaintextHashValue(Exception):
 class BadCrypttextHashValue(Exception):
     pass
 
+class DownloadStopped(Exception):
+    pass
+
 class Output:
     def __init__(self, downloadable, key, total_length):
         self.downloadable = downloadable
@@ -282,6 +287,7 @@ class SegmentDownloader:
         self.parent.bucket_failed(vbucket)
 
 class FileDownloader:
+    implements(IPushProducer)
     check_crypttext_hash = True
     check_plaintext_hash = True
 
@@ -295,7 +301,12 @@ class FileDownloader:
         self._size = u.size
         self._num_needed_shares = u.needed_shares
 
+        if IConsumer.providedBy(downloadable):
+            downloadable.registerProducer(self, True)
+        self._downloadable = downloadable
         self._output = Output(downloadable, u.key, self._size)
+        self._paused = False
+        self._stopped = False
 
         self.active_buckets = {} # k: shnum, v: bucket
         self._share_buckets = [] # list of (sharenum, bucket) tuples
@@ -311,8 +322,23 @@ class FileDownloader:
                                 "crypttext_hashtree": 0,
                                 }
 
+    def pauseProducing(self):
+        if self._paused:
+            return
+        self._paused = defer.Deferred()
+
+    def resumeProducing(self):
+        if self._paused:
+            p = self._paused
+            self._paused = None
+            eventually(p.callback, None)
+
+    def stopProducing(self):
+        log.msg("Download.stopProducing")
+        self._stopped = True
+
     def start(self):
-        log.msg("starting download [%s]" % idlib.b2a(self._storage_index))
+        log.msg("starting download [%s]" % idlib.b2a(self._storage_index)[:6])
 
         # first step: who should we download from?
         d = defer.maybeDeferred(self._get_all_shareholders)
@@ -324,6 +350,11 @@ class FileDownloader:
         d.addCallback(self._create_validated_buckets)
         # once we know that, we can download blocks from everybody
         d.addCallback(self._download_all_segments)
+        def _finished(res):
+            if IConsumer.providedBy(self._downloadable):
+                self._downloadable.unregisterProducer()
+            return res
+        d.addBoth(_finished)
         def _failed(why):
             self._output.fail(why)
             return why
@@ -541,20 +572,40 @@ class FileDownloader:
         d = defer.succeed(None)
         for segnum in range(self._total_segments-1):
             d.addCallback(self._download_segment, segnum)
+            # this pause, at the end of write, prevents pre-fetch from
+            # happening until the consumer is ready for more data.
+            d.addCallback(self._check_for_pause)
         d.addCallback(self._download_tail_segment, self._total_segments-1)
         return d
 
+    def _check_for_pause(self, res):
+        if self._paused:
+            d = defer.Deferred()
+            self._paused.addCallback(lambda ignored: d.callback(res))
+            return d
+        if self._stopped:
+            raise DownloadStopped("our Consumer called stopProducing()")
+        return res
+
     def _download_segment(self, res, segnum):
+        log.msg("downloading seg#%d of %d (%d%%)"
+                % (segnum, self._total_segments,
+                   100.0 * segnum / self._total_segments))
         # memory footprint: when the SegmentDownloader finishes pulling down
         # all shares, we have 1*segment_size of usage.
         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
+        # pause before using more memory
+        d.addCallback(self._check_for_pause)
         # while the codec does its job, we hit 2*segment_size
         d.addCallback(lambda (shares, shareids):
                       self._codec.decode(shares, shareids))
         # once the codec is done, we drop back to 1*segment_size, because
         # 'shares' goes out of scope. The memory usage is all in the
         # plaintext now, spread out into a bunch of tiny buffers.
+
+        # pause/check-for-stop just before writing, to honor stopProducing
+        d.addCallback(self._check_for_pause)
         def _done(buffers):
             # we start by joining all these buffers together into a single
             # string. This makes Output.write easier, since it wants to hash
@@ -571,10 +622,17 @@ class FileDownloader:
         return d
 
     def _download_tail_segment(self, res, segnum):
+        log.msg("downloading seg#%d of %d (%d%%)"
+                % (segnum, self._total_segments,
+                   100.0 * segnum / self._total_segments))
         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
+        # pause before using more memory
+        d.addCallback(self._check_for_pause)
         d.addCallback(lambda (shares, shareids):
                       self._tail_codec.decode(shares, shareids))
+        # pause/check-for-stop just before writing, to honor stopProducing
+        d.addCallback(self._check_for_pause)
         def _done(buffers):
             # trim off any padding added by the upload side
             segment = "".join(buffers)
@@ -589,11 +647,8 @@ class FileDownloader:
         return d
 
     def _done(self, res):
+        log.msg("download done [%s]" % idlib.b2a(self._storage_index)[:6])
         self._output.close()
-        log.msg("computed CRYPTTEXT_HASH: %s" %
-                idlib.b2a(self._output.crypttext_hash))
-        log.msg("computed PLAINTEXT_HASH: %s" %
-                idlib.b2a(self._output.plaintext_hash))
         if self.check_crypttext_hash:
             _assert(self._crypttext_hash == self._output.crypttext_hash,
                     "bad crypttext_hash: computed=%s, expected=%s" %
index a1e9933830556d4cf8b1a7daec7a3a1f1dc5c4dd..3cc6c83c0fbb5220865afef23a991533c5a77415 100644 (file)
@@ -5,6 +5,7 @@ from twisted.application import service, strports, internet
 from twisted.web import static, resource, server, html, http
 from twisted.python import util, log
 from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
 from nevow import inevow, rend, loaders, appserver, url, tags as T
 from nevow.static import File as nevow_File # TODO: merge with static.File?
 from allmydata.util import fileutil
@@ -271,12 +272,18 @@ class Directory(rend.Page):
             return ""
 
 class WebDownloadTarget:
-    implements(IDownloadTarget)
+    implements(IDownloadTarget, IConsumer)
     def __init__(self, req, content_type, content_encoding):
         self._req = req
         self._content_type = content_type
         self._content_encoding = content_encoding
         self._opened = False
+        self._producer = None
+
+    def registerProducer(self, producer, streaming):
+        self._req.registerProducer(producer, streaming)
+    def unregisterProducer(self):
+        self._req.unregisterProducer()
 
     def open(self, size):
         self._opened = True