From: Kevan Carstensen Date: Tue, 2 Aug 2011 01:35:24 +0000 (-0700) Subject: mutable/retrieve: rework the mutable downloader to handle multiple-segment files X-Git-Tag: trac-5200~33 X-Git-Url: https://git.rkrishnan.org/pf/content/%22news.html/frontends/provisioning?a=commitdiff_plain;h=ac3b2647dd2c45cd1ddbf5b130ee5a780c66c73b;p=tahoe-lafs%2Ftahoe-lafs.git mutable/retrieve: rework the mutable downloader to handle multiple-segment files The downloader needs substantial reworking to handle multiple segment mutable files, which it needs to handle for MDMF. --- diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py index 257cc5f3..166eec48 100644 --- a/src/allmydata/mutable/retrieve.py +++ b/src/allmydata/mutable/retrieve.py @@ -1,12 +1,14 @@ -import struct, time +import time from itertools import count from zope.interface import implements from twisted.internet import defer from twisted.python import failure -from foolscap.api import DeadReferenceError, eventually, fireEventually -from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError -from allmydata.util import hashutil, idlib, log +from twisted.internet.interfaces import IPushProducer, IConsumer +from foolscap.api import eventually, fireEventually +from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \ + MDMF_VERSION, SDMF_VERSION +from allmydata.util import hashutil, log, mathutil from allmydata.util.dictutil import DictOfSets from allmydata import hashtree, codec from allmydata.storage.server import si_b2a @@ -14,7 +16,7 @@ from pycryptopp.cipher.aes import AES from pycryptopp.publickey import rsa from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError -from allmydata.mutable.layout import SIGNED_PREFIX, unpack_share_data +from allmydata.mutable.layout import MDMFSlotReadProxy class RetrieveStatus: implements(IRetrieveStatus) @@ -81,8 +83,10 @@ class Retrieve: # times, and each will have a separate response chain. However the # Retrieve object will remain tied to a specific version of the file, and # will use a single ServerMap instance. + implements(IPushProducer) - def __init__(self, filenode, servermap, verinfo, fetch_privkey=False): + def __init__(self, filenode, servermap, verinfo, fetch_privkey=False, + verify=False): self._node = filenode assert self._node.get_pubkey() self._storage_index = filenode.get_storage_index() @@ -100,11 +104,32 @@ class Retrieve: self.verinfo = verinfo # during repair, we may be called upon to grab the private key, since # it wasn't picked up during a verify=False checker run, and we'll - # need it for repair to generate the a new version. - self._need_privkey = fetch_privkey - if self._node.get_privkey(): + # need it for repair to generate a new version. + self._need_privkey = fetch_privkey or verify + if self._node.get_privkey() and not verify: self._need_privkey = False + if self._need_privkey: + # TODO: Evaluate the need for this. We'll use it if we want + # to limit how many queries are on the wire for the privkey + # at once. + self._privkey_query_markers = [] # one Marker for each time we've + # tried to get the privkey. + + # verify means that we are using the downloader logic to verify all + # of our shares. This tells the downloader a few things. + # + # 1. We need to download all of the shares. + # 2. We don't need to decode or decrypt the shares, since our + # caller doesn't care about the plaintext, only the + # information about which shares are or are not valid. + # 3. When we are validating readers, we need to validate the + # signature on the prefix. Do we? We already do this in the + # servermap update? + self._verify = False + if verify: + self._verify = True + self._status = RetrieveStatus() self._status.set_storage_index(self._storage_index) self._status.set_helper(False) @@ -114,6 +139,13 @@ class Retrieve: offsets_tuple) = self.verinfo self._status.set_size(datalength) self._status.set_encoding(k, N) + self.readers = {} + self._paused = False + self._paused_deferred = None + self._offset = None + self._read_length = None + self.log("got seqnum %d" % self.verinfo[0]) + def get_status(self): return self._status @@ -125,11 +157,74 @@ class Retrieve: kwargs["facility"] = "tahoe.mutable.retrieve" return log.msg(*args, **kwargs) - def download(self): + + ################### + # IPushProducer + + def pauseProducing(self): + """ + I am called by my download target if we have produced too much + data for it to handle. I make the downloader stop producing new + data until my resumeProducing method is called. + """ + if self._paused: + return + + # fired when the download is unpaused. + self._old_status = self._status.get_status() + self._status.set_status("Paused") + + self._pause_deferred = defer.Deferred() + self._paused = True + + + def resumeProducing(self): + """ + I am called by my download target once it is ready to begin + receiving data again. + """ + if not self._paused: + return + + self._paused = False + p = self._pause_deferred + self._pause_deferred = None + self._status.set_status(self._old_status) + + eventually(p.callback, None) + + + def _check_for_paused(self, res): + """ + I am called just before a write to the consumer. I return a + Deferred that eventually fires with the data that is to be + written to the consumer. If the download has not been paused, + the Deferred fires immediately. Otherwise, the Deferred fires + when the downloader is unpaused. + """ + if self._paused: + d = defer.Deferred() + self._pause_deferred.addCallback(lambda ignored: d.callback(res)) + return d + return defer.succeed(res) + + + def download(self, consumer=None, offset=0, size=None): + assert IConsumer.providedBy(consumer) or self._verify + + if consumer: + self._consumer = consumer + # we provide IPushProducer, so streaming=True, per + # IConsumer. + self._consumer.registerProducer(self, streaming=True) + self._done_deferred = defer.Deferred() self._started = time.time() self._status.set_status("Retrieving Shares") + self._offset = offset + self._read_length = size + # first, which servers can we use? versionmap = self.servermap.make_versionmap() shares = versionmap[self.verinfo] @@ -137,412 +232,900 @@ class Retrieve: self.remaining_sharemap = DictOfSets() for (shnum, peerid, timestamp) in shares: self.remaining_sharemap.add(shnum, peerid) + # If the servermap update fetched anything, it fetched at least 1 + # KiB, so we ask for that much. + # TODO: Change the cache methods to allow us to fetch all of the + # data that they have, then change this method to do that. + any_cache = self._node._read_from_cache(self.verinfo, shnum, + 0, 1000) + ss = self.servermap.connections[peerid] + reader = MDMFSlotReadProxy(ss, + self._storage_index, + shnum, + any_cache) + reader.peerid = peerid + self.readers[shnum] = reader + self.shares = {} # maps shnum to validated blocks + self._active_readers = [] # list of active readers for this dl. + self._validated_readers = set() # set of readers that we have + # validated the prefix of + self._block_hash_trees = {} # shnum => hashtree # how many shares do we need? - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + (seqnum, + root_hash, + IV, + segsize, + datalength, + k, + N, + prefix, offsets_tuple) = self.verinfo + + + # We need one share hash tree for the entire file; its leaves + # are the roots of the block hash trees for the shares that + # comprise it, and its root is in the verinfo. + self.share_hash_tree = hashtree.IncompleteHashTree(N) + self.share_hash_tree.set_hashes({0: root_hash}) + + # This will set up both the segment decoder and the tail segment + # decoder, as well as a variety of other instance variables that + # the download process will use. + self._setup_encoding_parameters() assert len(self.remaining_sharemap) >= k - # we start with the lowest shnums we have available, since FEC is - # faster if we're using "primary shares" - self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k]) - for shnum in self.active_shnums: - # we use an arbitrary peer who has the share. If shares are - # doubled up (more than one share per peer), we could make this - # run faster by spreading the load among multiple peers. But the - # algorithm to do that is more complicated than I want to write - # right now, and a well-provisioned grid shouldn't have multiple - # shares per peer. - peerid = list(self.remaining_sharemap[shnum])[0] - self.get_data(shnum, peerid) - - # control flow beyond this point: state machine. Receiving responses - # from queries is the input. We might send out more queries, or we - # might produce a result. + self.log("starting download") + self._paused = False + self._started_fetching = time.time() + + self._add_active_peers() + # The download process beyond this is a state machine. + # _add_active_peers will select the peers that we want to use + # for the download, and then attempt to start downloading. After + # each segment, it will check for doneness, reacting to broken + # peers and corrupt shares as necessary. If it runs out of good + # peers before downloading all of the segments, _done_deferred + # will errback. Otherwise, it will eventually callback with the + # contents of the mutable file. return self._done_deferred - def get_data(self, shnum, peerid): - self.log(format="sending sh#%(shnum)d request to [%(peerid)s]", - shnum=shnum, - peerid=idlib.shortnodeid_b2a(peerid), - level=log.NOISY) - ss = self.servermap.connections[peerid] - started = time.time() - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + + def decode(self, blocks_and_salts, segnum): + """ + I am a helper method that the mutable file update process uses + as a shortcut to decode and decrypt the segments that it needs + to fetch in order to perform a file update. I take in a + collection of blocks and salts, and pick some of those to make a + segment with. I return the plaintext associated with that + segment. + """ + # shnum => block hash tree. Unusued, but setup_encoding_parameters will + # want to set this. + # XXX: Make it so that it won't set this if we're just decoding. + self._block_hash_trees = {} + self._setup_encoding_parameters() + # This is the form expected by decode. + blocks_and_salts = blocks_and_salts.items() + blocks_and_salts = [(True, [d]) for d in blocks_and_salts] + + d = self._decode_blocks(blocks_and_salts, segnum) + d.addCallback(self._decrypt_segment) + return d + + + def _setup_encoding_parameters(self): + """ + I set up the encoding parameters, including k, n, the number + of segments associated with this file, and the segment decoder. + """ + (seqnum, + root_hash, + IV, + segsize, + datalength, + k, + n, + known_prefix, offsets_tuple) = self.verinfo - offsets = dict(offsets_tuple) + self._required_shares = k + self._total_shares = n + self._segment_size = segsize + self._data_length = datalength - # we read the checkstring, to make sure that the data we grab is from - # the right version. - readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ] + if not IV: + self._version = MDMF_VERSION + else: + self._version = SDMF_VERSION - # We also read the data, and the hashes necessary to validate them - # (share_hash_chain, block_hash_tree, share_data). We don't read the - # signature or the pubkey, since that was handled during the - # servermap phase, and we'll be comparing the share hash chain - # against the roothash that was validated back then. + if datalength and segsize: + self._num_segments = mathutil.div_ceil(datalength, segsize) + self._tail_data_size = datalength % segsize + else: + self._num_segments = 0 + self._tail_data_size = 0 - readv.append( (offsets['share_hash_chain'], - offsets['enc_privkey'] - offsets['share_hash_chain'] ) ) + self._segment_decoder = codec.CRSDecoder() + self._segment_decoder.set_params(segsize, k, n) - # if we need the private key (for repair), we also fetch that - if self._need_privkey: - readv.append( (offsets['enc_privkey'], - offsets['EOF'] - offsets['enc_privkey']) ) - - m = Marker() - self._outstanding_queries[m] = (peerid, shnum, started) - - # ask the cache first - got_from_cache = False - datavs = [] - for (offset, length) in readv: - data = self._node._read_from_cache(self.verinfo, shnum, offset, length) - if data is not None: - datavs.append(data) - if len(datavs) == len(readv): - self.log("got data from cache") - got_from_cache = True - d = fireEventually({shnum: datavs}) - # datavs is a dict mapping shnum to a pair of strings + if not self._tail_data_size: + self._tail_data_size = segsize + + self._tail_segment_size = mathutil.next_multiple(self._tail_data_size, + self._required_shares) + if self._tail_segment_size == self._segment_size: + self._tail_decoder = self._segment_decoder else: - d = self._do_read(ss, peerid, self._storage_index, [shnum], readv) - self.remaining_sharemap.discard(shnum, peerid) - - d.addCallback(self._got_results, m, peerid, started, got_from_cache) - d.addErrback(self._query_failed, m, peerid) - # errors that aren't handled by _query_failed (and errors caused by - # _query_failed) get logged, but we still want to check for doneness. - def _oops(f): - self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s", - shnum=shnum, - peerid=idlib.shortnodeid_b2a(peerid), - failure=f, - level=log.WEIRD, umid="W0xnQA") - d.addErrback(_oops) - d.addBoth(self._check_for_done) - # any error during _check_for_done means the download fails. If the - # download is successful, _check_for_done will fire _done by itself. - d.addErrback(self._done) - d.addErrback(log.err) - return d # purely for testing convenience - - def _do_read(self, ss, peerid, storage_index, shnums, readv): - # isolate the callRemote to a separate method, so tests can subclass - # Publish and override it - d = ss.callRemote("slot_readv", storage_index, shnums, readv) - return d + self._tail_decoder = codec.CRSDecoder() + self._tail_decoder.set_params(self._tail_segment_size, + self._required_shares, + self._total_shares) - def remove_peer(self, peerid): - for shnum in list(self.remaining_sharemap.keys()): - self.remaining_sharemap.discard(shnum, peerid) + self.log("got encoding parameters: " + "k: %d " + "n: %d " + "%d segments of %d bytes each (%d byte tail segment)" % \ + (k, n, self._num_segments, self._segment_size, + self._tail_segment_size)) - def _got_results(self, datavs, marker, peerid, started, got_from_cache): - now = time.time() - elapsed = now - started - if not got_from_cache: - self._status.add_fetch_timing(peerid, elapsed) - self.log(format="got results (%(shares)d shares) from [%(peerid)s]", - shares=len(datavs), - peerid=idlib.shortnodeid_b2a(peerid), - level=log.NOISY) - self._outstanding_queries.pop(marker, None) - if not self._running: - return + for i in xrange(self._total_shares): + # So we don't have to do this later. + self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments) - # note that we only ask for a single share per query, so we only - # expect a single share back. On the other hand, we use the extra - # shares if we get them.. seems better than an assert(). + # Our last task is to tell the downloader where to start and + # where to stop. We use three parameters for that: + # - self._start_segment: the segment that we need to start + # downloading from. + # - self._current_segment: the next segment that we need to + # download. + # - self._last_segment: The last segment that we were asked to + # download. + # + # We say that the download is complete when + # self._current_segment > self._last_segment. We use + # self._start_segment and self._last_segment to know when to + # strip things off of segments, and how much to strip. + if self._offset: + self.log("got offset: %d" % self._offset) + # our start segment is the first segment containing the + # offset we were given. + start = mathutil.div_ceil(self._offset, + self._segment_size) + # this gets us the first segment after self._offset. Then + # our start segment is the one before it. + start -= 1 - for shnum,datav in datavs.items(): - (prefix, hash_and_data) = datav[:2] - try: - self._got_results_one_share(shnum, peerid, - prefix, hash_and_data) - except CorruptShareError, e: - # log it and give the other shares a chance to be processed - f = failure.Failure() - self.log(format="bad share: %(f_value)s", - f_value=str(f.value), failure=f, - level=log.WEIRD, umid="7fzWZw") - self.notify_server_corruption(peerid, shnum, str(e)) - self.remove_peer(peerid) - self.servermap.mark_bad_share(peerid, shnum, prefix) - self._bad_shares.add( (peerid, shnum) ) - self._status.problems[peerid] = f - self._last_failure = f - pass - if self._need_privkey and len(datav) > 2: - lp = None - self._try_to_validate_privkey(datav[2], peerid, shnum, lp) - # all done! + assert start < self._num_segments + self._start_segment = start + self.log("got start segment: %d" % self._start_segment) + else: + self._start_segment = 0 - def notify_server_corruption(self, peerid, shnum, reason): - ss = self.servermap.connections[peerid] - ss.callRemoteOnly("advise_corrupt_share", - "mutable", self._storage_index, shnum, reason) - def _got_results_one_share(self, shnum, peerid, - got_prefix, got_hash_and_data): - self.log("_got_results: got shnum #%d from peerid %s" - % (shnum, idlib.shortnodeid_b2a(peerid))) - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + if self._read_length: + # our end segment is the last segment containing part of the + # segment that we were asked to read. + self.log("got read length %d" % self._read_length) + end_data = self._offset + self._read_length + end = mathutil.div_ceil(end_data, + self._segment_size) + end -= 1 + assert end < self._num_segments + self._last_segment = end + self.log("got end segment: %d" % self._last_segment) + else: + self._last_segment = self._num_segments - 1 + + self._current_segment = self._start_segment + + def _add_active_peers(self): + """ + I populate self._active_readers with enough active readers to + retrieve the contents of this mutable file. I am called before + downloading starts, and (eventually) after each validation + error, connection error, or other problem in the download. + """ + # TODO: It would be cool to investigate other heuristics for + # reader selection. For instance, the cost (in time the user + # spends waiting for their file) of selecting a really slow peer + # that happens to have a primary share is probably more than + # selecting a really fast peer that doesn't have a primary + # share. Maybe the servermap could be extended to provide this + # information; it could keep track of latency information while + # it gathers more important data, and then this routine could + # use that to select active readers. + # + # (these and other questions would be easier to answer with a + # robust, configurable tahoe-lafs simulator, which modeled node + # failures, differences in node speed, and other characteristics + # that we expect storage servers to have. You could have + # presets for really stable grids (like allmydata.com), + # friendnets, make it easy to configure your own settings, and + # then simulate the effect of big changes on these use cases + # instead of just reasoning about what the effect might be. Out + # of scope for MDMF, though.) + + # We need at least self._required_shares readers to download a + # segment. + if self._verify: + needed = self._total_shares + else: + needed = self._required_shares - len(self._active_readers) + # XXX: Why don't format= log messages work here? + self.log("adding %d peers to the active peers list" % needed) + + # We favor lower numbered shares, since FEC is faster with + # primary shares than with other shares, and lower-numbered + # shares are more likely to be primary than higher numbered + # shares. + active_shnums = set(sorted(self.remaining_sharemap.keys())) + # We shouldn't consider adding shares that we already have; this + # will cause problems later. + active_shnums -= set([reader.shnum for reader in self._active_readers]) + active_shnums = list(active_shnums)[:needed] + if len(active_shnums) < needed and not self._verify: + # We don't have enough readers to retrieve the file; fail. + return self._failed() + + for shnum in active_shnums: + self._active_readers.append(self.readers[shnum]) + self.log("added reader for share %d" % shnum) + assert len(self._active_readers) >= self._required_shares + # Conceptually, this is part of the _add_active_peers step. It + # validates the prefixes of newly added readers to make sure + # that they match what we are expecting for self.verinfo. If + # validation is successful, _validate_active_prefixes will call + # _download_current_segment for us. If validation is + # unsuccessful, then _validate_prefixes will remove the peer and + # call _add_active_peers again, where we will attempt to rectify + # the problem by choosing another peer. + return self._validate_active_prefixes() + + + def _validate_active_prefixes(self): + """ + I check to make sure that the prefixes on the peers that I am + currently reading from match the prefix that we want to see, as + said in self.verinfo. + + If I find that all of the active peers have acceptable prefixes, + I pass control to _download_current_segment, which will use + those peers to do cool things. If I find that some of the active + peers have unacceptable prefixes, I will remove them from active + peers (and from further consideration) and call + _add_active_peers to attempt to rectify the situation. I keep + track of which peers I have already validated so that I don't + need to do so again. + """ + assert self._active_readers, "No more active readers" + + ds = [] + new_readers = set(self._active_readers) - self._validated_readers + self.log('validating %d newly-added active readers' % len(new_readers)) + + for reader in new_readers: + # We force a remote read here -- otherwise, we are relying + # on cached data that we already verified as valid, and we + # won't detect an uncoordinated write that has occurred + # since the last servermap update. + d = reader.get_prefix(force_remote=True) + d.addCallback(self._try_to_validate_prefix, reader) + ds.append(d) + dl = defer.DeferredList(ds, consumeErrors=True) + def _check_results(results): + # Each result in results will be of the form (success, msg). + # We don't care about msg, but success will tell us whether + # or not the checkstring validated. If it didn't, we need to + # remove the offending (peer,share) from our active readers, + # and ensure that active readers is again populated. + bad_readers = [] + for i, result in enumerate(results): + if not result[0]: + reader = self._active_readers[i] + f = result[1] + assert isinstance(f, failure.Failure) + + self.log("The reader %s failed to " + "properly validate: %s" % \ + (reader, str(f.value))) + bad_readers.append((reader, f)) + else: + reader = self._active_readers[i] + self.log("the reader %s checks out, so we'll use it" % \ + reader) + self._validated_readers.add(reader) + # Each time we validate a reader, we check to see if + # we need the private key. If we do, we politely ask + # for it and then continue computing. If we find + # that we haven't gotten it at the end of + # segment decoding, then we'll take more drastic + # measures. + if self._need_privkey and not self._node.is_readonly(): + d = reader.get_encprivkey() + d.addCallback(self._try_to_validate_privkey, reader) + if bad_readers: + # We do them all at once, or else we screw up list indexing. + for (reader, f) in bad_readers: + self._mark_bad_share(reader, f) + if self._verify: + if len(self._active_readers) >= self._required_shares: + return self._download_current_segment() + else: + return self._failed() + else: + return self._add_active_peers() + else: + return self._download_current_segment() + # The next step will assert that it has enough active + # readers to fetch shares; we just need to remove it. + dl.addCallback(_check_results) + return dl + + + def _try_to_validate_prefix(self, prefix, reader): + """ + I check that the prefix returned by a candidate server for + retrieval matches the prefix that the servermap knows about + (and, hence, the prefix that was validated earlier). If it does, + I return True, which means that I approve of the use of the + candidate server for segment retrieval. If it doesn't, I return + False, which means that another server must be chosen. + """ + (seqnum, + root_hash, + IV, + segsize, + datalength, + k, + N, + known_prefix, offsets_tuple) = self.verinfo - assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix)) - if got_prefix != prefix: - msg = "someone wrote to the data since we read the servermap: prefix changed" - raise UncoordinatedWriteError(msg) - (share_hash_chain, block_hash_tree, - share_data) = unpack_share_data(self.verinfo, got_hash_and_data) - - assert isinstance(share_data, str) - # build the block hash tree. SDMF has only one leaf. - leaves = [hashutil.block_hash(share_data)] - t = hashtree.HashTree(leaves) - if list(t) != block_hash_tree: - raise CorruptShareError(peerid, shnum, "block hash tree failure") - share_hash_leaf = t[0] - t2 = hashtree.IncompleteHashTree(N) - # root_hash was checked by the signature - t2.set_hashes({0: root_hash}) - try: - t2.set_hashes(hashes=share_hash_chain, - leaves={shnum: share_hash_leaf}) - except (hashtree.BadHashError, hashtree.NotEnoughHashesError, - IndexError), e: - msg = "corrupt hashes: %s" % (e,) - raise CorruptShareError(peerid, shnum, msg) - self.log(" data valid! len=%d" % len(share_data)) - # each query comes down to this: placing validated share data into - # self.shares - self.shares[shnum] = share_data + if known_prefix != prefix: + self.log("prefix from share %d doesn't match" % reader.shnum) + raise UncoordinatedWriteError("Mismatched prefix -- this could " + "indicate an uncoordinated write") + # Otherwise, we're okay -- no issues. - def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp): - alleged_privkey_s = self._node._decrypt_privkey(enc_privkey) - alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s) - if alleged_writekey != self._node.get_writekey(): - self.log("invalid privkey from %s shnum %d" % - (idlib.nodeid_b2a(peerid)[:8], shnum), - parent=lp, level=log.WEIRD, umid="YIw4tA") - return + def _remove_reader(self, reader): + """ + At various points, we will wish to remove a peer from + consideration and/or use. These include, but are not necessarily + limited to: - # it's good - self.log("got valid privkey from shnum %d on peerid %s" % - (shnum, idlib.shortnodeid_b2a(peerid)), - parent=lp) - privkey = rsa.create_signing_key_from_string(alleged_privkey_s) - self._node._populate_encprivkey(enc_privkey) - self._node._populate_privkey(privkey) - self._need_privkey = False + - A connection error. + - A mismatched prefix (that is, a prefix that does not match + our conception of the version information string). + - A failing block hash, salt hash, or share hash, which can + indicate disk failure/bit flips, or network trouble. - def _query_failed(self, f, marker, peerid): - self.log(format="query to [%(peerid)s] failed", - peerid=idlib.shortnodeid_b2a(peerid), - level=log.NOISY) - self._status.problems[peerid] = f - self._outstanding_queries.pop(marker, None) - if not self._running: - return + This method will do that. I will make sure that the + (shnum,reader) combination represented by my reader argument is + not used for anything else during this download. I will not + advise the reader of any corruption, something that my callers + may wish to do on their own. + """ + # TODO: When you're done writing this, see if this is ever + # actually used for something that _mark_bad_share isn't. I have + # a feeling that they will be used for very similar things, and + # that having them both here is just going to be an epic amount + # of code duplication. + # + # (well, okay, not epic, but meaningful) + self.log("removing reader %s" % reader) + # Remove the reader from _active_readers + self._active_readers.remove(reader) + # TODO: self.readers.remove(reader)? + for shnum in list(self.remaining_sharemap.keys()): + self.remaining_sharemap.discard(shnum, reader.peerid) + + + def _mark_bad_share(self, reader, f): + """ + I mark the (peerid, shnum) encapsulated by my reader argument as + a bad share, which means that it will not be used anywhere else. + + There are several reasons to want to mark something as a bad + share. These include: + + - A connection error to the peer. + - A mismatched prefix (that is, a prefix that does not match + our local conception of the version information string). + - A failing block hash, salt hash, share hash, or other + integrity check. + + This method will ensure that readers that we wish to mark bad + (for these reasons or other reasons) are not used for the rest + of the download. Additionally, it will attempt to tell the + remote peer (with no guarantee of success) that its share is + corrupt. + """ + self.log("marking share %d on server %s as bad" % \ + (reader.shnum, reader)) + prefix = self.verinfo[-2] + self.servermap.mark_bad_share(reader.peerid, + reader.shnum, + prefix) + self._remove_reader(reader) + self._bad_shares.add((reader.peerid, reader.shnum, f)) + self._status.problems[reader.peerid] = f self._last_failure = f - self.remove_peer(peerid) - level = log.WEIRD - if f.check(DeadReferenceError): - level = log.UNUSUAL - self.log(format="error during query: %(f_value)s", - f_value=str(f.value), failure=f, level=level, umid="gOJB5g") + self.notify_server_corruption(reader.peerid, reader.shnum, + str(f.value)) - def _check_for_done(self, res): - # exit paths: - # return : keep waiting, no new queries - # return self._send_more_queries(outstanding) : send some more queries - # fire self._done(plaintext) : download successful - # raise exception : download fails - - self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s", - running=self._running, decoding=self._decoding, - level=log.NOISY) - if not self._running: - return - if self._decoding: - return - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, - offsets_tuple) = self.verinfo - if len(self.shares) < k: - # we don't have enough shares yet - return self._maybe_send_more_queries(k) - if self._need_privkey: - # we got k shares, but none of them had a valid privkey. TODO: - # look further. Adding code to do this is a bit complicated, and - # I want to avoid that complication, and this should be pretty - # rare (k shares with bitflips in the enc_privkey but not in the - # data blocks). If we actually do get here, the subsequent repair - # will fail for lack of a privkey. - self.log("got k shares but still need_privkey, bummer", - level=log.WEIRD, umid="MdRHPA") - - # we have enough to finish. All the shares have had their hashes - # checked, so if something fails at this point, we don't know how - # to fix it, so the download will fail. - - self._decoding = True # avoid reentrancy - self._status.set_status("decoding") - now = time.time() - elapsed = now - self._started - self._status.timings["fetch"] = elapsed - - d = defer.maybeDeferred(self._decode) - d.addCallback(self._decrypt, IV, self._node.get_readkey()) - d.addBoth(self._done) - return d # purely for test convenience - - def _maybe_send_more_queries(self, k): - # we don't have enough shares yet. Should we send out more queries? - # There are some number of queries outstanding, each for a single - # share. If we can generate 'needed_shares' additional queries, we do - # so. If we can't, then we know this file is a goner, and we raise - # NotEnoughSharesError. - self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, " - "outstanding=%(outstanding)d"), - have=len(self.shares), k=k, - outstanding=len(self._outstanding_queries), - level=log.NOISY) - - remaining_shares = k - len(self.shares) - needed = remaining_shares - len(self._outstanding_queries) - if not needed: - # we have enough queries in flight already - - # TODO: but if they've been in flight for a long time, and we - # have reason to believe that new queries might respond faster - # (i.e. we've seen other queries come back faster, then consider - # sending out new queries. This could help with peers which have - # silently gone away since the servermap was updated, for which - # we're still waiting for the 15-minute TCP disconnect to happen. - self.log("enough queries are in flight, no more are needed", - level=log.NOISY) - return + def _download_current_segment(self): + """ + I download, validate, decode, decrypt, and assemble the segment + that this Retrieve is currently responsible for downloading. + """ + assert len(self._active_readers) >= self._required_shares + if self._current_segment <= self._last_segment: + d = self._process_segment(self._current_segment) + else: + d = defer.succeed(None) + d.addBoth(self._turn_barrier) + d.addCallback(self._check_for_done) + return d + + + def _turn_barrier(self, result): + """ + I help the download process avoid the recursion limit issues + discussed in #237. + """ + return fireEventually(result) - outstanding_shnums = set([shnum - for (peerid, shnum, started) - in self._outstanding_queries.values()]) - # prefer low-numbered shares, they are more likely to be primary - available_shnums = sorted(self.remaining_sharemap.keys()) - for shnum in available_shnums: - if shnum in outstanding_shnums: - # skip ones that are already in transit - continue - if shnum not in self.remaining_sharemap: - # no servers for that shnum. note that DictOfSets removes - # empty sets from the dict for us. - continue - peerid = list(self.remaining_sharemap[shnum])[0] - # get_data will remove that peerid from the sharemap, and add the - # query to self._outstanding_queries - self._status.set_status("Retrieving More Shares") - self.get_data(shnum, peerid) - needed -= 1 - if not needed: + + def _process_segment(self, segnum): + """ + I download, validate, decode, and decrypt one segment of the + file that this Retrieve is retrieving. This means coordinating + the process of getting k blocks of that file, validating them, + assembling them into one segment with the decoder, and then + decrypting them. + """ + self.log("processing segment %d" % segnum) + + # TODO: The old code uses a marker. Should this code do that + # too? What did the Marker do? + assert len(self._active_readers) >= self._required_shares + + # We need to ask each of our active readers for its block and + # salt. We will then validate those. If validation is + # successful, we will assemble the results into plaintext. + ds = [] + for reader in self._active_readers: + started = time.time() + d = reader.get_block_and_salt(segnum, queue=True) + d2 = self._get_needed_hashes(reader, segnum) + dl = defer.DeferredList([d, d2], consumeErrors=True) + dl.addCallback(self._validate_block, segnum, reader, started) + dl.addErrback(self._validation_or_decoding_failed, [reader]) + ds.append(dl) + reader.flush() + dl = defer.DeferredList(ds) + if self._verify: + dl.addCallback(lambda ignored: "") + dl.addCallback(self._set_segment) + else: + dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum) + return dl + + + def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum): + """ + I take the results of fetching and validating the blocks from a + callback chain in another method. If the results are such that + they tell me that validation and fetching succeeded without + incident, I will proceed with decoding and decryption. + Otherwise, I will do nothing. + """ + self.log("trying to decode and decrypt segment %d" % segnum) + failures = False + for block_and_salt in blocks_and_salts: + if not block_and_salt[0] or block_and_salt[1] == None: + self.log("some validation operations failed; not proceeding") + failures = True break + if not failures: + self.log("everything looks ok, building segment %d" % segnum) + d = self._decode_blocks(blocks_and_salts, segnum) + d.addCallback(self._decrypt_segment) + d.addErrback(self._validation_or_decoding_failed, + self._active_readers) + # check to see whether we've been paused before writing + # anything. + d.addCallback(self._check_for_paused) + d.addCallback(self._set_segment) + return d + else: + return defer.succeed(None) - # at this point, we have as many outstanding queries as we can. If - # needed!=0 then we might not have enough to recover the file. - if needed: - format = ("ran out of peers: " - "have %(have)d shares (k=%(k)d), " - "%(outstanding)d queries in flight, " - "need %(need)d more, " - "found %(bad)d bad shares") - args = {"have": len(self.shares), - "k": k, - "outstanding": len(self._outstanding_queries), - "need": needed, - "bad": len(self._bad_shares), - } - self.log(format=format, - level=log.WEIRD, umid="ezTfjw", **args) - err = NotEnoughSharesError("%s, last failure: %s" % - (format % args, self._last_failure)) - if self._bad_shares: - self.log("We found some bad shares this pass. You should " - "update the servermap and try again to check " - "more peers", - level=log.WEIRD, umid="EFkOlA") - err.servermap = self.servermap - raise err + def _set_segment(self, segment): + """ + Given a plaintext segment, I register that segment with the + target that is handling the file download. + """ + self.log("got plaintext for segment %d" % self._current_segment) + if self._current_segment == self._start_segment: + # We're on the first segment. It's possible that we want + # only some part of the end of this segment, and that we + # just downloaded the whole thing to get that part. If so, + # we need to account for that and give the reader just the + # data that they want. + n = self._offset % self._segment_size + self.log("stripping %d bytes off of the first segment" % n) + self.log("original segment length: %d" % len(segment)) + segment = segment[n:] + self.log("new segment length: %d" % len(segment)) + + if self._current_segment == self._last_segment and self._read_length is not None: + # We're on the last segment. It's possible that we only want + # part of the beginning of this segment, and that we + # downloaded the whole thing anyway. Make sure to give the + # caller only the portion of the segment that they want to + # receive. + extra = self._read_length + if self._start_segment != self._last_segment: + extra -= self._segment_size - \ + (self._offset % self._segment_size) + extra %= self._segment_size + self.log("original segment length: %d" % len(segment)) + segment = segment[:extra] + self.log("new segment length: %d" % len(segment)) + self.log("only taking %d bytes of the last segment" % extra) + + if not self._verify: + self._consumer.write(segment) + else: + # we don't care about the plaintext if we are doing a verify. + segment = None + self._current_segment += 1 + + + def _validation_or_decoding_failed(self, f, readers): + """ + I am called when a block or a salt fails to correctly validate, or when + the decryption or decoding operation fails for some reason. I react to + this failure by notifying the remote server of corruption, and then + removing the remote peer from further activity. + """ + assert isinstance(readers, list) + bad_shnums = [reader.shnum for reader in readers] + + self.log("validation or decoding failed on share(s) %s, peer(s) %s " + ", segment %d: %s" % \ + (bad_shnums, readers, self._current_segment, str(f))) + for reader in readers: + self._mark_bad_share(reader, f) return - def _decode(self): - started = time.time() - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, - offsets_tuple) = self.verinfo - # shares_dict is a dict mapping shnum to share data, but the codec - # wants two lists. - shareids = []; shares = [] - for shareid, share in self.shares.items(): + def _validate_block(self, results, segnum, reader, started): + """ + I validate a block from one share on a remote server. + """ + # Grab the part of the block hash tree that is necessary to + # validate this block, then generate the block hash root. + self.log("validating share %d for segment %d" % (reader.shnum, + segnum)) + self._status.add_fetch_timing(reader.peerid, started) + self._status.set_status("Valdiating blocks for segment %d" % segnum) + # Did we fail to fetch either of the things that we were + # supposed to? Fail if so. + if not results[0][0] and results[1][0]: + # handled by the errback handler. + + # These all get batched into one query, so the resulting + # failure should be the same for all of them, so we can just + # use the first one. + assert isinstance(results[0][1], failure.Failure) + + f = results[0][1] + raise CorruptShareError(reader.peerid, + reader.shnum, + "Connection error: %s" % str(f)) + + block_and_salt, block_and_sharehashes = results + block, salt = block_and_salt[1] + blockhashes, sharehashes = block_and_sharehashes[1] + + blockhashes = dict(enumerate(blockhashes[1])) + self.log("the reader gave me the following blockhashes: %s" % \ + blockhashes.keys()) + self.log("the reader gave me the following sharehashes: %s" % \ + sharehashes[1].keys()) + bht = self._block_hash_trees[reader.shnum] + + if bht.needed_hashes(segnum, include_leaf=True): + try: + bht.set_hashes(blockhashes) + except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \ + IndexError), e: + raise CorruptShareError(reader.peerid, + reader.shnum, + "block hash tree failure: %s" % e) + + if self._version == MDMF_VERSION: + blockhash = hashutil.block_hash(salt + block) + else: + blockhash = hashutil.block_hash(block) + # If this works without an error, then validation is + # successful. + try: + bht.set_hashes(leaves={segnum: blockhash}) + except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \ + IndexError), e: + raise CorruptShareError(reader.peerid, + reader.shnum, + "block hash tree failure: %s" % e) + + # Reaching this point means that we know that this segment + # is correct. Now we need to check to see whether the share + # hash chain is also correct. + # SDMF wrote share hash chains that didn't contain the + # leaves, which would be produced from the block hash tree. + # So we need to validate the block hash tree first. If + # successful, then bht[0] will contain the root for the + # shnum, which will be a leaf in the share hash tree, which + # will allow us to validate the rest of the tree. + if self.share_hash_tree.needed_hashes(reader.shnum, + include_leaf=True) or \ + self._verify: + try: + self.share_hash_tree.set_hashes(hashes=sharehashes[1], + leaves={reader.shnum: bht[0]}) + except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \ + IndexError), e: + raise CorruptShareError(reader.peerid, + reader.shnum, + "corrupt hashes: %s" % e) + + self.log('share %d is valid for segment %d' % (reader.shnum, + segnum)) + return {reader.shnum: (block, salt)} + + + def _get_needed_hashes(self, reader, segnum): + """ + I get the hashes needed to validate segnum from the reader, then return + to my caller when this is done. + """ + bht = self._block_hash_trees[reader.shnum] + needed = bht.needed_hashes(segnum, include_leaf=True) + # The root of the block hash tree is also a leaf in the share + # hash tree. So we don't need to fetch it from the remote + # server. In the case of files with one segment, this means that + # we won't fetch any block hash tree from the remote server, + # since the hash of each share of the file is the entire block + # hash tree, and is a leaf in the share hash tree. This is fine, + # since any share corruption will be detected in the share hash + # tree. + #needed.discard(0) + self.log("getting blockhashes for segment %d, share %d: %s" % \ + (segnum, reader.shnum, str(needed))) + d1 = reader.get_blockhashes(needed, queue=True, force_remote=True) + if self.share_hash_tree.needed_hashes(reader.shnum): + need = self.share_hash_tree.needed_hashes(reader.shnum) + self.log("also need sharehashes for share %d: %s" % (reader.shnum, + str(need))) + d2 = reader.get_sharehashes(need, queue=True, force_remote=True) + else: + d2 = defer.succeed({}) # the logic in the next method + # expects a dict + dl = defer.DeferredList([d1, d2], consumeErrors=True) + return dl + + + def _decode_blocks(self, blocks_and_salts, segnum): + """ + I take a list of k blocks and salts, and decode that into a + single encrypted segment. + """ + d = {} + # We want to merge our dictionaries to the form + # {shnum: blocks_and_salts} + # + # The dictionaries come from validate block that way, so we just + # need to merge them. + for block_and_salt in blocks_and_salts: + d.update(block_and_salt[1]) + + # All of these blocks should have the same salt; in SDMF, it is + # the file-wide IV, while in MDMF it is the per-segment salt. In + # either case, we just need to get one of them and use it. + # + # d.items()[0] is like (shnum, (block, salt)) + # d.items()[0][1] is like (block, salt) + # d.items()[0][1][1] is the salt. + salt = d.items()[0][1][1] + # Next, extract just the blocks from the dict. We'll use the + # salt in the next step. + share_and_shareids = [(k, v[0]) for k, v in d.items()] + d2 = dict(share_and_shareids) + shareids = [] + shares = [] + for shareid, share in d2.items(): shareids.append(shareid) shares.append(share) - assert len(shareids) >= k, len(shareids) + self._status.set_status("Decoding") + started = time.time() + assert len(shareids) >= self._required_shares, len(shareids) # zfec really doesn't want extra shares - shareids = shareids[:k] - shares = shares[:k] - - fec = codec.CRSDecoder() - fec.set_params(segsize, k, N) - - self.log("params %s, we have %d shares" % ((segsize, k, N), len(shares))) - self.log("about to decode, shareids=%s" % (shareids,)) - d = defer.maybeDeferred(fec.decode, shares, shareids) - def _done(buffers): - self._status.timings["decode"] = time.time() - started - self.log(" decode done, %d buffers" % len(buffers)) + shareids = shareids[:self._required_shares] + shares = shares[:self._required_shares] + self.log("decoding segment %d" % segnum) + if segnum == self._num_segments - 1: + d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids) + else: + d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids) + def _process(buffers): segment = "".join(buffers) + self.log(format="now decoding segment %(segnum)s of %(numsegs)s", + segnum=segnum, + numsegs=self._num_segments, + level=log.NOISY) self.log(" joined length %d, datalength %d" % - (len(segment), datalength)) - segment = segment[:datalength] + (len(segment), self._data_length)) + if segnum == self._num_segments - 1: + size_to_use = self._tail_data_size + else: + size_to_use = self._segment_size + segment = segment[:size_to_use] self.log(" segment len=%d" % len(segment)) - return segment - def _err(f): - self.log(" decode failed: %s" % f) - return f - d.addCallback(_done) - d.addErrback(_err) + self._status.timings.setdefault("decode", 0) + self._status.timings['decode'] = time.time() - started + return segment, salt + d.addCallback(_process) return d - def _decrypt(self, crypttext, IV, readkey): + + def _decrypt_segment(self, segment_and_salt): + """ + I take a single segment and its salt, and decrypt it. I return + the plaintext of the segment that is in my argument. + """ + segment, salt = segment_and_salt self._status.set_status("decrypting") + self.log("decrypting segment %d" % self._current_segment) started = time.time() - key = hashutil.ssk_readkey_data_hash(IV, readkey) + key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey()) decryptor = AES(key) - plaintext = decryptor.process(crypttext) - self._status.timings["decrypt"] = time.time() - started + plaintext = decryptor.process(segment) + self._status.timings.setdefault("decrypt", 0) + self._status.timings['decrypt'] = time.time() - started return plaintext - def _done(self, res): - if not self._running: + + def notify_server_corruption(self, peerid, shnum, reason): + ss = self.servermap.connections[peerid] + ss.callRemoteOnly("advise_corrupt_share", + "mutable", self._storage_index, shnum, reason) + + + def _try_to_validate_privkey(self, enc_privkey, reader): + alleged_privkey_s = self._node._decrypt_privkey(enc_privkey) + alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s) + if alleged_writekey != self._node.get_writekey(): + self.log("invalid privkey from %s shnum %d" % + (reader, reader.shnum), + level=log.WEIRD, umid="YIw4tA") + if self._verify: + self.servermap.mark_bad_share(reader.peerid, reader.shnum, + self.verinfo[-2]) + e = CorruptShareError(reader.peerid, + reader.shnum, + "invalid privkey") + f = failure.Failure(e) + self._bad_shares.add((reader.peerid, reader.shnum, f)) return + + # it's good + self.log("got valid privkey from shnum %d on reader %s" % + (reader.shnum, reader)) + privkey = rsa.create_signing_key_from_string(alleged_privkey_s) + self._node._populate_encprivkey(enc_privkey) + self._node._populate_privkey(privkey) + self._need_privkey = False + + + def _check_for_done(self, res): + """ + I check to see if this Retrieve object has successfully finished + its work. + + I can exit in the following ways: + - If there are no more segments to download, then I exit by + causing self._done_deferred to fire with the plaintext + content requested by the caller. + - If there are still segments to be downloaded, and there + are enough active readers (readers which have not broken + and have not given us corrupt data) to continue + downloading, I send control back to + _download_current_segment. + - If there are still segments to be downloaded but there are + not enough active peers to download them, I ask + _add_active_peers to add more peers. If it is successful, + it will call _download_current_segment. If there are not + enough peers to retrieve the file, then that will cause + _done_deferred to errback. + """ + self.log("checking for doneness") + if self._current_segment > self._last_segment: + # No more segments to download, we're done. + self.log("got plaintext, done") + return self._done() + + if len(self._active_readers) >= self._required_shares: + # More segments to download, but we have enough good peers + # in self._active_readers that we can do that without issue, + # so go nab the next segment. + self.log("not done yet: on segment %d of %d" % \ + (self._current_segment + 1, self._num_segments)) + return self._download_current_segment() + + self.log("not done yet: on segment %d of %d, need to add peers" % \ + (self._current_segment + 1, self._num_segments)) + return self._add_active_peers() + + + def _done(self): + """ + I am called by _check_for_done when the download process has + finished successfully. After making some useful logging + statements, I return the decrypted contents to the owner of this + Retrieve object through self._done_deferred. + """ self._running = False self._status.set_active(False) - self._status.timings["total"] = time.time() - self._started - # res is either the new contents, or a Failure - if isinstance(res, failure.Failure): - self.log("Retrieve done, with failure", failure=res, - level=log.UNUSUAL) - self._status.set_status("Failed") + now = time.time() + self._status.timings['total'] = now - self._started + self._status.timings['fetch'] = now - self._started_fetching + + if self._verify: + ret = list(self._bad_shares) + self.log("done verifying, found %d bad shares" % len(ret)) else: - self.log("Retrieve done, success!") - self._status.set_status("Finished") - self._status.set_progress(1.0) - # remember the encoding parameters, use them again next time - (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, - offsets_tuple) = self.verinfo - self._node._populate_required_shares(k) - self._node._populate_total_shares(N) - eventually(self._done_deferred.callback, res) + # TODO: upload status here? + ret = self._consumer + self._consumer.unregisterProducer() + eventually(self._done_deferred.callback, ret) + + + def _failed(self): + """ + I am called by _add_active_peers when there are not enough + active peers left to complete the download. After making some + useful logging statements, I return an exception to that effect + to the caller of this Retrieve object through + self._done_deferred. + """ + self._running = False + self._status.set_active(False) + now = time.time() + self._status.timings['total'] = now - self._started + self._status.timings['fetch'] = now - self._started_fetching + if self._verify: + ret = list(self._bad_shares) + else: + format = ("ran out of peers: " + "have %(have)d of %(total)d segments " + "found %(bad)d bad shares " + "encoding %(k)d-of-%(n)d") + args = {"have": self._current_segment, + "total": self._num_segments, + "need": self._last_segment, + "k": self._required_shares, + "n": self._total_shares, + "bad": len(self._bad_shares)} + e = NotEnoughSharesError("%s, last failure: %s" % \ + (format % args, str(self._last_failure))) + f = failure.Failure(e) + ret = f + eventually(self._done_deferred.callback, ret)