From 1340c484c6c60c524096e98ee78243c8c00a12b7 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Wed, 19 Sep 2007 00:34:47 -0700
Subject: [PATCH] download.py: use producer/consumer to reduce memory usage,
 closes #129. If the DownloadTarget is also an IConsumer, give it control of
 the brakes by offering ourselves to target.registerProducer(). When they tell
 us to pause, set a flag, which is checked between segment downloads and
 decodes. webish.py: make WebDownloadTarget an IConsumer and pass control
 along to the http.Request, which already knows how to be an IConsumer. This
 reduces the memory footprint of stalled HTTP GETs to a bare minimum, and thus
 closes #129.

---
 src/allmydata/download.py | 65 ++++++++++++++++++++++++++++++++++++---
 src/allmydata/webish.py   |  9 +++++-
 2 files changed, 68 insertions(+), 6 deletions(-)

diff --git a/src/allmydata/download.py b/src/allmydata/download.py
index 12acd64f..8e5cd05e 100644
--- a/src/allmydata/download.py
+++ b/src/allmydata/download.py
@@ -3,7 +3,9 @@ import os, random
 from zope.interface import implements
 from twisted.python import log
 from twisted.internet import defer
+from twisted.internet.interfaces import IPushProducer, IConsumer
 from twisted.application import service
+from foolscap.eventual import eventually
 
 from allmydata.util import idlib, mathutil, hashutil
 from allmydata.util.assertutil import _assert
@@ -23,6 +25,9 @@ class BadPlaintextHashValue(Exception):
 class BadCrypttextHashValue(Exception):
     pass
 
+class DownloadStopped(Exception):
+    pass
+
 class Output:
     def __init__(self, downloadable, key, total_length):
         self.downloadable = downloadable
@@ -282,6 +287,7 @@ class SegmentDownloader:
         self.parent.bucket_failed(vbucket)
 
 class FileDownloader:
+    implements(IPushProducer)
     check_crypttext_hash = True
     check_plaintext_hash = True
 
@@ -295,7 +301,12 @@ class FileDownloader:
         self._size = u.size
         self._num_needed_shares = u.needed_shares
 
+        if IConsumer.providedBy(downloadable):
+            downloadable.registerProducer(self, True)
+        self._downloadable = downloadable
         self._output = Output(downloadable, u.key, self._size)
+        self._paused = False
+        self._stopped = False
 
         self.active_buckets = {} # k: shnum, v: bucket
         self._share_buckets = [] # list of (sharenum, bucket) tuples
@@ -311,8 +322,23 @@ class FileDownloader:
                                 "crypttext_hashtree": 0,
                                 }
 
+    def pauseProducing(self):
+        if self._paused:
+            return
+        self._paused = defer.Deferred()
+
+    def resumeProducing(self):
+        if self._paused:
+            p = self._paused
+            self._paused = None
+            eventually(p.callback, None)
+
+    def stopProducing(self):
+        log.msg("Download.stopProducing")
+        self._stopped = True
+
     def start(self):
-        log.msg("starting download [%s]" % idlib.b2a(self._storage_index))
+        log.msg("starting download [%s]" % idlib.b2a(self._storage_index)[:6])
 
         # first step: who should we download from?
         d = defer.maybeDeferred(self._get_all_shareholders)
@@ -324,6 +350,11 @@ class FileDownloader:
         d.addCallback(self._create_validated_buckets)
         # once we know that, we can download blocks from everybody
         d.addCallback(self._download_all_segments)
+        def _finished(res):
+            if IConsumer.providedBy(self._downloadable):
+                self._downloadable.unregisterProducer()
+            return res
+        d.addBoth(_finished)
         def _failed(why):
             self._output.fail(why)
             return why
@@ -541,20 +572,40 @@ class FileDownloader:
         d = defer.succeed(None)
         for segnum in range(self._total_segments-1):
             d.addCallback(self._download_segment, segnum)
