From: Brian Warner Date: Sat, 5 Apr 2008 00:09:26 +0000 (-0700) Subject: mutable.py: checkpointing #303 work: retrieve does what I want, now starting in on... X-Git-Tag: allmydata-tahoe-1.1.0~249 X-Git-Url: https://git.rkrishnan.org/pf/vdrive?a=commitdiff_plain;h=2c939bfdd31bce3ba3d8dab41a27bd33b1fbe8d8;p=tahoe-lafs%2Ftahoe-lafs.git mutable.py: checkpointing #303 work: retrieve does what I want, now starting in on publish --- diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py index d6d54e9d..b5d002f9 100644 --- a/src/allmydata/mutable.py +++ b/src/allmydata/mutable.py @@ -1,6 +1,6 @@ import os, struct, time, weakref -from itertools import islice, count +from itertools import count from zope.interface import implements from twisted.internet import defer from twisted.python import failure @@ -44,6 +44,20 @@ class CorruptShareError(Exception): self.shnum, self.reason) +class DictOfSets(dict): + def add(self, key, value): + if key in self: + self[key].add(value) + else: + self[key] = set([value]) + + def discard(self, key, value): + if not key in self: + return + self[key].discard(value) + if not self[key]: + del self[key] + PREFIX = ">BQ32s16s" # each version has a different prefix SIGNED_PREFIX = ">BQ32s16s BBQQ" # this is covered by the signature HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets @@ -65,7 +79,7 @@ def unpack_header(data): return (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) def unpack_prefix_and_signature(data): - assert len(data) >= HEADER_LENGTH + assert len(data) >= HEADER_LENGTH, len(data) prefix = data[:struct.calcsize(SIGNED_PREFIX)] (version, @@ -131,29 +145,18 @@ def unpack_share(data): pubkey, signature, share_hash_chain, block_hash_tree, share_data, enc_privkey) -def unpack_share_data(data): - assert len(data) >= HEADER_LENGTH - o = {} - (version, - seqnum, - root_hash, - IV, - k, N, segsize, datalen, - o['signature'], - o['share_hash_chain'], - o['block_hash_tree'], - o['share_data'], - o['enc_privkey'], - o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH]) +def unpack_share_data(verinfo, hash_and_data): + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, o_t) = verinfo - assert version == 0 - if len(data) < o['enc_privkey']: - raise NeedMoreDataError(o['enc_privkey'], - o['enc_privkey'], o['EOF']-o['enc_privkey']) + # hash_and_data starts with the share_hash_chain, so figure out what the + # offsets really are + o = dict(o_t) + o_share_hash_chain = 0 + o_block_hash_tree = o['block_hash_tree'] - o['share_hash_chain'] + o_share_data = o['share_data'] - o['share_hash_chain'] + o_enc_privkey = o['enc_privkey'] - o['share_hash_chain'] - pubkey = data[HEADER_LENGTH:o['signature']] - signature = data[o['signature']:o['share_hash_chain']] - share_hash_chain_s = data[o['share_hash_chain']:o['block_hash_tree']] + share_hash_chain_s = hash_and_data[o_share_hash_chain:o_block_hash_tree] share_hash_format = ">H32s" hsize = struct.calcsize(share_hash_format) assert len(share_hash_chain_s) % hsize == 0, len(share_hash_chain_s) @@ -163,17 +166,15 @@ def unpack_share_data(data): (hid, h) = struct.unpack(share_hash_format, chunk) share_hash_chain.append( (hid, h) ) share_hash_chain = dict(share_hash_chain) - block_hash_tree_s = data[o['block_hash_tree']:o['share_data']] + block_hash_tree_s = hash_and_data[o_block_hash_tree:o_share_data] assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s) block_hash_tree = [] for i in range(0, len(block_hash_tree_s), 32): block_hash_tree.append(block_hash_tree_s[i:i+32]) - share_data = data[o['share_data']:o['enc_privkey']] + share_data = hash_and_data[o_share_data:o_enc_privkey] - return (seqnum, root_hash, IV, k, N, segsize, datalen, - pubkey, signature, share_hash_chain, block_hash_tree, - share_data) + return (share_hash_chain, block_hash_tree, share_data) def pack_checkstring(seqnum, root_hash, IV): @@ -250,6 +251,107 @@ def pack_share(prefix, verification_key, signature, encprivkey]) return final_share +class ServerMap: + """I record the placement of mutable shares. + + This object records which shares (of various versions) are located on + which servers. + + One purpose I serve is to inform callers about which versions of the + mutable file are recoverable and 'current'. + + A second purpose is to serve as a state marker for test-and-set + operations. I am passed out of retrieval operations and back into publish + operations, which means 'publish this new version, but only if nothing + has changed since I last retrieved this data'. This reduces the chances + of clobbering a simultaneous (uncoordinated) write. + """ + + def __init__(self): + # 'servermap' maps peerid to sets of (shnum, versionid, timestamp) + # tuples. Each 'versionid' is a (seqnum, root_hash, IV, segsize, + # datalength, k, N, signed_prefix, offsets) tuple + self.servermap = DictOfSets() + self.connections = {} # maps peerid to a RemoteReference + self.problems = [] # mostly for debugging + + def make_versionmap(self): + """Return a dict that maps versionid to sets of (shnum, peerid, + timestamp) tuples.""" + versionmap = DictOfSets() + for (peerid, shares) in self.servermap.items(): + for (shnum, verinfo, timestamp) in shares: + versionmap.add(verinfo, (shnum, peerid, timestamp)) + return versionmap + + def shares_available(self): + """Return a dict that maps versionid to tuples of + (num_distinct_shares, k) tuples.""" + versionmap = self.make_versionmap() + all_shares = {} + for versionid, shares in versionmap.items(): + s = set() + for (shnum, peerid, timestamp) in shares: + s.add(shnum) + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = versionid + all_shares[versionid] = (len(s), k) + return all_shares + + def recoverable_versions(self): + """Return a set of versionids, one for each version that is currently + recoverable.""" + versionmap = self.make_versionmap() + + recoverable_versions = set() + for (verinfo, shares) in versionmap.items(): + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = verinfo + shnums = set([shnum for (shnum, peerid, timestamp) in shares]) + if len(shnums) >= k: + # this one is recoverable + recoverable_versions.add(verinfo) + + return recoverable_versions + + def unrecoverable_versions(self): + """Return a set of versionids, one for each version that is currently + unrecoverable.""" + versionmap = self.make_versionmap() + + unrecoverable_versions = set() + for (verinfo, shares) in versionmap.items(): + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = verinfo + shnums = set([shnum for (shnum, peerid, timestamp) in shares]) + if len(shnums) < k: + unrecoverable_versions.add(verinfo) + + return unrecoverable_versions + + def best_recoverable_version(self): + """Return a single versionid, for the so-called 'best' recoverable + version. Sequence number is the primary sort criteria, followed by + root hash. Returns None if there are no recoverable versions.""" + recoverable = list(self.recoverable_versions()) + recoverable.sort() + if recoverable: + return recoverable[-1] + return None + + def unrecoverable_newer_versions(self): + # Return a dict of versionid -> health, for versions that are + # unrecoverable and have later seqnums than any recoverable versions. + # These indicate that a write will lose data. + pass + + def needs_merge(self): + # return True if there are multiple recoverable versions with the + # same seqnum, meaning that MutableFileNode.read_best_version is not + # giving you the whole story, and that using its data to do a + # subsequent publish will lose information. + pass + class RetrieveStatus: implements(IRetrieveStatus) @@ -309,149 +411,143 @@ class RetrieveStatus: def set_active(self, value): self.active = value -class Retrieve: - def __init__(self, filenode): +MODE_CHECK = "query all peers" +MODE_ANYTHING = "one recoverable version" +MODE_WRITE = "replace all shares, probably" # not for initial creation +MODE_ENOUGH = "enough" + +class ServermapUpdater: + def __init__(self, filenode, servermap, mode=MODE_ENOUGH): self._node = filenode - self._contents = None - # if the filenode already has a copy of the pubkey, use it. Otherwise - # we'll grab a copy from the first peer we talk to. - self._pubkey = filenode.get_pubkey() + self._servermap = servermap + self.mode = mode + self._running = True + self._storage_index = filenode.get_storage_index() - self._readkey = filenode.get_readkey() self._last_failure = None - self._log_number = None - self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5] - num = self._node._client.log("Retrieve(%s): starting" % prefix) - self._log_number = num - self._status = RetrieveStatus() - self._status.set_storage_index(self._storage_index) - self._status.set_helper(False) - self._status.set_progress(0.0) - self._status.set_active(True) - # how much data should be read on the first fetch? It would be nice - # if we could grab small directories in a single RTT. The way we pack - # dirnodes consumes about 112 bytes per child. The way we pack - # mutable files puts about 935 bytes of pubkey+sig+hashes, then our - # data, then about 1216 bytes of encprivkey. So 2kB ought to get us - # about 9 entries, which seems like a good default. + + # how much data should we read? + # * if we only need the checkstring, then [0:75] + # * if we need to validate the checkstring sig, then [543ish:799ish] + # * if we need the verification key, then [107:436ish] + # * the offset table at [75:107] tells us about the 'ish' + # A future version of the SMDF slot format should consider using + # fixed-size slots so we can retrieve less data. For now, we'll just + # read 2000 bytes, which also happens to read enough actual data to + # pre-fetch a 9-entry dirnode. self._read_size = 2000 + if mode == MODE_CHECK: + # we use unpack_prefix_and_signature, so we need 1k + self._read_size = 1000 - def log(self, msg, **kwargs): - prefix = self._log_prefix - num = self._node._client.log("Retrieve(%s): %s" % (prefix, msg), - parent=self._log_number, **kwargs) - return num - - def log_err(self, f): - num = log.err(f, parent=self._log_number) - return num - - def retrieve(self): - """Retrieve the filenode's current contents. Returns a Deferred that - fires with a string when the contents have been retrieved.""" - - # 1: make a guess as to how many peers we should send requests to. We - # want to hear from k+EPSILON (k because we have to, EPSILON extra - # because that helps us resist rollback attacks). [TRADEOFF: - # EPSILON>0 means extra work] [TODO: implement EPSILON>0] - # 2: build the permuted peerlist, taking the first k+E peers - # 3: send readv requests to all of them in parallel, asking for the - # first 2KB of data from all shares - # 4: when the first of the responses comes back, extract information: - # 4a: extract the pubkey, hash it, compare against the URI. If this - # check fails, log a WEIRD and ignore the peer. - # 4b: extract the prefix (seqnum, roothash, k, N, segsize, datalength) - # and verify the signature on it. If this is wrong, log a WEIRD - # and ignore the peer. Save the prefix string in a dict that's - # keyed by (seqnum,roothash) and has (prefixstring, sharemap) as - # values. We'll use the prefixstring again later to avoid doing - # multiple signature checks - # 4c: extract the share size (offset of the last byte of sharedata). - # if it is larger than 2k, send new readv requests to pull down - # the extra data - # 4d: if the extracted 'k' is more than we guessed, rebuild a larger - # permuted peerlist and send out more readv requests. - # 5: as additional responses come back, extract the prefix and compare - # against the ones we've already seen. If they match, add the - # peerid to the corresponing sharemap dict - # 6: [TRADEOFF]: if EPSILON==0, when we get k responses for the - # same (seqnum,roothash) key, attempt to reconstruct that data. - # if EPSILON>0, wait for k+EPSILON responses, then attempt to - # reconstruct the most popular version.. If we do not have enough - # shares and there are still requests outstanding, wait. If there - # are not still requests outstanding (todo: configurable), send - # more requests. Never send queries to more than 2*N servers. If - # we've run out of servers, fail. - # 7: if we discover corrupt shares during the reconstruction process, - # remove that share from the sharemap. and start step#6 again. - - initial_query_count = 5 - - # self._valid_versions is a dictionary in which the keys are - # 'verinfo' tuples (seqnum, root_hash, IV, segsize, datalength, k, - # N). Every time we hear about a new potential version of the file, - # we check its signature, and the valid ones are added to this - # dictionary. The values of the dictionary are (prefix, sharemap) - # tuples, where 'prefix' is just the first part of the share - # (containing the serialized verinfo), for easier comparison. - # 'sharemap' is a DictOfSets, in which the keys are sharenumbers, and - # the values are sets of (peerid, data) tuples. There is a (peerid, - # data) tuple for every instance of a given share that we've seen. - # The 'data' in this tuple is a full copy of the SDMF share, starting - # with the \x00 version byte and continuing through the last byte of - # sharedata. - self._valid_versions = {} - - # self._valid_shares is a dict mapping (peerid,data) tuples to - # validated sharedata strings. Each time we examine the hash chains - # inside a share and validate them against a signed root_hash, we add - # the share to self._valid_shares . We use this to avoid re-checking - # the hashes over and over again. - self._valid_shares = {} + prefix = storage.si_b2a(self._storage_index)[:5] + self._log_number = log.msg("SharemapUpdater(%s): starting" % prefix) + + def log(self, *args, **kwargs): + if "parent" not in kwargs: + kwargs["parent"] = self._log_number + return log.msg(*args, **kwargs) + + def update(self): + """Update the servermap to reflect current conditions. Returns a + Deferred that fires with the servermap once the update has finished.""" + + # self._valid_versions is a set of validated verinfo tuples. We just + # use it to remember which versions had valid signatures, so we can + # avoid re-checking the signatures for each share. + self._valid_versions = set() + + # self.versionmap maps verinfo tuples to sets of (shnum, peerid, + # timestamp) tuples. This is used to figure out which versions might + # be retrievable, and to make the eventual data download faster. + self.versionmap = DictOfSets() self._started = time.time() self._done_deferred = defer.Deferred() - d = defer.succeed(initial_query_count) - d.addCallback(self._choose_initial_peers) + # first, which peers should be talk to? Any that were in our old + # servermap, plus "enough" others. + + self._queries_completed = 0 + + client = self._node._client + full_peerlist = client.get_permuted_peers("storage", + self._node._storage_index) + self.full_peerlist = full_peerlist # for use later, immutable + self.extra_peers = full_peerlist[:] # peers are removed as we use them + self._good_peers = set() # peers who had some shares + self._empty_peers = set() # peers who don't have any shares + self._bad_peers = set() # peers to whom our queries failed + + k = self._node.get_required_shares() + if k is None: + # make a guess + k = 3 + N = self._node.get_required_shares() + if N is None: + N = 10 + self.EPSILON = k + # we want to send queries to at least this many peers (although we + # might not wait for all of their answers to come back) + self.num_peers_to_query = k + self.EPSILON + + # TODO: initial_peers_to_query needs to be ordered list of (peerid, + # ss) tuples + + if self.mode == MODE_CHECK: + initial_peers_to_query = dict(full_peerlist) + must_query = set(initial_peers_to_query.keys()) + self.extra_peers = [] + elif self.mode == MODE_WRITE: + # we're planning to replace all the shares, so we want a good + # chance of finding them all. We will keep searching until we've + # seen epsilon that don't have a share. + self.num_peers_to_query = N + self.EPSILON + initial_peers_to_query, must_query = self._build_initial_querylist() + self.required_num_empty_peers = self.EPSILON + else: + initial_peers_to_query, must_query = self._build_initial_querylist() + + # this is a set of peers that we are required to get responses from: + # they are peers who used to have a share, so we need to know where + # they currently stand, even if that means we have to wait for a + # silently-lost TCP connection to time out. We remove peers from this + # set as we get responses. + self._must_query = must_query + + # now initial_peers_to_query contains the peers that we should ask, + # self.must_query contains the peers that we must have heard from + # before we can consider ourselves finished, and self.extra_peers + # contains the overflow (peers that we should tap if we don't get + # enough responses) + + d = defer.succeed(initial_peers_to_query) d.addCallback(self._send_initial_requests) - d.addCallback(self._wait_for_finish) + d.addCallback(lambda res: self._done_deferred) return d - def _wait_for_finish(self, res): - return self._done_deferred + def _build_initial_querylist(self): + initial_peers_to_query = {} + must_query = set() + for peerid in self._servermap.servermap.keys(): + ss = self._servermap.connections[peerid] + # we send queries to everyone who was already in the sharemap + initial_peers_to_query[peerid] = ss + # and we must wait for responses from them + must_query.add(peerid) - def _choose_initial_peers(self, numqueries): - n = self._node - started = time.time() - full_peerlist = n._client.get_permuted_peers("storage", - self._storage_index) - - # _peerlist is a list of (peerid,conn) tuples for peers that are - # worth talking too. This starts with the first numqueries in the - # permuted list. If that's not enough to get us a recoverable - # version, we expand this to include the first 2*total_shares peerids - # (assuming we learn what total_shares is from one of the first - # numqueries peers) - self._peerlist = [p for p in islice(full_peerlist, numqueries)] - # _peerlist_limit is the query limit we used to build this list. If - # we later increase this limit, it may be useful to re-scan the - # permuted list. - self._peerlist_limit = numqueries - self._status.set_search_distance(len(self._peerlist)) - elapsed = time.time() - started - self._status.timings["peer_selection"] = elapsed - return self._peerlist + while ((self.num_peers_to_query > len(initial_peers_to_query)) + and self.extra_peers): + (peerid, ss) = self.extra_peers.pop(0) + initial_peers_to_query[peerid] = ss + + return initial_peers_to_query, must_query def _send_initial_requests(self, peerlist): - self._first_query_sent = time.time() - self._bad_peerids = set() - self._running = True self._queries_outstanding = set() - self._used_peers = set() self._sharemap = DictOfSets() # shnum -> [(peerid, seqnum, R)..] dl = [] - for (peerid, ss) in peerlist: + for (peerid, ss) in peerlist.items(): self._queries_outstanding.add(peerid) self._do_query(ss, peerid, self._storage_index, self._read_size) @@ -465,6 +561,11 @@ class Retrieve: return d def _do_query(self, ss, peerid, storage_index, readsize): + self.log(format="sending query to [%(peerid)s], readsize=%(readsize)d", + peerid=idlib.shortnodeid_b2a(peerid), + readsize=readsize, + level=log.NOISY) + self._servermap.connections[peerid] = ss started = time.time() self._queries_outstanding.add(peerid) d = self._do_read(ss, peerid, storage_index, [], [(0, readsize)]) @@ -475,6 +576,7 @@ class Retrieve: # _query_failed) get logged, but we still want to check for doneness. d.addErrback(log.err) d.addBoth(self._check_for_done) + d.addErrback(log.err) return d def _deserialize_pubkey(self, pubkey_s): @@ -482,81 +584,68 @@ class Retrieve: return verifier def _got_results(self, datavs, peerid, readsize, stuff, started): + self.log(format="got result from [%(peerid)s], %(numshares)d shares", + peerid=idlib.shortnodeid_b2a(peerid), + numshares=len(datavs), + level=log.NOISY) self._queries_outstanding.discard(peerid) - self._used_peers.add(peerid) + self._must_query.discard(peerid) + self._queries_completed += 1 if not self._running: + self.log("but we're not running, so we'll ignore it") return - elapsed = time.time() - started - if peerid not in self._status.timings["fetch_per_server"]: - self._status.timings["fetch_per_server"][peerid] = [] - self._status.timings["fetch_per_server"][peerid].append(elapsed) - - if peerid not in self._status.sharemap: - self._status.sharemap[peerid] = set() + if datavs: + self._good_peers.add(peerid) + else: + self._empty_peers.add(peerid) for shnum,datav in datavs.items(): data = datav[0] try: self._got_results_one_share(shnum, data, peerid) - except NeedMoreDataError, e: - # ah, just re-send the query then. - self.log("need more data from %(peerid)s, got %(got)d, need %(needed)d", - peerid=idlib.shortnodeid_b2a(peerid), - got=len(data), needed=e.needed_bytes, - level=log.NOISY) - self._read_size = max(self._read_size, e.needed_bytes) - # TODO: for MDMF, sanity-check self._read_size: don't let one - # server cause us to try to read gigabytes of data from all - # other servers. - (ss, storage_index) = stuff - self._do_query(ss, peerid, storage_index, self._read_size) - return except CorruptShareError, e: # log it and give the other shares a chance to be processed f = failure.Failure() self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD) - self._bad_peerids.add(peerid) + self._bad_peers.add(peerid) self._last_failure = f + self._servermap.problems.append(f) pass # all done! + self.log("DONE") def _got_results_one_share(self, shnum, data, peerid): - self.log("_got_results: got shnum #%d from peerid %s" - % (shnum, idlib.shortnodeid_b2a(peerid))) + self.log(format="_got_results: got shnum #%(shnum)d from peerid %(peerid)s", + shnum=shnum, + peerid=idlib.shortnodeid_b2a(peerid)) - # this might raise NeedMoreDataError, in which case the rest of - # the shares are probably short too. _query_failed() will take - # responsiblity for re-issuing the queries with a new length. + # this might raise NeedMoreDataError, if the pubkey and signature + # live at some weird offset. That shouldn't happen, so I'm going to + # treat it as a bad share. (seqnum, root_hash, IV, k, N, segsize, datalength, pubkey_s, signature, prefix) = unpack_prefix_and_signature(data) - if not self._pubkey: + if not self._node._pubkey: fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s) assert len(fingerprint) == 32 if fingerprint != self._node._fingerprint: - self._status.problems[peerid] = "sh#%d: pubkey doesn't match fingerprint" % shnum raise CorruptShareError(peerid, shnum, "pubkey doesn't match fingerprint") - self._pubkey = self._deserialize_pubkey(pubkey_s) - self._node._populate_pubkey(self._pubkey) + self._node._pubkey = self._deserialize_pubkey(pubkey_s) - verinfo = (seqnum, root_hash, IV, segsize, datalength, k, N) - self._status.sharemap[peerid].add(verinfo) + (ig_version, ig_seqnum, ig_root_hash, ig_IV, ig_k, ig_N, + ig_segsize, ig_datalen, offsets) = unpack_header(data) + offsets_tuple = tuple( [(key,value) for key,value in offsets.items()] ) + + verinfo = (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) if verinfo not in self._valid_versions: # it's a new pair. Verify the signature. - started = time.time() - valid = self._pubkey.verify(prefix, signature) - # this records the total verification time for all versions we've - # seen. This time is included in "fetch". - elapsed = time.time() - started - self._status.timings["cumulative_verify"] += elapsed - + valid = self._node._pubkey.verify(prefix, signature) if not valid: - self._status.problems[peerid] = "sh#%d: invalid signature" % shnum - raise CorruptShareError(peerid, shnum, - "signature is invalid") + raise CorruptShareError(peerid, shnum, "signature is invalid") # ok, it's a valid verinfo. Add it to the list of validated # versions. @@ -564,206 +653,378 @@ class Retrieve: % (seqnum, base32.b2a(root_hash)[:4], idlib.shortnodeid_b2a(peerid), shnum, k, N, segsize, datalength)) - self._valid_versions[verinfo] = (prefix, DictOfSets()) + self._valid_versions.add(verinfo) + # We now know that this is a valid candidate verinfo. - # We now know that this is a valid candidate verinfo. Accumulate the - # share info, if there's enough data present. If not, raise - # NeedMoreDataError, which will trigger a re-fetch. - _ignored = unpack_share_data(data) - self.log(" found enough data to add share contents") - self._valid_versions[verinfo][1].add(shnum, (peerid, data)) + # Add the info to our servermap. + timestamp = time.time() + self._servermap.servermap.add(peerid, (shnum, verinfo, timestamp)) + # and the versionmap + self.versionmap.add(verinfo, (shnum, peerid, timestamp)) def _query_failed(self, f, peerid): + self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD) if not self._running: return + self._must_query.discard(peerid) self._queries_outstanding.discard(peerid) - self._used_peers.add(peerid) + self._bad_peers.add(peerid) + self._servermap.problems.append(f) + self._queries_completed += 1 self._last_failure = f - self._bad_peerids.add(peerid) - self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD) def _check_for_done(self, res): - if not self._running: - self.log("ODD: _check_for_done but we're not running") + # exit paths: + # return self._send_more_queries(outstanding) : send some more queries + # return self._done() : all done + # return : keep waiting, no new queries + + self.log(format=("_check_for_done, mode is '%(mode)s', " + "%(outstanding)d queries outstanding, " + "%(extra)d extra peers available, " + "%(must)d 'must query' peers left" + ), + mode=self.mode, + outstanding=len(self._queries_outstanding), + extra=len(self.extra_peers), + must=len(self._must_query), + ) + + if self._must_query: + # we are still waiting for responses from peers that used to have + # a share, so we must continue to wait. No additional queries are + # required at this time. + self.log("%d 'must query' peers left" % len(self._must_query)) return - share_prefixes = {} - versionmap = DictOfSets() - max_N = 0 - for verinfo, (prefix, sharemap) in self._valid_versions.items(): - # sharemap is a dict that maps shnums to sets of (peerid,data). - # len(sharemap) is the number of distinct shares that appear to - # be available. - (seqnum, root_hash, IV, segsize, datalength, k, N) = verinfo - max_N = max(max_N, N) - if len(sharemap) >= k: - # this one looks retrievable. TODO: our policy of decoding - # the first version that we can get is a bit troublesome: in - # a small grid with a large expansion factor, a single - # out-of-date server can cause us to retrieve an older - # version. Fixing this is equivalent to protecting ourselves - # against a rollback attack, and the best approach is - # probably to say that we won't do _attempt_decode until: - # (we've received at least k+EPSILON shares or - # we've received at least k shares and ran out of servers) - # in that case, identify the verinfos that are decodeable and - # attempt the one with the highest (seqnum,R) value. If the - # highest seqnum can't be recovered, only then might we fall - # back to an older version. - d = defer.maybeDeferred(self._attempt_decode, verinfo, sharemap) - def _problem(f): - self._last_failure = f - if f.check(CorruptShareError): - self.log("saw corrupt share, rescheduling", - level=log.WEIRD) - # _attempt_decode is responsible for removing the bad - # share, so we can just try again - eventually(self._check_for_done, None) - return - return f - d.addCallbacks(self._done, _problem) - # TODO: create an errback-routing mechanism to make sure that - # weird coding errors will cause the retrieval to fail rather - # than hanging forever. Any otherwise-unhandled exceptions - # should follow this path. A simple way to test this is to - # raise BadNameError in _validate_share_and_extract_data . - return - # we don't have enough shares yet. Should we send out more queries? - if self._queries_outstanding: - # there are some running, so just wait for them to come back. - # TODO: if our initial guess at k was too low, waiting for these - # responses before sending new queries will increase our latency, - # so we could speed things up by sending new requests earlier. - self.log("ROUTINE: %d queries outstanding" % - len(self._queries_outstanding)) - return + if (not self._queries_outstanding and not self.extra_peers): + # all queries have retired, and we have no peers left to ask. No + # more progress can be made, therefore we are done. + self.log("all queries are retired, no extra peers: done") + return self._done() + + recoverable_versions = self._servermap.recoverable_versions() + unrecoverable_versions = self._servermap.unrecoverable_versions() + + # what is our completion policy? how hard should we work? + + if self.mode == MODE_ANYTHING: + if recoverable_versions: + self.log("MODE_ANYTHING and %d recoverable versions: done" + % len(recoverable_versions)) + return self._done() + + if self.mode == MODE_CHECK: + # we used self._must_query, and we know there aren't any + # responses still waiting, so that means we must be done + self.log("MODE_CHECK: done") + return self._done() + + MAX_IN_FLIGHT = 5 + if self.mode == MODE_ENOUGH: + # if we've queried k+epsilon servers, and we see a recoverable + # version, and we haven't seen any unrecoverable higher-seqnum'ed + # versions, then we're done. + + if self._queries_completed < self.num_peers_to_query: + self.log(format="ENOUGH, %(completed)d completed, %(query)d to query: need more", + completed=self._queries_completed, + query=self.num_peers_to_query) + return self._send_more_queries(MAX_IN_FLIGHT) + if not recoverable_versions: + self.log("ENOUGH, no recoverable versions: need more") + return self._send_more_queries(MAX_IN_FLIGHT) + highest_recoverable = max(recoverable_versions) + highest_recoverable_seqnum = highest_recoverable[0] + for unrec_verinfo in unrecoverable_versions: + if unrec_verinfo[0] > highest_recoverable_seqnum: + # there is evidence of a higher-seqnum version, but we + # don't yet see enough shares to recover it. Try harder. + # TODO: consider sending more queries. + # TODO: consider limiting the search distance + self.log("ENOUGH, evidence of higher seqnum: need more") + return self._send_more_queries(MAX_IN_FLIGHT) + # all the unrecoverable versions were old or concurrent with a + # recoverable version. Good enough. + self.log("ENOUGH: no higher-seqnum: done") + return self._done() + + if self.mode == MODE_WRITE: + # we want to keep querying until we've seen a few that don't have + # any shares, to be sufficiently confident that we've seen all + # the shares. This is still less work than MODE_CHECK, which asks + # every server in the world. + + if not recoverable_versions: + self.log("WRITE, no recoverable versions: need more") + return self._send_more_queries(MAX_IN_FLIGHT) + + last_found = -1 + last_not_responded = -1 + num_not_responded = 0 + num_not_found = 0 + states = [] + for i,(peerid,ss) in enumerate(self.full_peerlist): + if peerid in self._bad_peers: + # query failed + states.append("x") + #self.log("loop [%s]: x" % idlib.shortnodeid_b2a(peerid)) + elif peerid in self._empty_peers: + # no shares + states.append("0") + #self.log("loop [%s]: 0" % idlib.shortnodeid_b2a(peerid)) + if last_found != -1: + num_not_found += 1 + if num_not_found >= self.EPSILON: + self.log("MODE_WRITE: found our boundary, %s" % + "".join(states)) + # we need to know that we've gotten answers from + # everybody to the left of here + if last_not_responded == -1: + # we're done + self.log("have all our answers") + return self._done() + # still waiting for somebody + return self._send_more_queries(num_not_responded) + + elif peerid in self._good_peers: + # yes shares + states.append("1") + #self.log("loop [%s]: 1" % idlib.shortnodeid_b2a(peerid)) + last_found = i + num_not_found = 0 + else: + # not responded yet + states.append("?") + #self.log("loop [%s]: ?" % idlib.shortnodeid_b2a(peerid)) + last_not_responded = i + num_not_responded += 1 + + # if we hit here, we didn't find our boundary, so we're still + # waiting for peers + self.log("MODE_WRITE: no boundary yet, %s" % "".join(states)) + return self._send_more_queries(MAX_IN_FLIGHT) + + # otherwise, keep up to 5 queries in flight. TODO: this is pretty + # arbitrary, really I want this to be something like k - + # max(known_version_sharecounts) + some extra + self.log("catchall: need more") + return self._send_more_queries(MAX_IN_FLIGHT) + + def _send_more_queries(self, num_outstanding): + assert self.extra_peers # we shouldn't get here with nothing in reserve + more_queries = [] + + while True: + self.log(" there are %d queries outstanding" % len(self._queries_outstanding)) + active_queries = len(self._queries_outstanding) + len(more_queries) + if active_queries >= num_outstanding: + break + if not self.extra_peers: + break + more_queries.append(self.extra_peers.pop(0)) + + self.log(format="sending %(more)d more queries: %(who)s", + more=len(more_queries), + who=" ".join(["[%s]" % idlib.shortnodeid_b2a(peerid) + for (peerid,ss) in more_queries]), + level=log.NOISY) - # no more queries are outstanding. Can we send out more? First, - # should we be looking at more peers? - self.log("need more peers: max_N=%s, peerlist=%d peerlist_limit=%d" % - (max_N, len(self._peerlist), - self._peerlist_limit), level=log.UNUSUAL) - if max_N: - search_distance = max_N * 2 - else: - search_distance = 20 - self.log("search_distance=%d" % search_distance, level=log.UNUSUAL) - if self._peerlist_limit < search_distance: - # we might be able to get some more peers from the list - peers = self._node._client.get_permuted_peers("storage", - self._storage_index) - self._peerlist = [p for p in islice(peers, search_distance)] - self._peerlist_limit = search_distance - self.log("added peers, peerlist=%d, peerlist_limit=%d" - % (len(self._peerlist), self._peerlist_limit), - level=log.UNUSUAL) - # are there any peers on the list that we haven't used? - new_query_peers = [] - peer_indicies = [] - for i, (peerid, ss) in enumerate(self._peerlist): - if peerid not in self._used_peers: - new_query_peers.append( (peerid, ss) ) - peer_indicies.append(i) - if len(new_query_peers) > 5: - # only query in batches of 5. TODO: this is pretty - # arbitrary, really I want this to be something like - # k - max(known_version_sharecounts) + some extra - break - if new_query_peers: - self.log("sending %d new queries (read %d bytes)" % - (len(new_query_peers), self._read_size), level=log.UNUSUAL) - new_search_distance = max(max(peer_indicies), - self._status.get_search_distance()) - self._status.set_search_distance(new_search_distance) - for (peerid, ss) in new_query_peers: - self._do_query(ss, peerid, self._storage_index, self._read_size) + for (peerid, ss) in more_queries: + self._do_query(ss, peerid, self._storage_index, self._read_size) # we'll retrigger when those queries come back + + def _done(self): + if not self._running: return + self._running = False + # the servermap will not be touched after this + eventually(self._done_deferred.callback, self._servermap) + + +class Marker: + pass + +class Retrieve: + # this class is currently single-use. Eventually (in MDMF) we will make + # it multi-use, in which case you can call download(range) multiple + # times, and each will have a separate response chain. However the + # Retrieve object will remain tied to a specific version of the file, and + # will use a single ServerMap instance. + + def __init__(self, filenode, servermap, verinfo): + self._node = filenode + assert self._node._pubkey + self._storage_index = filenode.get_storage_index() + assert self._node._readkey + self._last_failure = None + prefix = storage.si_b2a(self._storage_index)[:5] + self._log_number = log.msg("Retrieve(%s): starting" % prefix) + self._outstanding_queries = {} # maps (peerid,shnum) to start_time + self._running = True + self._decoding = False + + self.servermap = servermap + assert self._node._pubkey + self.verinfo = verinfo + + def log(self, *args, **kwargs): + if "parent" not in kwargs: + kwargs["parent"] = self._log_number + return log.msg(*args, **kwargs) + + def download(self): + self._done_deferred = defer.Deferred() + + # first, which servers can we use? + versionmap = self.servermap.make_versionmap() + shares = versionmap[self.verinfo] + # this sharemap is consumed as we decide to send requests + self.remaining_sharemap = DictOfSets() + for (shnum, peerid, timestamp) in shares: + self.remaining_sharemap.add(shnum, peerid) + + self.shares = {} # maps shnum to validated blocks + + # how many shares do we need? + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = self.verinfo + assert len(self.remaining_sharemap) >= k + # we start with the lowest shnums we have available, since FEC is + # faster if we're using "primary shares" + self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k]) + for shnum in self.active_shnums: + # we use an arbitrary peer who has the share. If shares are + # doubled up (more than one share per peer), we could make this + # run faster by spreading the load among multiple peers. But the + # algorithm to do that is more complicated than I want to write + # right now, and a well-provisioned grid shouldn't have multiple + # shares per peer. + peerid = list(self.remaining_sharemap[shnum])[0] + self.get_data(shnum, peerid) - # we've used up all the peers we're allowed to search. Failure. - self.log("ran out of peers", level=log.WEIRD) - e = NotEnoughPeersError("last failure: %s" % self._last_failure) - return self._done(failure.Failure(e)) - - def _attempt_decode(self, verinfo, sharemap): - # sharemap is a dict which maps shnum to [(peerid,data)..] sets. - (seqnum, root_hash, IV, segsize, datalength, k, N) = verinfo - - assert len(sharemap) >= k, len(sharemap) - - shares_s = [] - for shnum in sorted(sharemap.keys()): - for shareinfo in sharemap[shnum]: - shares_s.append("#%d" % shnum) - shares_s = ",".join(shares_s) - self.log("_attempt_decode: version %d-%s, shares: %s" % - (seqnum, base32.b2a(root_hash)[:4], shares_s)) - - # first, validate each share that we haven't validated yet. We use - # self._valid_shares to remember which ones we've already checked. - - shares = {} - for shnum, shareinfos in sharemap.items(): - assert len(shareinfos) > 0 - for shareinfo in shareinfos: - # have we already validated the hashes on this share? - if shareinfo not in self._valid_shares: - # nope: must check the hashes and extract the actual data - (peerid,data) = shareinfo - try: - # The (seqnum+root_hash+IV) tuple for this share was - # already verified: specifically, all shares in the - # sharemap have a (seqnum+root_hash+IV) pair that was - # present in a validly signed prefix. The remainder - # of the prefix for this particular share has *not* - # been validated, but we don't care since we don't - # use it. self._validate_share() is required to check - # the hashes on the share data (and hash chains) to - # make sure they match root_hash, but is not required - # (and is in fact prohibited, because we don't - # validate the prefix on all shares) from using - # anything else in the share. - validator = self._validate_share_and_extract_data - sharedata = validator(peerid, root_hash, shnum, data) - assert isinstance(sharedata, str) - except CorruptShareError, e: - self.log("share was corrupt: %s" % e, level=log.WEIRD) - sharemap[shnum].discard(shareinfo) - if not sharemap[shnum]: - # remove the key so the test in _check_for_done - # can accurately decide that we don't have enough - # shares to try again right now. - del sharemap[shnum] - # If there are enough remaining shares, - # _check_for_done() will try again - raise - # share is valid: remember it so we won't need to check - # (or extract) it again - self._valid_shares[shareinfo] = sharedata - - # the share is now in _valid_shares, so just copy over the - # sharedata - shares[shnum] = self._valid_shares[shareinfo] - - # now that the big loop is done, all shares in the sharemap are - # valid, and they're all for the same seqnum+root_hash version, so - # it's now down to doing FEC and decrypt. - elapsed = time.time() - self._started - self._status.timings["fetch"] = elapsed - assert len(shares) >= k, len(shares) - d = defer.maybeDeferred(self._decode, shares, segsize, datalength, k, N) - d.addCallback(self._decrypt, IV, seqnum, root_hash) + # control flow beyond this point: state machine. Receiving responses + # from queries is the input. We might send out more queries, or we + # might produce a result. + + return self._done_deferred + + def get_data(self, shnum, peerid): + self.log(format="sending sh#%(shnum)d request to [%(peerid)s]", + shnum=shnum, + peerid=idlib.shortnodeid_b2a(peerid), + level=log.NOISY) + ss = self.servermap.connections[peerid] + started = time.time() + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = self.verinfo + offsets = dict(offsets_tuple) + # we read the checkstring, to make sure that the data we grab is from + # the right version. We also read the data, and the hashes necessary + # to validate them (share_hash_chain, block_hash_tree, share_data). + # We don't read the signature or the pubkey, since that was handled + # during the servermap phase, and we'll be comparing the share hash + # chain against the roothash that was validated back then. + readv = [ (0, struct.calcsize(SIGNED_PREFIX)), + (offsets['share_hash_chain'], + offsets['enc_privkey'] - offsets['share_hash_chain']), + ] + + m = Marker() + self._outstanding_queries[m] = (peerid, shnum, started) + + # ask the cache first + datav = [] + #for (offset, length) in readv: + # (data, timestamp) = self._node._cache.read(self.verinfo, shnum, + # offset, length) + # if data is not None: + # datav.append(data) + if len(datav) == len(readv): + self.log("got data from cache") + d = defer.succeed(datav) + else: + self.remaining_sharemap[shnum].remove(peerid) + d = self._do_read(ss, peerid, self._storage_index, [shnum], readv) + d.addCallback(self._fill_cache, readv) + + d.addCallback(self._got_results, m, peerid, started) + d.addErrback(self._query_failed, m, peerid) + # errors that aren't handled by _query_failed (and errors caused by + # _query_failed) get logged, but we still want to check for doneness. + def _oops(f): + self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s", + shnum=shnum, + peerid=idlib.shortnodeid_b2a(peerid), + failure=f, + level=log.WEIRD) + d.addErrback(_oops) + d.addBoth(self._check_for_done) + # any error during _check_for_done means the download fails. If the + # download is successful, _check_for_done will fire _done by itself. + d.addErrback(self._done) + d.addErrback(log.err) + return d # purely for testing convenience + + def _fill_cache(self, datavs, readv): + timestamp = time.time() + for shnum,datav in datavs.items(): + for i, (offset, length) in enumerate(readv): + data = datav[i] + self._node._cache.add(self.verinfo, shnum, offset, data, + timestamp) + return datavs + + def _do_read(self, ss, peerid, storage_index, shnums, readv): + # isolate the callRemote to a separate method, so tests can subclass + # Publish and override it + d = ss.callRemote("slot_readv", storage_index, shnums, readv) return d - def _validate_share_and_extract_data(self, peerid, root_hash, shnum, data): - # 'data' is the whole SMDF share - self.log("_validate_share_and_extract_data[%d]" % shnum) - assert data[0] == "\x00" - pieces = unpack_share_data(data) - (seqnum, root_hash_copy, IV, k, N, segsize, datalen, - pubkey, signature, share_hash_chain, block_hash_tree, - share_data) = pieces + def remove_peer(self, peerid): + for shnum in list(self.remaining_sharemap.keys()): + self.remaining_sharemap.discard(shnum, peerid) + + def _got_results(self, datavs, marker, peerid, started): + self.log(format="got results (%(shares)d shares) from [%(peerid)s]", + shares=len(datavs), + peerid=idlib.shortnodeid_b2a(peerid), + level=log.NOISY) + self._outstanding_queries.pop(marker, None) + if not self._running: + return + + # note that we only ask for a single share per query, so we only + # expect a single share back. On the other hand, we use the extra + # shares if we get them.. seems better than an assert(). + + for shnum,datav in datavs.items(): + (prefix, hash_and_data) = datav + try: + self._got_results_one_share(shnum, peerid, + prefix, hash_and_data) + except CorruptShareError, e: + # log it and give the other shares a chance to be processed + f = failure.Failure() + self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD) + self.remove_peer(peerid) + self._last_failure = f + pass + # all done! + + def _got_results_one_share(self, shnum, peerid, + got_prefix, got_hash_and_data): + self.log("_got_results: got shnum #%d from peerid %s" + % (shnum, idlib.shortnodeid_b2a(peerid))) + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = self.verinfo + assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix)) + if got_prefix != prefix: + msg = "someone wrote to the data since we read the servermap: prefix changed" + raise UncoordinatedWriteError(msg) + (share_hash_chain, block_hash_tree, + share_data) = unpack_share_data(self.verinfo, got_hash_and_data) assert isinstance(share_data, str) # build the block hash tree. SDMF has only one leaf. @@ -783,14 +1044,131 @@ class Retrieve: msg = "corrupt hashes: %s" % (e,) raise CorruptShareError(peerid, shnum, msg) self.log(" data valid! len=%d" % len(share_data)) - return share_data + # each query comes down to this: placing validated share data into + # self.shares + self.shares[shnum] = share_data + + def _query_failed(self, f, marker, peerid): + self.log(format="query to [%(peerid)s] failed", + peerid=idlib.shortnodeid_b2a(peerid), + level=log.NOISY) + self._outstanding_queries.pop(marker, None) + if not self._running: + return + self._last_failure = f + self.remove_peer(peerid) + self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD) + + def _check_for_done(self, res): + # exit paths: + # return : keep waiting, no new queries + # return self._send_more_queries(outstanding) : send some more queries + # fire self._done(plaintext) : download successful + # raise exception : download fails + + self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s", + running=self._running, decoding=self._decoding, + level=log.NOISY) + if not self._running: + return + if self._decoding: + return + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = self.verinfo + + if len(self.shares) < k: + # we don't have enough shares yet + return self._maybe_send_more_queries(k) + + # we have enough to finish. All the shares have had their hashes + # checked, so if something fails at this point, we don't know how + # to fix it, so the download will fail. + + self._decoding = True # avoid reentrancy + + d = defer.maybeDeferred(self._decode) + d.addCallback(self._decrypt, IV, self._node._readkey) + d.addBoth(self._done) + return d # purely for test convenience + + def _maybe_send_more_queries(self, k): + # we don't have enough shares yet. Should we send out more queries? + # There are some number of queries outstanding, each for a single + # share. If we can generate 'needed_shares' additional queries, we do + # so. If we can't, then we know this file is a goner, and we raise + # NotEnoughPeersError. + self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, " + "outstanding=%(outstanding)d"), + have=len(self.shares), k=k, + outstanding=len(self._outstanding_queries), + level=log.NOISY) + + remaining_shares = k - len(self.shares) + needed = remaining_shares - len(self._outstanding_queries) + if not needed: + # we have enough queries in flight already + + # TODO: but if they've been in flight for a long time, and we + # have reason to believe that new queries might respond faster + # (i.e. we've seen other queries come back faster, then consider + # sending out new queries. This could help with peers which have + # silently gone away since the servermap was updated, for which + # we're still waiting for the 15-minute TCP disconnect to happen. + self.log("enough queries are in flight, no more are needed", + level=log.NOISY) + return - def _decode(self, shares_dict, segsize, datalength, k, N): + outstanding_shnums = set([shnum + for (peerid, shnum, started) + in self._outstanding_queries.values()]) + # prefer low-numbered shares, they are more likely to be primary + available_shnums = sorted(self.remaining_sharemap.keys()) + for shnum in available_shnums: + if shnum in outstanding_shnums: + # skip ones that are already in transit + continue + if shnum not in self.remaining_sharemap: + # no servers for that shnum. note that DictOfSets removes + # empty sets from the dict for us. + continue + peerid = list(self.remaining_sharemap[shnum])[0] + # get_data will remove that peerid from the sharemap, and add the + # query to self._outstanding_queries + self.get_data(shnum, peerid) + needed -= 1 + if not needed: + break + + # at this point, we have as many outstanding queries as we can. If + # needed!=0 then we might not have enough to recover the file. + if needed: + format = ("ran out of peers: " + "have %(have)d shares (k=%(k)d), " + "%(outstanding)d queries in flight, " + "need %(need)d more") + self.log(format=format, + have=len(self.shares), k=k, + outstanding=len(self._outstanding_queries), + need=needed, + level=log.WEIRD) + msg2 = format % {"have": len(self.shares), + "k": k, + "outstanding": len(self._outstanding_queries), + "need": needed, + } + raise NotEnoughPeersError("%s, last failure: %s" % + (msg2, self._last_failure)) + + return + + def _decode(self): + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = self.verinfo # shares_dict is a dict mapping shnum to share data, but the codec # wants two lists. shareids = []; shares = [] - for shareid, share in shares_dict.items(): + for shareid, share in self.shares.items(): shareids.append(shareid) shares.append(share) @@ -805,17 +1183,8 @@ class Retrieve: self.log("params %s, we have %d shares" % (params, len(shares))) self.log("about to decode, shareids=%s" % (shareids,)) - started = time.time() d = defer.maybeDeferred(fec.decode, shares, shareids) def _done(buffers): - elapsed = time.time() - started - self._status.timings["decode"] = elapsed - self._status.set_encoding(k, N) - - # stash these in the MutableFileNode to speed up the next pass - self._node._populate_required_shares(k) - self._node._populate_total_shares(N) - self.log(" decode done, %d buffers" % len(buffers)) segment = "".join(buffers) self.log(" joined length %d, datalength %d" % @@ -830,41 +1199,24 @@ class Retrieve: d.addErrback(_err) return d - def _decrypt(self, crypttext, IV, seqnum, root_hash): + def _decrypt(self, crypttext, IV, readkey): started = time.time() - key = hashutil.ssk_readkey_data_hash(IV, self._readkey) + key = hashutil.ssk_readkey_data_hash(IV, readkey) decryptor = AES(key) plaintext = decryptor.process(crypttext) - elapsed = time.time() - started - self._status.timings["decrypt"] = elapsed - # it worked, so record the seqnum and root_hash for next time - self._node._populate_seqnum(seqnum) - self._node._populate_root_hash(root_hash) return plaintext def _done(self, res): - # res is either the new contents, or a Failure - self.log("DONE") + if not self._running: + return self._running = False - self._status.set_active(False) - self._status.set_status("Done") - self._status.set_progress(1.0) - if isinstance(res, str): - self._status.set_size(len(res)) - elapsed = time.time() - self._started - self._status.timings["total"] = elapsed + # res is either the new contents, or a Failure + if isinstance(res, failure.Failure): + self.log("DONE, with failure", failure=res) + else: + self.log("DONE, success!: res=%s" % (res,)) eventually(self._done_deferred.callback, res) - def get_status(self): - return self._status - - -class DictOfSets(dict): - def add(self, key, value): - if key in self: - self[key].add(value) - else: - self[key] = set([value]) class PublishStatus: implements(IPublishStatus) @@ -943,17 +1295,23 @@ class Publish: self._status.set_active(True) self._started = time.time() + def new__init__(self, filenode, servermap): + self._node = filenode + self._servermap =servermap + self._storage_index = self._node.get_storage_index() + self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5] + num = self._node._client.log("Publish(%s): starting" % prefix) + self._log_number = num + def log(self, *args, **kwargs): if 'parent' not in kwargs: kwargs['parent'] = self._log_number - num = log.msg(*args, **kwargs) - return num + return log.msg(*args, **kwargs) def log_err(self, *args, **kwargs): if 'parent' not in kwargs: kwargs['parent'] = self._log_number - num = log.err(*args, **kwargs) - return num + return log.err(*args, **kwargs) def publish(self, newdata): """Publish the filenode's current contents. Returns a Deferred that @@ -1618,6 +1976,7 @@ class MutableFileNode: self._required_shares = None # ditto self._total_shares = None # ditto self._sharemap = {} # known shares, shnum-to-[nodeids] + self._cache = ResponseCache() self._current_data = None # SDMF: we're allowed to cache the contents self._current_roothash = None # ditto @@ -1788,6 +2147,45 @@ class MutableFileNode: def get_verifier(self): return IMutableFileURI(self._uri).get_verifier() + def obtain_lock(self, res=None): + # stub, get real version from zooko's #265 patch + d = defer.Deferred() + d.callback(res) + return d + + def release_lock(self, res): + # stub + return res + + ############################ + + # methods exposed to the higher-layer application + + def update_servermap(self, old_map=None, mode=MODE_ENOUGH): + servermap = old_map or ServerMap() + d = self.obtain_lock() + d.addCallback(lambda res: + ServermapUpdater(self, servermap, mode).update()) + d.addCallback(self.release_lock) + return d + + def download_version(self, servermap, versionid): + """Returns a Deferred that fires with a string.""" + d = self.obtain_lock() + d.addCallback(lambda res: + Retrieve(self, servermap, versionid).download()) + d.addCallback(self.release_lock) + return d + + def publish(self, servermap, newdata): + assert self._pubkey, "update_servermap must be called before publish" + d = self.obtain_lock() + d.addCallback(lambda res: Publish(self, servermap).publish(newdata)) + d.addCallback(self.release_lock) + return d + + ################################# + def check(self): verifier = self.get_verifier() return self._client.getServiceNamed("checker").check(verifier) @@ -1804,26 +2202,19 @@ class MutableFileNode: return d def download_to_data(self): - r = self.retrieve_class(self) - self._client.notify_retrieve(r) - return r.retrieve() + d = self.obtain_lock() + d.addCallback(lambda res: self.update_servermap(mode=MODE_ENOUGH)) + d.addCallback(lambda smap: + self.download_version(smap, + smap.best_recoverable_version())) + d.addCallback(self.release_lock) + return d def update(self, newdata): - # this must be called after a retrieve - assert self._pubkey, "download_to_data() must be called before update()" - assert self._current_seqnum is not None, "download_to_data() must be called before update()" return self._publish(newdata) def overwrite(self, newdata): - # we do retrieve just to get the seqnum. We ignore the contents. - # TODO: use a smaller form of retrieve that doesn't try to fetch the - # data. Also, replace Publish with a form that uses the cached - # sharemap from the previous retrieval. - r = self.retrieve_class(self) - self._client.notify_retrieve(r) - d = r.retrieve() - d.addCallback(lambda ignored: self._publish(newdata)) - return d + return self._publish(newdata) class MutableWatcher(service.MultiService): MAX_PUBLISH_STATUSES = 20 @@ -1872,3 +2263,74 @@ class MutableWatcher(service.MultiService): if p.get_status().get_active()] def list_recent_retrieve(self): return self._recent_retrieve_status + +class ResponseCache: + """I cache share data, to reduce the number of round trips used during + mutable file operations. All of the data in my cache is for a single + storage index, but I will keep information on multiple shares (and + multiple versions) for that storage index. + + My cache is indexed by a (verinfo, shnum) tuple. + + My cache entries contain a set of non-overlapping byteranges: (start, + data, timestamp) tuples. + """ + + def __init__(self): + self.cache = DictOfSets() + + def _does_overlap(self, x_start, x_length, y_start, y_length): + if x_start < y_start: + x_start, y_start = y_start, x_start + x_length, y_length = y_length, x_length + x_end = x_start + x_length + y_end = y_start + y_length + # this just returns a boolean. Eventually we'll want a form that + # returns a range. + if not x_length: + return False + if not y_length: + return False + if x_start >= y_end: + return False + if y_start >= x_end: + return False + return True + + + def _inside(self, x_start, x_length, y_start, y_length): + x_end = x_start + x_length + y_end = y_start + y_length + if x_start < y_start: + return False + if x_start >= y_end: + return False + if x_end < y_start: + return False + if x_end > y_end: + return False + return True + + def add(self, verinfo, shnum, offset, data, timestamp): + index = (verinfo, shnum) + self.cache.add(index, (offset, data, timestamp) ) + + def read(self, verinfo, shnum, offset, length): + """Try to satisfy a read request from cache. + Returns (data, timestamp), or (None, None) if the cache did not hold + the requested data. + """ + + # TODO: join multiple fragments, instead of only returning a hit if + # we have a fragment that contains the whole request + + index = (verinfo, shnum) + end = offset+length + for entry in self.cache.get(index, set()): + (e_start, e_data, e_timestamp) = entry + if self._inside(offset, length, e_start, len(e_data)): + want_start = offset - e_start + want_end = offset+length - e_start + return (e_data[want_start:want_end], e_timestamp) + return None, None + diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index b3be3dc5..1b188c10 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -5,13 +5,14 @@ from twisted.trial import unittest from twisted.internet import defer, reactor from twisted.python import failure from allmydata import mutable, uri, dirnode, download +from allmydata.util import base32 from allmydata.util.idlib import shortnodeid_b2a from allmydata.util.hashutil import tagged_hash from allmydata.encode import NotEnoughPeersError from allmydata.interfaces import IURI, INewDirectoryURI, \ IMutableFileURI, IUploadable, IFileURI from allmydata.filenode import LiteralFileNode -from foolscap.eventual import eventually +from foolscap.eventual import eventually, fireEventually from foolscap.logging import log import sha @@ -110,7 +111,9 @@ class FakePublish(mutable.Publish): def _do_read(self, ss, peerid, storage_index, shnums, readv): assert ss[0] == peerid assert shnums == [] - return defer.maybeDeferred(self._storage.read, peerid, storage_index) + d = fireEventually() + d.addCallback(lambda res: self._storage.read(peerid, storage_index)) + return d def _do_testreadwrite(self, peerid, secrets, tw_vectors, read_vector): @@ -182,7 +185,6 @@ class FakeClient: return res def get_permuted_peers(self, service_name, key): - # TODO: include_myself=True """ @return: list of (peerid, connection,) """ @@ -303,6 +305,7 @@ class Publish(unittest.TestCase): CONTENTS = "some initial contents" fn.create(CONTENTS) p = mutable.Publish(fn) + r = mutable.Retrieve(fn) # make some fake shares shares_and_ids = ( ["%07d" % i for i in range(10)], range(10) ) target_info = None @@ -467,7 +470,27 @@ class Publish(unittest.TestCase): class FakeRetrieve(mutable.Retrieve): def _do_read(self, ss, peerid, storage_index, shnums, readv): - d = defer.maybeDeferred(self._storage.read, peerid, storage_index) + d = fireEventually() + d.addCallback(lambda res: self._storage.read(peerid, storage_index)) + def _read(shares): + response = {} + for shnum in shares: + if shnums and shnum not in shnums: + continue + vector = response[shnum] = [] + for (offset, length) in readv: + assert isinstance(offset, (int, long)), offset + assert isinstance(length, (int, long)), length + vector.append(shares[shnum][offset:offset+length]) + return response + d.addCallback(_read) + return d + +class FakeServermapUpdater(mutable.ServermapUpdater): + + def _do_read(self, ss, peerid, storage_index, shnums, readv): + d = fireEventually() + d.addCallback(lambda res: self._storage.read(peerid, storage_index)) def _read(shares): response = {} for shnum in shares: @@ -487,31 +510,217 @@ class FakeRetrieve(mutable.Retrieve): count = mo.group(1) return FakePubKey(int(count)) +class Sharemap(unittest.TestCase): + def setUp(self): + # publish a file and create shares, which can then be manipulated + # later. + num_peers = 20 + self._client = FakeClient(num_peers) + self._fn = FakeFilenode(self._client) + self._storage = FakeStorage() + d = self._fn.create("") + def _created(res): + p = FakePublish(self._fn) + p._storage = self._storage + contents = "New contents go here" + return p.publish(contents) + d.addCallback(_created) + return d + + def make_servermap(self, storage, mode=mutable.MODE_CHECK): + smu = FakeServermapUpdater(self._fn, mutable.ServerMap(), mode) + smu._storage = storage + d = smu.update() + return d + + def update_servermap(self, storage, oldmap, mode=mutable.MODE_CHECK): + smu = FakeServermapUpdater(self._fn, oldmap, mode) + smu._storage = storage + d = smu.update() + return d + + def failUnlessOneRecoverable(self, sm, num_shares): + self.failUnlessEqual(len(sm.recoverable_versions()), 1) + self.failUnlessEqual(len(sm.unrecoverable_versions()), 0) + best = sm.best_recoverable_version() + self.failIfEqual(best, None) + self.failUnlessEqual(sm.recoverable_versions(), set([best])) + self.failUnlessEqual(len(sm.shares_available()), 1) + self.failUnlessEqual(sm.shares_available()[best], (num_shares, 3)) + return sm + + def test_basic(self): + s = self._storage # unmangled + d = defer.succeed(None) + ms = self.make_servermap + us = self.update_servermap + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_CHECK)) + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_WRITE)) + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ENOUGH)) + # this more stops at k+epsilon, and epsilon=k, so 6 shares + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ANYTHING)) + # this mode stops at 'k' shares + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3)) + + # and can we re-use the same servermap? Note that these are sorted in + # increasing order of number of servers queried, since once a server + # gets into the servermap, we'll always ask it for an update. + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 3)) + d.addCallback(lambda sm: us(s, sm, mode=mutable.MODE_ENOUGH)) + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 6)) + d.addCallback(lambda sm: us(s, sm, mode=mutable.MODE_WRITE)) + d.addCallback(lambda sm: us(s, sm, mode=mutable.MODE_CHECK)) + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10)) + d.addCallback(lambda sm: us(s, sm, mode=mutable.MODE_ANYTHING)) + d.addCallback(lambda sm: self.failUnlessOneRecoverable(sm, 10)) + + return d + + def failUnlessNoneRecoverable(self, sm): + self.failUnlessEqual(len(sm.recoverable_versions()), 0) + self.failUnlessEqual(len(sm.unrecoverable_versions()), 0) + best = sm.best_recoverable_version() + self.failUnlessEqual(best, None) + self.failUnlessEqual(len(sm.shares_available()), 0) + + def test_no_shares(self): + s = self._storage + s._peers = {} # delete all shares + ms = self.make_servermap + d = defer.succeed(None) + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_CHECK)) + d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm)) + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ANYTHING)) + d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm)) + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_WRITE)) + d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm)) + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ENOUGH)) + d.addCallback(lambda sm: self.failUnlessNoneRecoverable(sm)) + + return d + + def failUnlessNotQuiteEnough(self, sm): + self.failUnlessEqual(len(sm.recoverable_versions()), 0) + self.failUnlessEqual(len(sm.unrecoverable_versions()), 1) + best = sm.best_recoverable_version() + self.failUnlessEqual(best, None) + self.failUnlessEqual(len(sm.shares_available()), 1) + self.failUnlessEqual(sm.shares_available().values()[0], (2,3) ) + + def test_not_quite_enough_shares(self): + s = self._storage + ms = self.make_servermap + num_shares = len(s._peers) + for peerid in s._peers: + s._peers[peerid] = {} + num_shares -= 1 + if num_shares == 2: + break + # now there ought to be only two shares left + assert len([peerid for peerid in s._peers if s._peers[peerid]]) == 2 + + d = defer.succeed(None) + + d.addCallback(lambda res: ms(s, mode=mutable.MODE_CHECK)) + d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ANYTHING)) + d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_WRITE)) + d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) + d.addCallback(lambda res: ms(s, mode=mutable.MODE_ENOUGH)) + d.addCallback(lambda sm: self.failUnlessNotQuiteEnough(sm)) + + return d + + class Roundtrip(unittest.TestCase): - def setup_for_publish(self, num_peers): - c = FakeClient(num_peers) - fn = FakeFilenode(c) - s = FakeStorage() - # .create usually returns a Deferred, but we happen to know it's - # synchronous - fn.create("") - p = FakePublish(fn) - p._storage = s - r = FakeRetrieve(fn) - r._storage = s - return c, s, fn, p, r + def setUp(self): + # publish a file and create shares, which can then be manipulated + # later. + self.CONTENTS = "New contents go here" + num_peers = 20 + self._client = FakeClient(num_peers) + self._fn = FakeFilenode(self._client) + self._storage = FakeStorage() + d = self._fn.create("") + def _created(res): + p = FakePublish(self._fn) + p._storage = self._storage + return p.publish(self.CONTENTS) + d.addCallback(_created) + return d + + def make_servermap(self, mode=mutable.MODE_ENOUGH, oldmap=None): + if oldmap is None: + oldmap = mutable.ServerMap() + smu = FakeServermapUpdater(self._fn, oldmap, mode) + smu._storage = self._storage + d = smu.update() + return d + + def abbrev_verinfo(self, verinfo): + if verinfo is None: + return None + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = verinfo + return "%d-%s" % (seqnum, base32.b2a(root_hash)[:4]) + + def abbrev_verinfo_dict(self, verinfo_d): + output = {} + for verinfo,value in verinfo_d.items(): + (seqnum, root_hash, IV, segsize, datalength, k, N, prefix, + offsets_tuple) = verinfo + output["%d-%s" % (seqnum, base32.b2a(root_hash)[:4])] = value + return output + + def dump_servermap(self, servermap): + print "SERVERMAP", servermap + print "RECOVERABLE", [self.abbrev_verinfo(v) + for v in servermap.recoverable_versions()] + print "BEST", self.abbrev_verinfo(servermap.best_recoverable_version()) + print "available", self.abbrev_verinfo_dict(servermap.shares_available()) + + def do_download(self, servermap, version=None): + if version is None: + version = servermap.best_recoverable_version() + r = FakeRetrieve(self._fn, servermap, version) + r._storage = self._storage + return r.download() def test_basic(self): - c, s, fn, p, r = self.setup_for_publish(20) - contents = "New contents go here" - d = p.publish(contents) - def _published(res): - return r.retrieve() - d.addCallback(_published) + d = self.make_servermap() + def _do_retrieve(servermap): + self._smap = servermap + #self.dump_servermap(servermap) + self.failUnlessEqual(len(servermap.recoverable_versions()), 1) + return self.do_download(servermap) + d.addCallback(_do_retrieve) def _retrieved(new_contents): - self.failUnlessEqual(contents, new_contents) + self.failUnlessEqual(new_contents, self.CONTENTS) + d.addCallback(_retrieved) + # we should be able to re-use the same servermap, both with and + # without updating it. + d.addCallback(lambda res: self.do_download(self._smap)) + d.addCallback(_retrieved) + d.addCallback(lambda res: self.make_servermap(oldmap=self._smap)) + d.addCallback(lambda res: self.do_download(self._smap)) + d.addCallback(_retrieved) + # clobbering the pubkey should make the servermap updater re-fetch it + def _clobber_pubkey(res): + self._fn._pubkey = None + d.addCallback(_clobber_pubkey) + d.addCallback(lambda res: self.make_servermap(oldmap=self._smap)) + d.addCallback(lambda res: self.do_download(self._smap)) d.addCallback(_retrieved) return d @@ -538,144 +747,139 @@ class Roundtrip(unittest.TestCase): d.addBoth(done) return d - def _corrupt_all(self, offset, substring, refetch_pubkey=False, - should_succeed=False): - c, s, fn, p, r = self.setup_for_publish(20) - contents = "New contents go here" - d = p.publish(contents) - def _published(res): - if refetch_pubkey: - # clear the pubkey, to force a fetch - r._pubkey = None - for peerid in s._peers: - shares = s._peers[peerid] - for shnum in shares: - data = shares[shnum] - (version, - seqnum, - root_hash, - IV, - k, N, segsize, datalen, - o) = mutable.unpack_header(data) - if isinstance(offset, tuple): - offset1, offset2 = offset - else: - offset1 = offset - offset2 = 0 - if offset1 == "pubkey": - real_offset = 107 - elif offset1 in o: - real_offset = o[offset1] - else: - real_offset = offset1 - real_offset = int(real_offset) + offset2 - assert isinstance(real_offset, int), offset - shares[shnum] = self.flip_bit(data, real_offset) - d.addCallback(_published) - if should_succeed: - d.addCallback(lambda res: r.retrieve()) - else: - d.addCallback(lambda res: - self.shouldFail(NotEnoughPeersError, - "_corrupt_all(offset=%s)" % (offset,), - substring, - r.retrieve)) + def _corrupt(self, res, s, offset, shnums_to_corrupt=None): + # if shnums_to_corrupt is None, corrupt all shares. Otherwise it is a + # list of shnums to corrupt. + for peerid in s._peers: + shares = s._peers[peerid] + for shnum in shares: + if (shnums_to_corrupt is not None + and shnum not in shnums_to_corrupt): + continue + data = shares[shnum] + (version, + seqnum, + root_hash, + IV, + k, N, segsize, datalen, + o) = mutable.unpack_header(data) + if isinstance(offset, tuple): + offset1, offset2 = offset + else: + offset1 = offset + offset2 = 0 + if offset1 == "pubkey": + real_offset = 107 + elif offset1 in o: + real_offset = o[offset1] + else: + real_offset = offset1 + real_offset = int(real_offset) + offset2 + assert isinstance(real_offset, int), offset + shares[shnum] = self.flip_bit(data, real_offset) + return res + + def _test_corrupt_all(self, offset, substring, + should_succeed=False, corrupt_early=True): + d = defer.succeed(None) + if corrupt_early: + d.addCallback(self._corrupt, self._storage, offset) + d.addCallback(lambda res: self.make_servermap()) + if not corrupt_early: + d.addCallback(self._corrupt, self._storage, offset) + def _do_retrieve(servermap): + ver = servermap.best_recoverable_version() + if ver is None and not should_succeed: + # no recoverable versions == not succeeding. The problem + # should be noted in the servermap's list of problems. + if substring: + allproblems = [str(f) for f in servermap.problems] + self.failUnless(substring in "".join(allproblems)) + return + r = FakeRetrieve(self._fn, servermap, ver) + r._storage = self._storage + if should_succeed: + d1 = r.download() + d1.addCallback(lambda new_contents: + self.failUnlessEqual(new_contents, self.CONTENTS)) + return d1 + else: + return self.shouldFail(NotEnoughPeersError, + "_corrupt_all(offset=%s)" % (offset,), + substring, + r.download) + d.addCallback(_do_retrieve) return d def test_corrupt_all_verbyte(self): # when the version byte is not 0, we hit an assertion error in # unpack_share(). - return self._corrupt_all(0, "AssertionError") + return self._test_corrupt_all(0, "AssertionError") def test_corrupt_all_seqnum(self): # a corrupt sequence number will trigger a bad signature - return self._corrupt_all(1, "signature is invalid") + return self._test_corrupt_all(1, "signature is invalid") def test_corrupt_all_R(self): # a corrupt root hash will trigger a bad signature - return self._corrupt_all(9, "signature is invalid") + return self._test_corrupt_all(9, "signature is invalid") def test_corrupt_all_IV(self): # a corrupt salt/IV will trigger a bad signature - return self._corrupt_all(41, "signature is invalid") + return self._test_corrupt_all(41, "signature is invalid") def test_corrupt_all_k(self): # a corrupt 'k' will trigger a bad signature - return self._corrupt_all(57, "signature is invalid") + return self._test_corrupt_all(57, "signature is invalid") def test_corrupt_all_N(self): # a corrupt 'N' will trigger a bad signature - return self._corrupt_all(58, "signature is invalid") + return self._test_corrupt_all(58, "signature is invalid") def test_corrupt_all_segsize(self): # a corrupt segsize will trigger a bad signature - return self._corrupt_all(59, "signature is invalid") + return self._test_corrupt_all(59, "signature is invalid") def test_corrupt_all_datalen(self): # a corrupt data length will trigger a bad signature - return self._corrupt_all(67, "signature is invalid") + return self._test_corrupt_all(67, "signature is invalid") def test_corrupt_all_pubkey(self): - # a corrupt pubkey won't match the URI's fingerprint - return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint", - refetch_pubkey=True) + # a corrupt pubkey won't match the URI's fingerprint. We need to + # remove the pubkey from the filenode, or else it won't bother trying + # to update it. + self._fn._pubkey = None + return self._test_corrupt_all("pubkey", + "pubkey doesn't match fingerprint") def test_corrupt_all_sig(self): # a corrupt signature is a bad one # the signature runs from about [543:799], depending upon the length # of the pubkey - return self._corrupt_all("signature", "signature is invalid", - refetch_pubkey=True) + return self._test_corrupt_all("signature", "signature is invalid") def test_corrupt_all_share_hash_chain_number(self): # a corrupt share hash chain entry will show up as a bad hash. If we # mangle the first byte, that will look like a bad hash number, # causing an IndexError - return self._corrupt_all("share_hash_chain", "corrupt hashes") + return self._test_corrupt_all("share_hash_chain", "corrupt hashes") def test_corrupt_all_share_hash_chain_hash(self): # a corrupt share hash chain entry will show up as a bad hash. If we # mangle a few bytes in, that will look like a bad hash. - return self._corrupt_all(("share_hash_chain",4), "corrupt hashes") + return self._test_corrupt_all(("share_hash_chain",4), "corrupt hashes") def test_corrupt_all_block_hash_tree(self): - return self._corrupt_all("block_hash_tree", "block hash tree failure") + return self._test_corrupt_all("block_hash_tree", + "block hash tree failure") def test_corrupt_all_block(self): - return self._corrupt_all("share_data", "block hash tree failure") + return self._test_corrupt_all("share_data", "block hash tree failure") def test_corrupt_all_encprivkey(self): - # a corrupted privkey won't even be noticed by the reader - return self._corrupt_all("enc_privkey", None, should_succeed=True) - - def test_short_read(self): - c, s, fn, p, r = self.setup_for_publish(20) - contents = "New contents go here" - d = p.publish(contents) - def _published(res): - # force a short read, to make Retrieve._got_results re-send the - # queries. But don't make it so short that we can't read the - # header. - r._read_size = mutable.HEADER_LENGTH + 10 - return r.retrieve() - d.addCallback(_published) - def _retrieved(new_contents): - self.failUnlessEqual(contents, new_contents) - d.addCallback(_retrieved) - return d - - def test_basic_sequenced(self): - c, s, fn, p, r = self.setup_for_publish(20) - s._sequence = c._peerids[:] - contents = "New contents go here" - d = p.publish(contents) - def _published(res): - return r.retrieve() - d.addCallback(_published) - def _retrieved(new_contents): - self.failUnlessEqual(contents, new_contents) - d.addCallback(_retrieved) - return d + # a corrupted privkey won't even be noticed by the reader, only by a + # writer. + return self._test_corrupt_all("enc_privkey", None, should_succeed=True) def test_basic_pubkey_at_end(self): # we corrupt the pubkey in all but the last 'k' shares, allowing the @@ -683,33 +887,25 @@ class Roundtrip(unittest.TestCase): # this is rather pessimistic: our Retrieve process will throw away # the whole share if the pubkey is bad, even though the rest of the # share might be good. - c, s, fn, p, r = self.setup_for_publish(20) - s._sequence = c._peerids[:] - contents = "New contents go here" - d = p.publish(contents) - def _published(res): - r._pubkey = None - homes = [peerid for peerid in c._peerids - if s._peers.get(peerid, {})] - k = fn.get_required_shares() - homes_to_corrupt = homes[:-k] - for peerid in homes_to_corrupt: - shares = s._peers[peerid] - for shnum in shares: - data = shares[shnum] - (version, - seqnum, - root_hash, - IV, - k, N, segsize, datalen, - o) = mutable.unpack_header(data) - offset = 107 # pubkey - shares[shnum] = self.flip_bit(data, offset) - return r.retrieve() - d.addCallback(_published) - def _retrieved(new_contents): - self.failUnlessEqual(contents, new_contents) - d.addCallback(_retrieved) + + self._fn._pubkey = None + k = self._fn.get_required_shares() + N = self._fn.get_total_shares() + d = defer.succeed(None) + d.addCallback(self._corrupt, self._storage, "pubkey", + shnums_to_corrupt=range(0, N-k)) + d.addCallback(lambda res: self.make_servermap()) + def _do_retrieve(servermap): + self.failUnless(servermap.problems) + self.failUnless("pubkey doesn't match fingerprint" + in str(servermap.problems[0])) + ver = servermap.best_recoverable_version() + r = FakeRetrieve(self._fn, servermap, ver) + r._storage = self._storage + return r.download() + d.addCallback(_do_retrieve) + d.addCallback(lambda new_contents: + self.failUnlessEqual(new_contents, self.CONTENTS)) return d def _encode(self, c, s, fn, k, n, data): @@ -741,6 +937,32 @@ class Roundtrip(unittest.TestCase): d.addCallback(_published) return d +class MultipleEncodings(unittest.TestCase): + + def publish(self): + # publish a file and create shares, which can then be manipulated + # later. + self.CONTENTS = "New contents go here" + num_peers = 20 + self._client = FakeClient(num_peers) + self._fn = FakeFilenode(self._client) + self._storage = FakeStorage() + d = self._fn.create("") + def _created(res): + p = FakePublish(self._fn) + p._storage = self._storage + return p.publish(self.CONTENTS) + d.addCallback(_created) + return d + + def make_servermap(self, mode=mutable.MODE_ENOUGH, oldmap=None): + if oldmap is None: + oldmap = mutable.ServerMap() + smu = FakeServermapUpdater(self._fn, oldmap, mode) + smu._storage = self._storage + d = smu.update() + return d + def test_multiple_encodings(self): # we encode the same file in two different ways (3-of-10 and 4-of-9), # then mix up the shares, to make sure that download survives seeing @@ -842,3 +1064,87 @@ class Roundtrip(unittest.TestCase): d.addCallback(_retrieved) return d + +class Utils(unittest.TestCase): + def test_dict_of_sets(self): + ds = mutable.DictOfSets() + ds.add(1, "a") + ds.add(2, "b") + ds.add(2, "b") + ds.add(2, "c") + self.failUnlessEqual(ds[1], set(["a"])) + self.failUnlessEqual(ds[2], set(["b", "c"])) + ds.discard(3, "d") # should not raise an exception + ds.discard(2, "b") + self.failUnlessEqual(ds[2], set(["c"])) + ds.discard(2, "c") + self.failIf(2 in ds) + + def _do_inside(self, c, x_start, x_length, y_start, y_length): + # we compare this against sets of integers + x = set(range(x_start, x_start+x_length)) + y = set(range(y_start, y_start+y_length)) + should_be_inside = x.issubset(y) + self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length, + y_start, y_length), + str((x_start, x_length, y_start, y_length))) + + def test_cache_inside(self): + c = mutable.ResponseCache() + x_start = 10 + x_length = 5 + for y_start in range(8, 17): + for y_length in range(8): + self._do_inside(c, x_start, x_length, y_start, y_length) + + def _do_overlap(self, c, x_start, x_length, y_start, y_length): + # we compare this against sets of integers + x = set(range(x_start, x_start+x_length)) + y = set(range(y_start, y_start+y_length)) + overlap = bool(x.intersection(y)) + self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length, + y_start, y_length), + str((x_start, x_length, y_start, y_length))) + + def test_cache_overlap(self): + c = mutable.ResponseCache() + x_start = 10 + x_length = 5 + for y_start in range(8, 17): + for y_length in range(8): + self._do_overlap(c, x_start, x_length, y_start, y_length) + + def test_cache(self): + c = mutable.ResponseCache() + # xdata = base62.b2a(os.urandom(100))[:100] + xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l" + ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs" + nope = (None, None) + c.add("v1", 1, 0, xdata, "time0") + c.add("v1", 1, 2000, ydata, "time1") + self.failUnlessEqual(c.read("v2", 1, 10, 11), nope) + self.failUnlessEqual(c.read("v1", 2, 10, 11), nope) + self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0")) + self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0")) + self.failUnlessEqual(c.read("v1", 1, 300, 10), nope) + self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1")) + self.failUnlessEqual(c.read("v1", 1, 0, 101), nope) + self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0")) + self.failUnlessEqual(c.read("v1", 1, 100, 1), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope) + self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope) + self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope) + + # optional: join fragments + c = mutable.ResponseCache() + c.add("v1", 1, 0, xdata[:10], "time0") + c.add("v1", 1, 10, xdata[10:20], "time1") + #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0")) + +