]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
immutable: refactor download to do only download-and-decode, not decryption
authorZooko O'Whielacronx <zooko@zooko.com>
Thu, 8 Jan 2009 18:53:49 +0000 (11:53 -0700)
committerZooko O'Whielacronx <zooko@zooko.com>
Thu, 8 Jan 2009 18:53:49 +0000 (11:53 -0700)
FileDownloader takes a verify cap and produces ciphertext, instead of taking a read cap and producing plaintext.
FileDownloader does all integrity checking including the mandatory ciphertext hash tree and the optional ciphertext flat hash, rather than expecting its target to do some of that checking.
Rename immutable.download.Output to immutable.download.DecryptingOutput. An instance of DecryptingOutput can be passed to FileDownloader to use as the latter's target.  Text pushed to the DecryptingOutput is decrypted and then pushed to *its* target.
DecryptingOutput satisfies the IConsumer interface, and if its target also satisfies IConsumer, then it forwards and pause/unpause signals to its producer (which is the FileDownloader).
This patch also changes some logging code to use the new logging mixin class.
Check integrity of a segment and decrypt the segment one block-sized buffer at a time instead of copying the buffers together into one segment-sized buffer (reduces peak memory usage, I think, and is probably a tad faster/less CPU, depending on your encoding parameters).
Refactor FileDownloader so that processing of segments and of tail-segment share as much code is possible.
FileDownloader and FileNode take caps as instances of URI (Python objects), not as strings.

src/allmydata/client.py
src/allmydata/immutable/download.py
src/allmydata/immutable/filenode.py
src/allmydata/test/test_encode.py
src/allmydata/test/test_filenode.py

index 3395f7f6d1f0c4f4e0a696f172203d78068ba835..d1465d6d5582789483e373d92bc4f7b22cc014a7 100644 (file)
@@ -361,7 +361,7 @@ class Client(node.Node, pollmixin.PollMixin):
                 else:
                     key = base32.b2a(u.storage_index)
                     cachefile = self.download_cache.get_file(key)
-                    node = FileNode(u.to_string(), self, cachefile) # CHK
+                    node = FileNode(u, self, cachefile) # CHK
             else:
                 assert IMutableFileURI.providedBy(u), u
                 node = MutableFileNode(self).init_from_uri(u)
index 2ba3492dcbdf3cbdc522d92d4fa21a2d53f7d2a0..f4921f53f00493b70ac8c35de22f7f54830b7cfe 100644 (file)
@@ -42,75 +42,26 @@ class DownloadResults:
         self.timings = {}
         self.file_size = None
 
-class Output:
-    def __init__(self, downloadable, key, total_length, log_parent,
-                 download_status):
+class DecryptingTarget(log.PrefixingLogMixin):
+    implements(IDownloadTarget, IConsumer)
+    def __init__(self, downloadable, key, _log_msg_id=None):
+        precondition(IDownloadTarget.providedBy(downloadable), downloadable)
         self.downloadable = downloadable
         self._decryptor = AES(key)
-        self._crypttext_hasher = hashutil.crypttext_hasher()
-        self.length = 0
-        self.total_length = total_length
-        self._segment_number = 0
-        self._crypttext_hash_tree = None
-        self._opened = False
-        self._log_parent = log_parent
-        self._status = download_status
-        self._status.set_progress(0.0)
-
-    def log(self, *args, **kwargs):
-        if "parent" not in kwargs:
-            kwargs["parent"] = self._log_parent
-        if "facility" not in kwargs:
-            kwargs["facility"] = "download.output"
-        return log.msg(*args, **kwargs)
-
-    def got_crypttext_hash_tree(self, crypttext_hash_tree):
-        self._crypttext_hash_tree = crypttext_hash_tree
-
-    def write_segment(self, crypttext):
-        self.length += len(crypttext)
-        self._status.set_progress( float(self.length) / self.total_length )
-
-        # memory footprint: 'crypttext' is the only segment_size usage
-        # outstanding. While we decrypt it into 'plaintext', we hit
-        # 2*segment_size.
-        self._crypttext_hasher.update(crypttext)
-        if self._crypttext_hash_tree:
-            ch = hashutil.crypttext_segment_hasher()
-            ch.update(crypttext)
-            crypttext_leaves = {self._segment_number: ch.digest()}
-            self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
-                     bytes=len(crypttext),
-                     segnum=self._segment_number, hash=base32.b2a(ch.digest()),
-                     level=log.NOISY)
-            self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
-
-        plaintext = self._decryptor.process(crypttext)
-        del crypttext
-
-        # now we're back down to 1*segment_size.
-        self._segment_number += 1
-        # We're still at 1*segment_size. The Downloadable is responsible for
-        # any memory usage beyond this.
-        if not self._opened:
-            self._opened = True
-            self.downloadable.open(self.total_length)
+        prefix = str(downloadable)
+        log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix)
+    def registerProducer(self, producer, streaming):
+        if IConsumer.providedBy(self.downloadable):
+            self.downloadable.registerProducer(producer, streaming)
+    def unregisterProducer(self):
+        if IConsumer.providedBy(self.downloadable):
+            self.downloadable.unregisterProducer()
+    def write(self, ciphertext):
+        plaintext = self._decryptor.process(ciphertext)
         self.downloadable.write(plaintext)
