]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
download: refactor handling of URI Extension Block and crypttext hash tree, simplify...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap import DeadReferenceError
8 from foolscap.eventual import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError
10 from allmydata.util import hashutil, idlib, log
11 from allmydata import hashtree, codec, storage
12 from pycryptopp.cipher.aes import AES
13 from pycryptopp.publickey import rsa
14
15 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
16 from layout import SIGNED_PREFIX, unpack_share_data
17
18 class RetrieveStatus:
19     implements(IRetrieveStatus)
20     statusid_counter = count(0)
21     def __init__(self):
22         self.timings = {}
23         self.timings["fetch_per_server"] = {}
24         self.timings["cumulative_verify"] = 0.0
25         self.problems = {}
26         self.active = True
27         self.storage_index = None
28         self.helper = False
29         self.encoding = ("?","?")
30         self.size = None
31         self.status = "Not started"
32         self.progress = 0.0
33         self.counter = self.statusid_counter.next()
34         self.started = time.time()
35
36     def get_started(self):
37         return self.started
38     def get_storage_index(self):
39         return self.storage_index
40     def get_encoding(self):
41         return self.encoding
42     def using_helper(self):
43         return self.helper
44     def get_size(self):
45         return self.size
46     def get_status(self):
47         return self.status
48     def get_progress(self):
49         return self.progress
50     def get_active(self):
51         return self.active
52     def get_counter(self):
53         return self.counter
54
55     def add_fetch_timing(self, peerid, elapsed):
56         if peerid not in self.timings["fetch_per_server"]:
57             self.timings["fetch_per_server"][peerid] = []
58         self.timings["fetch_per_server"][peerid].append(elapsed)
59     def set_storage_index(self, si):
60         self.storage_index = si
61     def set_helper(self, helper):
62         self.helper = helper
63     def set_encoding(self, k, n):
64         self.encoding = (k, n)
65     def set_size(self, size):
66         self.size = size
67     def set_status(self, status):
68         self.status = status
69     def set_progress(self, value):
70         self.progress = value
71     def set_active(self, value):
72         self.active = value
73
74 class Marker:
75     pass
76
77 class Retrieve:
78     # this class is currently single-use. Eventually (in MDMF) we will make
79     # it multi-use, in which case you can call download(range) multiple
80     # times, and each will have a separate response chain. However the
81     # Retrieve object will remain tied to a specific version of the file, and
82     # will use a single ServerMap instance.
83
84     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False):
85         self._node = filenode
86         assert self._node._pubkey
87         self._storage_index = filenode.get_storage_index()
88         assert self._node._readkey
89         self._last_failure = None
90         prefix = storage.si_b2a(self._storage_index)[:5]
91         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
92         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
93         self._running = True
94         self._decoding = False
95         self._bad_shares = set()
96
97         self.servermap = servermap
98         assert self._node._pubkey
99         self.verinfo = verinfo
100         # during repair, we may be called upon to grab the private key, since
101         # it wasn't picked up during a verify=False checker run, and we'll
102         # need it for repair to generate the a new version.
103         self._need_privkey = fetch_privkey
104         if self._node._privkey:
105             self._need_privkey = False
106
107         self._status = RetrieveStatus()
108         self._status.set_storage_index(self._storage_index)
109         self._status.set_helper(False)
110         self._status.set_progress(0.0)
111         self._status.set_active(True)
112         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
113          offsets_tuple) = self.verinfo
114         self._status.set_size(datalength)
115         self._status.set_encoding(k, N)
116
117     def get_status(self):
118         return self._status
119
120     def log(self, *args, **kwargs):
121         if "parent" not in kwargs:
122             kwargs["parent"] = self._log_number
123         if "facility" not in kwargs:
124             kwargs["facility"] = "tahoe.mutable.retrieve"
125         return log.msg(*args, **kwargs)
126
127     def download(self):
128         self._done_deferred = defer.Deferred()
129         self._started = time.time()
130         self._status.set_status("Retrieving Shares")
131
132         # first, which servers can we use?
133         versionmap = self.servermap.make_versionmap()
134         shares = versionmap[self.verinfo]
135         # this sharemap is consumed as we decide to send requests
136         self.remaining_sharemap = DictOfSets()
137         for (shnum, peerid, timestamp) in shares:
138             self.remaining_sharemap.add(shnum, peerid)
139
140         self.shares = {} # maps shnum to validated blocks
141
142         # how many shares do we need?
143         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
144          offsets_tuple) = self.verinfo
145         assert len(self.remaining_sharemap) >= k
146         # we start with the lowest shnums we have available, since FEC is
147         # faster if we're using "primary shares"
148         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
149         for shnum in self.active_shnums:
150             # we use an arbitrary peer who has the share. If shares are
151             # doubled up (more than one share per peer), we could make this
152             # run faster by spreading the load among multiple peers. But the
153             # algorithm to do that is more complicated than I want to write
154             # right now, and a well-provisioned grid shouldn't have multiple
155             # shares per peer.
156             peerid = list(self.remaining_sharemap[shnum])[0]
157             self.get_data(shnum, peerid)
158
159         # control flow beyond this point: state machine. Receiving responses
160         # from queries is the input. We might send out more queries, or we
161         # might produce a result.
162
163         return self._done_deferred
164
165     def get_data(self, shnum, peerid):
166         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
167                  shnum=shnum,
168                  peerid=idlib.shortnodeid_b2a(peerid),
169                  level=log.NOISY)
170         ss = self.servermap.connections[peerid]
171         started = time.time()
172         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
173          offsets_tuple) = self.verinfo
174         offsets = dict(offsets_tuple)
175
176         # we read the checkstring, to make sure that the data we grab is from
177         # the right version.
178         readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ]
179
180         # We also read the data, and the hashes necessary to validate them
181         # (share_hash_chain, block_hash_tree, share_data). We don't read the
182         # signature or the pubkey, since that was handled during the
183         # servermap phase, and we'll be comparing the share hash chain
184         # against the roothash that was validated back then.
185
186         readv.append( (offsets['share_hash_chain'],
187                        offsets['enc_privkey'] - offsets['share_hash_chain'] ) )
188
189         # if we need the private key (for repair), we also fetch that
190         if self._need_privkey:
191             readv.append( (offsets['enc_privkey'],
192                            offsets['EOF'] - offsets['enc_privkey']) )
193
194         m = Marker()
195         self._outstanding_queries[m] = (peerid, shnum, started)
196
197         # ask the cache first
198         got_from_cache = False
199         datavs = []
200         for (offset, length) in readv:
201             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
202                                                        offset, length)
203             if data is not None:
204                 datavs.append(data)
205         if len(datavs) == len(readv):
206             self.log("got data from cache")
207             got_from_cache = True
208             d = fireEventually({shnum: datavs})
209             # datavs is a dict mapping shnum to a pair of strings
210         else:
211             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
212         self.remaining_sharemap.discard(shnum, peerid)
213
214         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
215         d.addErrback(self._query_failed, m, peerid)
216         # errors that aren't handled by _query_failed (and errors caused by
217         # _query_failed) get logged, but we still want to check for doneness.
218         def _oops(f):
219             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
220                      shnum=shnum,
221                      peerid=idlib.shortnodeid_b2a(peerid),
222                      failure=f,
223                      level=log.WEIRD, umid="W0xnQA")
224         d.addErrback(_oops)
225         d.addBoth(self._check_for_done)
226         # any error during _check_for_done means the download fails. If the
227         # download is successful, _check_for_done will fire _done by itself.
228         d.addErrback(self._done)
229         d.addErrback(log.err)
230         return d # purely for testing convenience
231
232     def _do_read(self, ss, peerid, storage_index, shnums, readv):
233         # isolate the callRemote to a separate method, so tests can subclass
234         # Publish and override it
235         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
236         return d
237
238     def remove_peer(self, peerid):
239         for shnum in list(self.remaining_sharemap.keys()):
240             self.remaining_sharemap.discard(shnum, peerid)
241
242     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
243         now = time.time()
244         elapsed = now - started
245         if not got_from_cache:
246             self._status.add_fetch_timing(peerid, elapsed)
247         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
248                  shares=len(datavs),
249                  peerid=idlib.shortnodeid_b2a(peerid),
250                  level=log.NOISY)
251         self._outstanding_queries.pop(marker, None)
252         if not self._running:
253             return
254
255         # note that we only ask for a single share per query, so we only
256         # expect a single share back. On the other hand, we use the extra
257         # shares if we get them.. seems better than an assert().
258
259         for shnum,datav in datavs.items():
260             (prefix, hash_and_data) = datav[:2]
261             try:
262                 self._got_results_one_share(shnum, peerid,
263                                             prefix, hash_and_data)
264             except CorruptShareError, e:
265                 # log it and give the other shares a chance to be processed
266                 f = failure.Failure()
267                 self.log(format="bad share: %(f_value)s",
268                          f_value=str(f.value), failure=f,
269                          level=log.WEIRD, umid="7fzWZw")
270                 self.notify_server_corruption(peerid, shnum, str(e))
271                 self.remove_peer(peerid)
272                 self.servermap.mark_bad_share(peerid, shnum, prefix)
273                 self._bad_shares.add( (peerid, shnum) )
274                 self._status.problems[peerid] = f
275                 self._last_failure = f
276                 pass
277             if self._need_privkey and len(datav) > 2:
278                 lp = None
279                 self._try_to_validate_privkey(datav[2], peerid, shnum, lp)
280         # all done!
281
282     def notify_server_corruption(self, peerid, shnum, reason):
283         ss = self.servermap.connections[peerid]
284         ss.callRemoteOnly("advise_corrupt_share",
285                           "mutable", self._storage_index, shnum, reason)
286
287     def _got_results_one_share(self, shnum, peerid,
288                                got_prefix, got_hash_and_data):
289         self.log("_got_results: got shnum #%d from peerid %s"
290                  % (shnum, idlib.shortnodeid_b2a(peerid)))
291         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
292          offsets_tuple) = self.verinfo
293         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
294         if got_prefix != prefix:
295             msg = "someone wrote to the data since we read the servermap: prefix changed"
296             raise UncoordinatedWriteError(msg)
297         (share_hash_chain, block_hash_tree,
298          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
299
300         assert isinstance(share_data, str)
301         # build the block hash tree. SDMF has only one leaf.
302         leaves = [hashutil.block_hash(share_data)]
303         t = hashtree.HashTree(leaves)
304         if list(t) != block_hash_tree:
305             raise CorruptShareError(peerid, shnum, "block hash tree failure")
306         share_hash_leaf = t[0]
307         t2 = hashtree.IncompleteHashTree(N)
308         # root_hash was checked by the signature
309         t2.set_hashes({0: root_hash})
310         try:
311             t2.set_hashes(hashes=share_hash_chain,
312                           leaves={shnum: share_hash_leaf})
313         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
314                 IndexError), e:
315             msg = "corrupt hashes: %s" % (e,)
316             raise CorruptShareError(peerid, shnum, msg)
317         self.log(" data valid! len=%d" % len(share_data))
318         # each query comes down to this: placing validated share data into
319         # self.shares
320         self.shares[shnum] = share_data
321
322     def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
323
324         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
325         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
326         if alleged_writekey != self._node.get_writekey():
327             self.log("invalid privkey from %s shnum %d" %
328                      (idlib.nodeid_b2a(peerid)[:8], shnum),
329                      parent=lp, level=log.WEIRD, umid="YIw4tA")
330             return
331
332         # it's good
333         self.log("got valid privkey from shnum %d on peerid %s" %
334                  (shnum, idlib.shortnodeid_b2a(peerid)),
335                  parent=lp)
336         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
337         self._node._populate_encprivkey(enc_privkey)
338         self._node._populate_privkey(privkey)
339         self._need_privkey = False
340
341     def _query_failed(self, f, marker, peerid):
342         self.log(format="query to [%(peerid)s] failed",
343                  peerid=idlib.shortnodeid_b2a(peerid),
344                  level=log.NOISY)
345         self._status.problems[peerid] = f
346         self._outstanding_queries.pop(marker, None)
347         if not self._running:
348             return
349         self._last_failure = f
350         self.remove_peer(peerid)
351         level = log.WEIRD
352         if f.check(DeadReferenceError):
353             level = log.UNUSUAL
354         self.log(format="error during query: %(f_value)s",
355                  f_value=str(f.value), failure=f, level=level, umid="gOJB5g")
356
357     def _check_for_done(self, res):
358         # exit paths:
359         #  return : keep waiting, no new queries
360         #  return self._send_more_queries(outstanding) : send some more queries
361         #  fire self._done(plaintext) : download successful
362         #  raise exception : download fails
363
364         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
365                  running=self._running, decoding=self._decoding,
366                  level=log.NOISY)
367         if not self._running:
368             return
369         if self._decoding:
370             return
371         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
372          offsets_tuple) = self.verinfo
373
374         if len(self.shares) < k:
375             # we don't have enough shares yet
376             return self._maybe_send_more_queries(k)
377         if self._need_privkey:
378             # we got k shares, but none of them had a valid privkey. TODO:
379             # look further. Adding code to do this is a bit complicated, and
380             # I want to avoid that complication, and this should be pretty
381             # rare (k shares with bitflips in the enc_privkey but not in the
382             # data blocks). If we actually do get here, the subsequent repair
383             # will fail for lack of a privkey.
384             self.log("got k shares but still need_privkey, bummer",
385                      level=log.WEIRD, umid="MdRHPA")
386
387         # we have enough to finish. All the shares have had their hashes
388         # checked, so if something fails at this point, we don't know how
389         # to fix it, so the download will fail.
390
391         self._decoding = True # avoid reentrancy
392         self._status.set_status("decoding")
393         now = time.time()
394         elapsed = now - self._started
395         self._status.timings["fetch"] = elapsed
396
397         d = defer.maybeDeferred(self._decode)
398         d.addCallback(self._decrypt, IV, self._node._readkey)
399         d.addBoth(self._done)
400         return d # purely for test convenience
401
402     def _maybe_send_more_queries(self, k):
403         # we don't have enough shares yet. Should we send out more queries?
404         # There are some number of queries outstanding, each for a single
405         # share. If we can generate 'needed_shares' additional queries, we do
406         # so. If we can't, then we know this file is a goner, and we raise
407         # NotEnoughSharesError.
408         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
409                          "outstanding=%(outstanding)d"),
410                  have=len(self.shares), k=k,
411                  outstanding=len(self._outstanding_queries),
412                  level=log.NOISY)
413
414         remaining_shares = k - len(self.shares)
415         needed = remaining_shares - len(self._outstanding_queries)
416         if not needed:
417             # we have enough queries in flight already
418
419             # TODO: but if they've been in flight for a long time, and we
420             # have reason to believe that new queries might respond faster
421             # (i.e. we've seen other queries come back faster, then consider
422             # sending out new queries. This could help with peers which have
423             # silently gone away since the servermap was updated, for which
424             # we're still waiting for the 15-minute TCP disconnect to happen.
425             self.log("enough queries are in flight, no more are needed",
426                      level=log.NOISY)
427             return
428
429         outstanding_shnums = set([shnum
430                                   for (peerid, shnum, started)
431                                   in self._outstanding_queries.values()])
432         # prefer low-numbered shares, they are more likely to be primary
433         available_shnums = sorted(self.remaining_sharemap.keys())
434         for shnum in available_shnums:
435             if shnum in outstanding_shnums:
436                 # skip ones that are already in transit
437                 continue
438             if shnum not in self.remaining_sharemap:
439                 # no servers for that shnum. note that DictOfSets removes
440                 # empty sets from the dict for us.
441                 continue
442             peerid = list(self.remaining_sharemap[shnum])[0]
443             # get_data will remove that peerid from the sharemap, and add the
444             # query to self._outstanding_queries
445             self._status.set_status("Retrieving More Shares")
446             self.get_data(shnum, peerid)
447             needed -= 1
448             if not needed:
449                 break
450
451         # at this point, we have as many outstanding queries as we can. If
452         # needed!=0 then we might not have enough to recover the file.
453         if needed:
454             format = ("ran out of peers: "
455                       "have %(have)d shares (k=%(k)d), "
456                       "%(outstanding)d queries in flight, "
457                       "need %(need)d more, "
458                       "found %(bad)d bad shares")
459             args = {"have": len(self.shares),
460                     "k": k,
461                     "outstanding": len(self._outstanding_queries),
462                     "need": needed,
463                     "bad": len(self._bad_shares),
464                     }
465             self.log(format=format,
466                      level=log.WEIRD, umid="ezTfjw", **args)
467             err = NotEnoughSharesError("%s, last failure: %s" %
468                                       (format % args, self._last_failure))
469             if self._bad_shares:
470                 self.log("We found some bad shares this pass. You should "
471                          "update the servermap and try again to check "
472                          "more peers",
473                          level=log.WEIRD, umid="EFkOlA")
474                 err.servermap = self.servermap
475             raise err
476
477         return
478
479     def _decode(self):
480         started = time.time()
481         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
482          offsets_tuple) = self.verinfo
483
484         # shares_dict is a dict mapping shnum to share data, but the codec
485         # wants two lists.
486         shareids = []; shares = []
487         for shareid, share in self.shares.items():
488             shareids.append(shareid)
489             shares.append(share)
490
491         assert len(shareids) >= k, len(shareids)
492         # zfec really doesn't want extra shares
493         shareids = shareids[:k]
494         shares = shares[:k]
495
496         fec = codec.CRSDecoder()
497         fec.set_params(segsize, k, N)
498
499         self.log("params %s, we have %d shares" % ((segsize, k, N), len(shares)))
500         self.log("about to decode, shareids=%s" % (shareids,))
501         d = defer.maybeDeferred(fec.decode, shares, shareids)
502         def _done(buffers):
503             self._status.timings["decode"] = time.time() - started
504             self.log(" decode done, %d buffers" % len(buffers))
505             segment = "".join(buffers)
506             self.log(" joined length %d, datalength %d" %
507                      (len(segment), datalength))
508             segment = segment[:datalength]
509             self.log(" segment len=%d" % len(segment))
510             return segment
511         def _err(f):
512             self.log(" decode failed: %s" % f)
513             return f
514         d.addCallback(_done)
515         d.addErrback(_err)
516         return d
517
518     def _decrypt(self, crypttext, IV, readkey):
519         self._status.set_status("decrypting")
520         started = time.time()
521         key = hashutil.ssk_readkey_data_hash(IV, readkey)
522         decryptor = AES(key)
523         plaintext = decryptor.process(crypttext)
524         self._status.timings["decrypt"] = time.time() - started
525         return plaintext
526
527     def _done(self, res):
528         if not self._running:
529             return
530         self._running = False
531         self._status.set_active(False)
532         self._status.timings["total"] = time.time() - self._started
533         # res is either the new contents, or a Failure
534         if isinstance(res, failure.Failure):
535             self.log("Retrieve done, with failure", failure=res,
536                      level=log.UNUSUAL)
537             self._status.set_status("Failed")
538         else:
539             self.log("Retrieve done, success!")
540             self._status.set_status("Done")
541             self._status.set_progress(1.0)
542             # remember the encoding parameters, use them again next time
543             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
544              offsets_tuple) = self.verinfo
545             self._node._populate_required_shares(k)
546             self._node._populate_total_shares(N)
547         eventually(self._done_deferred.callback, res)
548