+            # this pause, at the end of write, prevents pre-fetch from
+            # happening until the consumer is ready for more data.
+            d.addCallback(self._check_for_pause)
         d.addCallback(self._download_tail_segment, self._total_segments-1)
         return d
 
+    def _check_for_pause(self, res):
+        if self._paused:
+            d = defer.Deferred()
+            self._paused.addCallback(lambda ignored: d.callback(res))
+            return d
+        if self._stopped:
+            raise DownloadStopped("our Consumer called stopProducing()")
+        return res
+
     def _download_segment(self, res, segnum):
+        log.msg("downloading seg#%d of %d (%d%%)"
+                % (segnum, self._total_segments,
+                   100.0 * segnum / self._total_segments))
         # memory footprint: when the SegmentDownloader finishes pulling down
         # all shares, we have 1*segment_size of usage.
         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
+        # pause before using more memory
+        d.addCallback(self._check_for_pause)
         # while the codec does its job, we hit 2*segment_size
         d.addCallback(lambda (shares, shareids):
                       self._codec.decode(shares, shareids))
         # once the codec is done, we drop back to 1*segment_size, because
         # 'shares' goes out of scope. The memory usage is all in the
         # plaintext now, spread out into a bunch of tiny buffers.
+
+        # pause/check-for-stop just before writing, to honor stopProducing
+        d.addCallback(self._check_for_pause)
         def _done(buffers):
             # we start by joining all these buffers together into a single
             # string. This makes Output.write easier, since it wants to hash
@@ -571,10 +622,17 @@ class FileDownloader:
         return d
 
     def _download_tail_segment(self, res, segnum):
+        log.msg("downloading seg#%d of %d (%d%%)"
+                % (segnum, self._total_segments,
+                   100.0 * segnum / self._total_segments))
         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
         d = segmentdler.start()
+        # pause before using more memory
+        d.addCallback(self._check_for_pause)
         d.addCallback(lambda (shares, shareids):
                       self._tail_codec.decode(shares, shareids))
+        # pause/check-for-stop just before writing, to honor stopProducing
+        d.addCallback(self._check_for_pause)
         def _done(buffers):
             # trim off any padding added by the upload side
             segment = "".join(buffers)
@@ -589,11 +647,8 @@ class FileDownloader:
         return d
 
     def _done(self, res):
+        log.msg("download done [%s]" % idlib.b2a(self._storage_index)[:6])
         self._output.close()
-        log.msg("computed CRYPTTEXT_HASH: %s" %
-                idlib.b2a(self._output.crypttext_hash))
-        log.msg("computed PLAINTEXT_HASH: %s" %
-                idlib.b2a(self._output.plaintext_hash))
         if self.check_crypttext_hash:
             _assert(self._crypttext_hash == self._output.crypttext_hash,
                     "bad crypttext_hash: computed=%s, expected=%s" %
diff --git a/src/allmydata/webish.py b/src/allmydata/webish.py
index a1e99338..3cc6c83c 100644
--- a/src/allmydata/webish.py
+++ b/src/allmydata/webish.py
@@ -5,6 +5,7 @@ from twisted.application import service, strports, internet
 from twisted.web import static, resource, server, html, http
 from twisted.python import util, log
 from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
 from nevow import inevow, rend, loaders, appserver, url, tags as T
 from nevow.static import File as nevow_File # TODO: merge with static.File?
 from allmydata.util import fileutil
@@ -271,12 +272,18 @@ class Directory(rend.Page):
             return ""
 
 class WebDownloadTarget:
-    implements(IDownloadTarget)
+    implements(IDownloadTarget, IConsumer)
     def __init__(self, req, content_type, content_encoding):
         self._req = req
         self._content_type = content_type
         self._content_encoding = content_encoding
         self._opened = False
+        self._producer = None
+
+    def registerProducer(self, producer, streaming):
+        self._req.registerProducer(producer, streaming)
+    def unregisterProducer(self):
+        self._req.unregisterProducer()
 
     def open(self, size):
         self._opened = True
-- 
2.45.2