-
-    def fail(self, why):
-        # this is really unusual, and deserves maximum forensics
-        if why.check(DownloadStopped):
-            # except DownloadStopped just means the consumer aborted the
-            # download, not so scary
-            self.log("download stopped", level=log.UNUSUAL)
-        else:
-            self.log("download failed!", failure=why,
-                     level=log.SCARY, umid="lp1vaQ")
-        self.downloadable.fail(why)
-
+    def open(self, size):
+        self.downloadable.open(size)
     def close(self):
-        self.crypttext_hash = self._crypttext_hasher.digest()
-        self.log("download finished, closing IDownloadable", level=log.NOISY)
         self.downloadable.close()
     def finish(self):
         return self.downloadable.finish()
@@ -653,11 +604,14 @@ class DownloadStatus:
         self.results = value
 
 class FileDownloader(log.PrefixingLogMixin):
+    """ I download shares, check their integrity, then decode them, check the integrity of the
+    resulting ciphertext, then and write it to my target. """
     implements(IPushProducer)
     _status = None
 
     def __init__(self, client, u, downloadable):
-        precondition(isinstance(u, uri.CHKFileURI), u)
+        precondition(IVerifierURI.providedBy(u), u)
+        precondition(IDownloadTarget.providedBy(downloadable), downloadable)
 
         prefix=base32.b2a_l(u.get_storage_index()[:8], 60)
         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
@@ -691,8 +645,7 @@ class FileDownloader(log.PrefixingLogMixin):
         if IConsumer.providedBy(downloadable):
             downloadable.registerProducer(self, True)
         self._downloadable = downloadable
-        self._output = Output(downloadable, u.key, self._uri.size, self._parentmsgid,
-                              self._status)
+        self._opened = False
 
         self.active_buckets = {} # k: shnum, v: bucket
         self._share_buckets = [] # list of (sharenum, bucket) tuples
@@ -700,8 +653,15 @@ class FileDownloader(log.PrefixingLogMixin):
 
         self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
 
-        self._share_hash_tree = None
-        self._crypttext_hash_tree = None
+        self._ciphertext_hasher = hashutil.crypttext_hasher()
+
+        self._bytes_done = 0
+        self._status.set_progress(float(self._bytes_done)/self._uri.size)
+
+        # _got_uri_extension() will create the following:
+        # self._crypttext_hash_tree
+        # self._share_hash_tree
+        # self._current_segnum = 0
 
     def pauseProducing(self):
         if self._paused:
@@ -730,7 +690,6 @@ class FileDownloader(log.PrefixingLogMixin):
             self._status.set_active(False)
 
     def start(self):
-        assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
         self.log("starting download")
 
         # first step: who should we download from?
@@ -754,7 +713,12 @@ class FileDownloader(log.PrefixingLogMixin):
             if self._status:
                 self._status.set_status("Failed")
                 self._status.set_active(False)
-            self._output.fail(why)
+            if why.check(DownloadStopped):
+                # DownloadStopped just means the consumer aborted the download; not so scary.
+                self.log("download stopped", level=log.UNUSUAL)
+            else:
+                # This is really unusual, and deserves maximum forensics.
+                self.log("download failed!", failure=why, level=log.SCARY, umid="lp1vaQ")
             return why
         d.addErrback(_failed)
         d.addCallback(self._done)
