From 600196f57143911e8f832b06747ac9a8dc17cb91 Mon Sep 17 00:00:00 2001 From: Zooko O'Whielacronx Date: Thu, 8 Jan 2009 11:53:49 -0700 Subject: [PATCH] immutable: refactor download to do only download-and-decode, not decryption FileDownloader takes a verify cap and produces ciphertext, instead of taking a read cap and producing plaintext. FileDownloader does all integrity checking including the mandatory ciphertext hash tree and the optional ciphertext flat hash, rather than expecting its target to do some of that checking. Rename immutable.download.Output to immutable.download.DecryptingOutput. An instance of DecryptingOutput can be passed to FileDownloader to use as the latter's target. Text pushed to the DecryptingOutput is decrypted and then pushed to *its* target. DecryptingOutput satisfies the IConsumer interface, and if its target also satisfies IConsumer, then it forwards and pause/unpause signals to its producer (which is the FileDownloader). This patch also changes some logging code to use the new logging mixin class. Check integrity of a segment and decrypt the segment one block-sized buffer at a time instead of copying the buffers together into one segment-sized buffer (reduces peak memory usage, I think, and is probably a tad faster/less CPU, depending on your encoding parameters). Refactor FileDownloader so that processing of segments and of tail-segment share as much code is possible. FileDownloader and FileNode take caps as instances of URI (Python objects), not as strings. --- src/allmydata/client.py | 2 +- src/allmydata/immutable/download.py | 248 +++++++++++----------------- src/allmydata/immutable/filenode.py | 5 +- src/allmydata/test/test_encode.py | 3 +- src/allmydata/test/test_filenode.py | 6 +- 5 files changed, 103 insertions(+), 161 deletions(-) diff --git a/src/allmydata/client.py b/src/allmydata/client.py index 3395f7f6..d1465d6d 100644 --- a/src/allmydata/client.py +++ b/src/allmydata/client.py @@ -361,7 +361,7 @@ class Client(node.Node, pollmixin.PollMixin): else: key = base32.b2a(u.storage_index) cachefile = self.download_cache.get_file(key) - node = FileNode(u.to_string(), self, cachefile) # CHK + node = FileNode(u, self, cachefile) # CHK else: assert IMutableFileURI.providedBy(u), u node = MutableFileNode(self).init_from_uri(u) diff --git a/src/allmydata/immutable/download.py b/src/allmydata/immutable/download.py index 2ba3492d..f4921f53 100644 --- a/src/allmydata/immutable/download.py +++ b/src/allmydata/immutable/download.py @@ -42,75 +42,26 @@ class DownloadResults: self.timings = {} self.file_size = None -class Output: - def __init__(self, downloadable, key, total_length, log_parent, - download_status): +class DecryptingTarget(log.PrefixingLogMixin): + implements(IDownloadTarget, IConsumer) + def __init__(self, downloadable, key, _log_msg_id=None): + precondition(IDownloadTarget.providedBy(downloadable), downloadable) self.downloadable = downloadable self._decryptor = AES(key) - self._crypttext_hasher = hashutil.crypttext_hasher() - self.length = 0 - self.total_length = total_length - self._segment_number = 0 - self._crypttext_hash_tree = None - self._opened = False - self._log_parent = log_parent - self._status = download_status - self._status.set_progress(0.0) - - def log(self, *args, **kwargs): - if "parent" not in kwargs: - kwargs["parent"] = self._log_parent - if "facility" not in kwargs: - kwargs["facility"] = "download.output" - return log.msg(*args, **kwargs) - - def got_crypttext_hash_tree(self, crypttext_hash_tree): - self._crypttext_hash_tree = crypttext_hash_tree - - def write_segment(self, crypttext): - self.length += len(crypttext) - self._status.set_progress( float(self.length) / self.total_length ) - - # memory footprint: 'crypttext' is the only segment_size usage - # outstanding. While we decrypt it into 'plaintext', we hit - # 2*segment_size. - self._crypttext_hasher.update(crypttext) - if self._crypttext_hash_tree: - ch = hashutil.crypttext_segment_hasher() - ch.update(crypttext) - crypttext_leaves = {self._segment_number: ch.digest()} - self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s", - bytes=len(crypttext), - segnum=self._segment_number, hash=base32.b2a(ch.digest()), - level=log.NOISY) - self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves) - - plaintext = self._decryptor.process(crypttext) - del crypttext - - # now we're back down to 1*segment_size. - self._segment_number += 1 - # We're still at 1*segment_size. The Downloadable is responsible for - # any memory usage beyond this. - if not self._opened: - self._opened = True - self.downloadable.open(self.total_length) + prefix = str(downloadable) + log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix) + def registerProducer(self, producer, streaming): + if IConsumer.providedBy(self.downloadable): + self.downloadable.registerProducer(producer, streaming) + def unregisterProducer(self): + if IConsumer.providedBy(self.downloadable): + self.downloadable.unregisterProducer() + def write(self, ciphertext): + plaintext = self._decryptor.process(ciphertext) self.downloadable.write(plaintext) - - def fail(self, why): - # this is really unusual, and deserves maximum forensics - if why.check(DownloadStopped): - # except DownloadStopped just means the consumer aborted the - # download, not so scary - self.log("download stopped", level=log.UNUSUAL) - else: - self.log("download failed!", failure=why, - level=log.SCARY, umid="lp1vaQ") - self.downloadable.fail(why) - + def open(self, size): + self.downloadable.open(size) def close(self): - self.crypttext_hash = self._crypttext_hasher.digest() - self.log("download finished, closing IDownloadable", level=log.NOISY) self.downloadable.close() def finish(self): return self.downloadable.finish() @@ -653,11 +604,14 @@ class DownloadStatus: self.results = value class FileDownloader(log.PrefixingLogMixin): + """ I download shares, check their integrity, then decode them, check the integrity of the + resulting ciphertext, then and write it to my target. """ implements(IPushProducer) _status = None def __init__(self, client, u, downloadable): - precondition(isinstance(u, uri.CHKFileURI), u) + precondition(IVerifierURI.providedBy(u), u) + precondition(IDownloadTarget.providedBy(downloadable), downloadable) prefix=base32.b2a_l(u.get_storage_index()[:8], 60) log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix) @@ -691,8 +645,7 @@ class FileDownloader(log.PrefixingLogMixin): if IConsumer.providedBy(downloadable): downloadable.registerProducer(self, True) self._downloadable = downloadable - self._output = Output(downloadable, u.key, self._uri.size, self._parentmsgid, - self._status) + self._opened = False self.active_buckets = {} # k: shnum, v: bucket self._share_buckets = [] # list of (sharenum, bucket) tuples @@ -700,8 +653,15 @@ class FileDownloader(log.PrefixingLogMixin): self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, } - self._share_hash_tree = None - self._crypttext_hash_tree = None + self._ciphertext_hasher = hashutil.crypttext_hasher() + + self._bytes_done = 0 + self._status.set_progress(float(self._bytes_done)/self._uri.size) + + # _got_uri_extension() will create the following: + # self._crypttext_hash_tree + # self._share_hash_tree + # self._current_segnum = 0 def pauseProducing(self): if self._paused: @@ -730,7 +690,6 @@ class FileDownloader(log.PrefixingLogMixin): self._status.set_active(False) def start(self): - assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri)) self.log("starting download") # first step: who should we download from? @@ -754,7 +713,12 @@ class FileDownloader(log.PrefixingLogMixin): if self._status: self._status.set_status("Failed") self._status.set_active(False) - self._output.fail(why) + if why.check(DownloadStopped): + # DownloadStopped just means the consumer aborted the download; not so scary. + self.log("download stopped", level=log.UNUSUAL) + else: + # This is really unusual, and deserves maximum forensics. + self.log("download failed!", failure=why, level=log.SCARY, umid="lp1vaQ") return why d.addErrback(_failed) d.addCallback(self._done) @@ -818,7 +782,6 @@ class FileDownloader(log.PrefixingLogMixin): del self._share_vbuckets[shnum] def _got_all_shareholders(self, res): - assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri)) if self._results: now = time.time() self._results.timings["peer_selection"] = now - self._started @@ -832,7 +795,6 @@ class FileDownloader(log.PrefixingLogMixin): # "vb is %s but should be a ValidatedReadBucketProxy" % (vb,) def _obtain_uri_extension(self, ignored): - assert isinstance(self._uri, uri.CHKFileURI), self._uri # all shareholders are supposed to have a copy of uri_extension, and # all are supposed to be identical. We compute the hash of the data # that comes back, and compare it against the version in our URI. If @@ -844,7 +806,7 @@ class FileDownloader(log.PrefixingLogMixin): vups = [] for sharenum, bucket in self._share_buckets: - vups.append(ValidatedExtendedURIProxy(bucket, self._uri.get_verify_cap(), self._fetch_failures)) + vups.append(ValidatedExtendedURIProxy(bucket, self._uri, self._fetch_failures)) vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid) d = vto.start() @@ -886,7 +848,6 @@ class FileDownloader(log.PrefixingLogMixin): def _got_crypttext_hash_tree(res): # Good -- the self._crypttext_hash_tree that we passed to vchtp is now populated # with hashes. - self._output.got_crypttext_hash_tree(self._crypttext_hash_tree) if self._results: elapsed = time.time() - _get_crypttext_hash_tree_started self._results.timings["hashtrees"] = elapsed @@ -896,7 +857,6 @@ class FileDownloader(log.PrefixingLogMixin): def _activate_enough_buckets(self): """either return a mapping from shnum to a ValidatedReadBucketProxy that can provide data for that share, or raise NotEnoughSharesError""" - assert isinstance(self._uri, uri.CHKFileURI), self._uri while len(self.active_buckets) < self._uri.needed_shares: # need some more @@ -934,12 +894,11 @@ class FileDownloader(log.PrefixingLogMixin): self._started_fetching = time.time() d = defer.succeed(None) - for segnum in range(self._vup.num_segments-1): + for segnum in range(self._vup.num_segments): d.addCallback(self._download_segment, segnum) # this pause, at the end of write, prevents pre-fetch from # happening until the consumer is ready for more data. d.addCallback(self._check_for_pause) - d.addCallback(self._download_tail_segment, self._vup.num_segments-1) return d def _check_for_pause(self, res): @@ -952,7 +911,6 @@ class FileDownloader(log.PrefixingLogMixin): return res def _download_segment(self, res, segnum): - assert isinstance(self._uri, uri.CHKFileURI), self._uri if self._status: self._status.set_status("Downloading segment %d of %d" % (segnum+1, self._vup.num_segments)) @@ -979,8 +937,11 @@ class FileDownloader(log.PrefixingLogMixin): return res if self._results: d.addCallback(_started_decode) - d.addCallback(lambda (shares, shareids): - self._codec.decode(shares, shareids)) + if segnum + 1 == self._vup.num_segments: + codec = self._tail_codec + else: + codec = self._codec + d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids)) # once the codec is done, we drop back to 1*segment_size, because # 'shares' goes out of scope. The memory usage is all in the # plaintext now, spread out into a bunch of tiny buffers. @@ -993,91 +954,66 @@ class FileDownloader(log.PrefixingLogMixin): # pause/check-for-stop just before writing, to honor stopProducing d.addCallback(self._check_for_pause) - def _done(buffers): - # we start by joining all these buffers together into a single - # string. This makes Output.write easier, since it wants to hash - # data one segment at a time anyways, and doesn't impact our - # memory footprint since we're already peaking at 2*segment_size - # inside the codec a moment ago. - segment = "".join(buffers) - del buffers - # we're down to 1*segment_size right now, but write_segment() - # will decrypt a copy of the segment internally, which will push - # us up to 2*segment_size while it runs. - started_decrypt = time.time() - self._output.write_segment(segment) - if self._results: - elapsed = time.time() - started_decrypt - self._results.timings["cumulative_decrypt"] += elapsed - d.addCallback(_done) + d.addCallback(self._got_segment) return d - def _download_tail_segment(self, res, segnum): - assert isinstance(self._uri, uri.CHKFileURI), self._uri - self.log("downloading seg#%d of %d (%d%%)" - % (segnum, self._vup.num_segments, - 100.0 * segnum / self._vup.num_segments)) - segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares, - self._results) - started = time.time() - d = segmentdler.start() - def _finished_fetching(res): - elapsed = time.time() - started - self._results.timings["cumulative_fetch"] += elapsed - return res - if self._results: - d.addCallback(_finished_fetching) - # pause before using more memory - d.addCallback(self._check_for_pause) - def _started_decode(res): - self._started_decode = time.time() - return res - if self._results: - d.addCallback(_started_decode) - d.addCallback(lambda (shares, shareids): - self._tail_codec.decode(shares, shareids)) - def _finished_decode(res): - elapsed = time.time() - self._started_decode - self._results.timings["cumulative_decode"] += elapsed - return res + def _got_segment(self, buffers): + precondition(self._crypttext_hash_tree) + started_decrypt = time.time() + self._status.set_progress(float(self._current_segnum)/self._uri.size) + + if self._current_segnum + 1 == self._vup.num_segments: + # This is the last segment. + # Trim off any padding added by the upload side. We never send empty segments. If + # the data was an exact multiple of the segment size, the last segment will be full. + tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._uri.needed_shares) + num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size) + # Remove buffers which don't contain any part of the tail. + del buffers[num_buffers_used:] + # Remove the past-the-tail-part of the last buffer. + tail_in_last_buf = self._vup.tail_data_size % tail_buf_size + if tail_in_last_buf == 0: + tail_in_last_buf = tail_buf_size + buffers[-1] = buffers[-1][:tail_in_last_buf] + + # First compute the hash of this segment and check that it fits. + ch = hashutil.crypttext_segment_hasher() + for buffer in buffers: + self._ciphertext_hasher.update(buffer) + ch.update(buffer) + self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()}) + + # Then write this segment to the target. + if not self._opened: + self._opened = True + self._downloadable.open(self._uri.size) + + for buffer in buffers: + self._downloadable.write(buffer) + self._bytes_done += len(buffer) + + self._status.set_progress(float(self._bytes_done)/self._uri.size) + self._current_segnum += 1 + if self._results: - d.addCallback(_finished_decode) - # pause/check-for-stop just before writing, to honor stopProducing - d.addCallback(self._check_for_pause) - def _done(buffers): - # trim off any padding added by the upload side - segment = "".join(buffers) - del buffers - # we never send empty segments. If the data was an exact multiple - # of the segment size, the last segment will be full. - pad_size = mathutil.pad_size(self._uri.size, self._vup.segment_size) - tail_size = self._vup.segment_size - pad_size - segment = segment[:tail_size] - started_decrypt = time.time() - self._output.write_segment(segment) - if self._results: - elapsed = time.time() - started_decrypt - self._results.timings["cumulative_decrypt"] += elapsed - d.addCallback(_done) - return d + elapsed = time.time() - started_decrypt + self._results.timings["cumulative_decrypt"] += elapsed def _done(self, res): - assert isinstance(self._uri, uri.CHKFileURI), self._uri self.log("download done") if self._results: now = time.time() self._results.timings["total"] = now - self._started self._results.timings["segments"] = now - self._started_fetching - self._output.close() if self._vup.crypttext_hash: - _assert(self._vup.crypttext_hash == self._output.crypttext_hash, + _assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(), "bad crypttext_hash: computed=%s, expected=%s" % - (base32.b2a(self._output.crypttext_hash), + (base32.b2a(self._ciphertext_hasher.digest()), base32.b2a(self._vup.crypttext_hash))) - _assert(self._output.length == self._uri.size, - got=self._output.length, expected=self._uri.size) - return self._output.finish() - + _assert(self._bytes_done == self._uri.size, self._bytes_done, self._uri.size) + self._status.set_progress(1) + self._downloadable.close() + return self._downloadable.finish() def get_download_status(self): return self._status @@ -1200,7 +1136,9 @@ class Downloader(service.MultiService): # include LIT files self.stats_provider.count('downloader.files_downloaded', 1) self.stats_provider.count('downloader.bytes_downloaded', u.get_size()) - dl = FileDownloader(self.parent, u, t) + + target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id) + dl = FileDownloader(self.parent, u.get_verify_cap(), target) self._add_download(dl) d = dl.start() return d diff --git a/src/allmydata/immutable/filenode.py b/src/allmydata/immutable/filenode.py index 4b115041..de37b2f8 100644 --- a/src/allmydata/immutable/filenode.py +++ b/src/allmydata/immutable/filenode.py @@ -9,6 +9,7 @@ from foolscap.eventual import eventually from allmydata.interfaces import IFileNode, IFileURI, ICheckable, \ IDownloadTarget from allmydata.util import log, base32 +from allmydata.util.assertutil import precondition from allmydata import uri as urimodule from allmydata.immutable.checker import Checker from allmydata.check_results import CheckAndRepairResults @@ -19,6 +20,7 @@ class _ImmutableFileNodeBase(object): implements(IFileNode, ICheckable) def __init__(self, uri, client): + precondition(urimodule.IImmutableFileURI.providedBy(uri), uri) self.u = IFileURI(uri) self._client = client @@ -172,7 +174,7 @@ class FileNode(_ImmutableFileNodeBase, log.PrefixingLogMixin): def __init__(self, uri, client, cachefile): _ImmutableFileNodeBase.__init__(self, uri, client) self.download_cache = DownloadCache(self, cachefile) - prefix = urimodule.from_string(uri).get_verify_cap().to_string() + prefix = uri.get_verify_cap().to_string() log.PrefixingLogMixin.__init__(self, "allmydata.immutable.filenode", prefix=prefix) self.log("starting", level=log.OPERATIONAL) @@ -250,6 +252,7 @@ class LiteralProducer: class LiteralFileNode(_ImmutableFileNodeBase): def __init__(self, uri, client): + precondition(urimodule.IImmutableFileURI.providedBy(uri), uri) _ImmutableFileNodeBase.__init__(self, uri, client) def get_uri(self): diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index 0954530d..e4133ff2 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -493,7 +493,8 @@ class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin): client = FakeClient() if not target: target = download.Data() - fd = download.FileDownloader(client, u, target) + target = download.DecryptingTarget(target, u.key) + fd = download.FileDownloader(client, u.get_verify_cap(), target) # we manually cycle the FileDownloader through a number of steps that # would normally be sequenced by a Deferred chain in diff --git a/src/allmydata/test/test_filenode.py b/src/allmydata/test/test_filenode.py index abf82220..48d1b140 100644 --- a/src/allmydata/test/test_filenode.py +++ b/src/allmydata/test/test_filenode.py @@ -27,8 +27,8 @@ class Node(unittest.TestCase): size=1000) c = FakeClient() cf = cachedir.CacheFile("none") - fn1 = filenode.FileNode(u.to_string(), c, cf) - fn2 = filenode.FileNode(u.to_string(), c, cf) + fn1 = filenode.FileNode(u, c, cf) + fn2 = filenode.FileNode(u, c, cf) self.failUnlessEqual(fn1, fn2) self.failIfEqual(fn1, "I am not a filenode") self.failIfEqual(fn1, NotANode()) @@ -49,7 +49,7 @@ class Node(unittest.TestCase): u = uri.LiteralFileURI(data=DATA) c = None fn1 = filenode.LiteralFileNode(u, c) - fn2 = filenode.LiteralFileNode(u.to_string(), c) + fn2 = filenode.LiteralFileNode(u, c) self.failUnlessEqual(fn1, fn2) self.failIfEqual(fn1, "I am not a filenode") self.failIfEqual(fn1, NotANode()) -- 2.45.2