@@ -818,7 +782,6 @@ class FileDownloader(log.PrefixingLogMixin):
             del self._share_vbuckets[shnum]
 
     def _got_all_shareholders(self, res):
-        assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
         if self._results:
             now = time.time()
             self._results.timings["peer_selection"] = now - self._started
@@ -832,7 +795,6 @@ class FileDownloader(log.PrefixingLogMixin):
         #               "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
 
     def _obtain_uri_extension(self, ignored):
-        assert isinstance(self._uri, uri.CHKFileURI), self._uri
         # all shareholders are supposed to have a copy of uri_extension, and
         # all are supposed to be identical. We compute the hash of the data
         # that comes back, and compare it against the version in our URI. If
@@ -844,7 +806,7 @@ class FileDownloader(log.PrefixingLogMixin):
 
         vups = []
         for sharenum, bucket in self._share_buckets:
-            vups.append(ValidatedExtendedURIProxy(bucket, self._uri.get_verify_cap(), self._fetch_failures))
+            vups.append(ValidatedExtendedURIProxy(bucket, self._uri, self._fetch_failures))
         vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid)
         d = vto.start()
 
@@ -886,7 +848,6 @@ class FileDownloader(log.PrefixingLogMixin):
         def _got_crypttext_hash_tree(res):
             # Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated
             # with hashes.
-            self._output.got_crypttext_hash_tree(self._crypttext_hash_tree)
             if self._results:
                 elapsed = time.time() - _get_crypttext_hash_tree_started
                 self._results.timings["hashtrees"] = elapsed
@@ -896,7 +857,6 @@ class FileDownloader(log.PrefixingLogMixin):
     def _activate_enough_buckets(self):
         """either return a mapping from shnum to a ValidatedReadBucketProxy that can
         provide data for that share, or raise NotEnoughSharesError"""
-        assert isinstance(self._uri, uri.CHKFileURI), self._uri
 
         while len(self.active_buckets) < self._uri.needed_shares:
             # need some more
@@ -934,12 +894,11 @@ class FileDownloader(log.PrefixingLogMixin):
         self._started_fetching = time.time()
 
         d = defer.succeed(None)
-        for segnum in range(self._vup.num_segments-1):
+        for segnum in range(self._vup.num_segments):
             d.addCallback(self._download_segment, segnum)
             # this pause, at the end of write, prevents pre-fetch from
             # happening until the consumer is ready for more data.
             d.addCallback(self._check_for_pause)
-        d.addCallback(self._download_tail_segment, self._vup.num_segments-1)
         return d
 
     def _check_for_pause(self, res):
@@ -952,7 +911,6 @@ class FileDownloader(log.PrefixingLogMixin):
         return res
 
     def _download_segment(self, res, segnum):
-        assert isinstance(self._uri, uri.CHKFileURI), self._uri
         if self._status:
             self._status.set_status("Downloading segment %d of %d" %
                                     (segnum+1, self._vup.num_segments))
@@ -979,8 +937,11 @@ class FileDownloader(log.PrefixingLogMixin):
             return res
         if self._results:
             d.addCallback(_started_decode)
-        d.addCallback(lambda (shares, shareids):
-                      self._codec.decode(shares, shareids))
+        if segnum + 1 == self._vup.num_segments:
+            codec = self._tail_codec
+        else:
+            codec = self._codec
+        d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids))
         # once the codec is done, we drop back to 1*segment_size, because
         # 'shares' goes out of scope. The memory usage is all in the
         # plaintext now, spread out into a bunch of tiny buffers.
@@ -993,91 +954,66 @@ class FileDownloader(log.PrefixingLogMixin):
 
         # pause/check-for-stop just before writing, to honor stopProducing
         d.addCallback(self._check_for_pause)
-        def _done(buffers):
-            # we start by joining all these buffers together into a single
-            # string. This makes Output.write easier, since it wants to hash
-            # data one segment at a time anyways, and doesn't impact our
-            # memory footprint since we're already peaking at 2*segment_size
-            # inside the codec a moment ago.
-            segment = "".join(buffers)
-            del buffers
-            # we're down to 1*segment_size right now, but write_segment()
-            # will decrypt a copy of the segment internally, which will push
-            # us up to 2*segment_size while it runs.
-            started_decrypt = time.time()
-            self._output.write_segment(segment)
-            if self._results:
-                elapsed = time.time() - started_decrypt
-                self._results.timings["cumulative_decrypt"] += elapsed
-        d.addCallback(_done)
+        d.addCallback(self._got_segment)
         return d
 
-    def _download_tail_segment(self, res, segnum):
-        assert isinstance(self._uri, uri.CHKFileURI), self._uri
-        self.log("downloading seg#%d of %d (%d%%)"
-                 % (segnum, self._vup.num_segments,
-                    100.0 * segnum / self._vup.num_segments))
-        segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares,
-                                        self._results)
-        started = time.time()
-        d = segmentdler.start()
-        def _finished_fetching(res):
-            elapsed = time.time() - started
-            self._results.timings["cumulative_fetch"] += elapsed
-            return res
-        if self._results:
-            d.addCallback(_finished_fetching)
-        # pause before using more memory
-        d.addCallback(self._check_for_pause)
-        def _started_decode(res):
-            self._started_decode = time.time()
-            return res
-        if self._results:
-            d.addCallback(_started_decode)
-        d.addCallback(lambda (shares, shareids):
-                      self._tail_codec.decode(shares, shareids))
-        def _finished_decode(res):
-            elapsed = time.time() - self._started_decode
-            self._results.timings["cumulative_decode"] += elapsed
-            return res
+    def _got_segment(self, buffers):
+        precondition(self._crypttext_hash_tree)
+        started_decrypt = time.time()
+        self._status.set_progress(float(self._current_segnum)/self._uri.size)
+
+        if self._current_segnum + 1 == self._vup.num_segments:
+            # This is the last segment.
+            # Trim off any padding added by the upload side.  We never send empty segments. If
+            # the data was an exact multiple of the segment size, the last segment will be full.
+            tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._uri.needed_shares)
+            num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size)
+            # Remove buffers which don't contain any part of the tail.
+            del buffers[num_buffers_used:]
+            # Remove the past-the-tail-part of the last buffer.
+            tail_in_last_buf = self._vup.tail_data_size % tail_buf_size
+            if tail_in_last_buf == 0:
+                tail_in_last_buf = tail_buf_size
+            buffers[-1] = buffers[-1][:tail_in_last_buf]
+
+        # First compute the hash of this segment and check that it fits.
+        ch = hashutil.crypttext_segment_hasher()
+        for buffer in buffers:
+            self._ciphertext_hasher.update(buffer)
+            ch.update(buffer)
+        self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()})
+
+        # Then write this segment to the target.
+        if not self._opened:
+            self._opened = True
+            self._downloadable.open(self._uri.size)
+
+        for buffer in buffers:
+            self._downloadable.write(buffer)
+            self._bytes_done += len(buffer)
+
+        self._status.set_progress(float(self._bytes_done)/self._uri.size)
+        self._current_segnum += 1
+
         if self._results:
-            d.addCallback(_finished_decode)
-        # pause/check-for-stop just before writing, to honor stopProducing
-        d.addCallback(self._check_for_pause)
-        def _done(buffers):
-            # trim off any padding added by the upload side
-            segment = "".join(buffers)
-            del buffers
-            # we never send empty segments. If the data was an exact multiple
-            # of the segment size, the last segment will be full.
-            pad_size = mathutil.pad_size(self._uri.size, self._vup.segment_size)
-            tail_size = self._vup.segment_size - pad_size
-            segment = segment[:tail_size]
-            started_decrypt = time.time()
-            self._output.write_segment(segment)
-            if self._results:
-                elapsed = time.time() - started_decrypt
-                self._results.timings["cumulative_decrypt"] += elapsed
-        d.addCallback(_done)
-        return d
+            elapsed = time.time() - started_decrypt
+            self._results.timings["cumulative_decrypt"] += elapsed
 
     def _done(self, res):
-        assert isinstance(self._uri, uri.CHKFileURI), self._uri
         self.log("download done")
         if self._results:
             now = time.time()
             self._results.timings["total"] = now - self._started
             self._results.timings["segments"] = now - self._started_fetching
-        self._output.close()
         if self._vup.crypttext_hash:
-            _assert(self._vup.crypttext_hash == self._output.crypttext_hash,
+            _assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(),
                     "bad crypttext_hash: computed=%s, expected=%s" %
-                    (base32.b2a(self._output.crypttext_hash),
+                    (base32.b2a(self._ciphertext_hasher.digest()),
                      base32.b2a(self._vup.crypttext_hash)))
-        _assert(self._output.length == self._uri.size,
-                got=self._output.length, expected=self._uri.size)
-        return self._output.finish()
-
+        _assert(self._bytes_done == self._uri.size, self._bytes_done, self._uri.size)
+        self._status.set_progress(1)
+        self._downloadable.close()
+        return self._downloadable.finish()
     def get_download_status(self):
         return self._status
 
@@ -1200,7 +1136,9 @@ class Downloader(service.MultiService):
             # include LIT files
             self.stats_provider.count('downloader.files_downloaded', 1)
             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
-        dl = FileDownloader(self.parent, u, t)
+
+        target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id)
+        dl = FileDownloader(self.parent, u.get_verify_cap(), target)
         self._add_download(dl)
         d = dl.start()
         return d
index 4b1150417f981f8322c830932d5eb91a463f8508..de37b2f841c26a2dd41f2a10d08b8f062f30d61a 100644 (file)
@@ -9,6 +9,7 @@ from foolscap.eventual import eventually
 from allmydata.interfaces import IFileNode, IFileURI, ICheckable, \
      IDownloadTarget
 from allmydata.util import log, base32
+from allmydata.util.assertutil import precondition
 from allmydata import uri as urimodule
 from allmydata.immutable.checker import Checker
 from allmydata.check_results import CheckAndRepairResults
@@ -19,6 +20,7 @@ class _ImmutableFileNodeBase(object):
     implements(IFileNode, ICheckable)
 
     def __init__(self, uri, client):
+        precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
         self.u = IFileURI(uri)
         self._client = client
 
@@ -172,7 +174,7 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin):
     def __init__(self, uri, client, cachefile):
         _ImmutableFileNodeBase.__init__(self, uri, client)
         self.download_cache = DownloadCache(self, cachefile)
-        prefix = urimodule.from_string(uri).get_verify_cap().to_string()
+        prefix = uri.get_verify_cap().to_string()
         log.PrefixingLogMixin.__init__(self, "allmydata.immutable.filenode", prefix=prefix)
         self.log("starting", level=log.OPERATIONAL)
 
@@ -250,6 +252,7 @@ class LiteralProducer:
 class LiteralFileNode(_ImmutableFileNodeBase):
 
     def __init__(self, uri, client):
+        precondition(urimodule.IImmutableFileURI.providedBy(uri), uri)
         _ImmutableFileNodeBase.__init__(self, uri, client)
 
     def get_uri(self):
index 0954530ded9973875ea68ce4d66ebf5e2d39add4..e4133ff2cb7d844ea49328efa55ad2e6dfd6b54a 100644 (file)
@@ -493,7 +493,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
         client = FakeClient()
         if not target:
             target = download.Data()
-        fd = download.FileDownloader(client, u, target)
+        target = download.DecryptingTarget(target, u.key)
+        fd = download.FileDownloader(client, u.get_verify_cap(), target)
 
         # we manually cycle the FileDownloader through a number of steps that
         # would normally be sequenced by a Deferred chain in
index abf82220fd7e8470882557b9e23c99ef81d3704c..48d1b14036266acbb50a1e519264c852054a71f9 100644 (file)
@@ -27,8 +27,8 @@ class Node(unittest.TestCase):
                            size=1000)
         c = FakeClient()
         cf = cachedir.CacheFile("none")
-        fn1 = filenode.FileNode(u.to_string(), c, cf)
-        fn2 = filenode.FileNode(u.to_string(), c, cf)
+        fn1 = filenode.FileNode(u, c, cf)
+        fn2 = filenode.FileNode(u, c, cf)
         self.failUnlessEqual(fn1, fn2)
         self.failIfEqual(fn1, "I am not a filenode")
         self.failIfEqual(fn1, NotANode())
@@ -49,7 +49,7 @@ class Node(unittest.TestCase):
         u = uri.LiteralFileURI(data=DATA)
         c = None
         fn1 = filenode.LiteralFileNode(u, c)
-        fn2 = filenode.LiteralFileNode(u.to_string(), c)
+        fn2 = filenode.LiteralFileNode(u, c)
         self.failUnlessEqual(fn1, fn2)
         self.failIfEqual(fn1, "I am not a filenode")
         self.failIfEqual(fn1, NotANode())