]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
Split out NoSharesError, stop adding attributes to NotEnoughSharesError, change human...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap.api import DeadReferenceError, eventually, fireEventually
8 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError
9 from allmydata.util import hashutil, idlib, log
10 from allmydata import hashtree, codec
11 from allmydata.storage.server import si_b2a
12 from pycryptopp.cipher.aes import AES
13 from pycryptopp.publickey import rsa
14
15 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
16 from layout import SIGNED_PREFIX, unpack_share_data
17
18 class RetrieveStatus:
19     implements(IRetrieveStatus)
20     statusid_counter = count(0)
21     def __init__(self):
22         self.timings = {}
23         self.timings["fetch_per_server"] = {}
24         self.timings["cumulative_verify"] = 0.0
25         self.problems = {}
26         self.active = True
27         self.storage_index = None
28         self.helper = False
29         self.encoding = ("?","?")
30         self.size = None
31         self.status = "Not started"
32         self.progress = 0.0
33         self.counter = self.statusid_counter.next()
34         self.started = time.time()
35
36     def get_started(self):
37         return self.started
38     def get_storage_index(self):
39         return self.storage_index
40     def get_encoding(self):
41         return self.encoding
42     def using_helper(self):
43         return self.helper
44     def get_size(self):
45         return self.size
46     def get_status(self):
47         return self.status
48     def get_progress(self):
49         return self.progress
50     def get_active(self):
51         return self.active
52     def get_counter(self):
53         return self.counter
54
55     def add_fetch_timing(self, peerid, elapsed):
56         if peerid not in self.timings["fetch_per_server"]:
57             self.timings["fetch_per_server"][peerid] = []
58         self.timings["fetch_per_server"][peerid].append(elapsed)
59     def set_storage_index(self, si):
60         self.storage_index = si
61     def set_helper(self, helper):
62         self.helper = helper
63     def set_encoding(self, k, n):
64         self.encoding = (k, n)
65     def set_size(self, size):
66         self.size = size
67     def set_status(self, status):
68         self.status = status
69     def set_progress(self, value):
70         self.progress = value
71     def set_active(self, value):
72         self.active = value
73
74 class Marker:
75     pass
76
77 class Retrieve:
78     # this class is currently single-use. Eventually (in MDMF) we will make
79     # it multi-use, in which case you can call download(range) multiple
80     # times, and each will have a separate response chain. However the
81     # Retrieve object will remain tied to a specific version of the file, and
82     # will use a single ServerMap instance.
83
84     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False):
85         self._node = filenode
86         assert self._node._pubkey
87         self._storage_index = filenode.get_storage_index()
88         assert self._node._readkey
89         self._last_failure = None
90         prefix = si_b2a(self._storage_index)[:5]
91         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
92         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
93         self._running = True
94         self._decoding = False
95         self._bad_shares = set()
96
97         self.servermap = servermap
98         assert self._node._pubkey
99         self.verinfo = verinfo
100         # during repair, we may be called upon to grab the private key, since
101         # it wasn't picked up during a verify=False checker run, and we'll
102         # need it for repair to generate the a new version.
103         self._need_privkey = fetch_privkey
104         if self._node._privkey:
105             self._need_privkey = False
106
107         self._status = RetrieveStatus()
108         self._status.set_storage_index(self._storage_index)
109         self._status.set_helper(False)
110         self._status.set_progress(0.0)
111         self._status.set_active(True)
112         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
113          offsets_tuple) = self.verinfo
114         self._status.set_size(datalength)
115         self._status.set_encoding(k, N)
116
117     def get_status(self):
118         return self._status
119
120     def log(self, *args, **kwargs):
121         if "parent" not in kwargs:
122             kwargs["parent"] = self._log_number
123         if "facility" not in kwargs:
124             kwargs["facility"] = "tahoe.mutable.retrieve"
125         return log.msg(*args, **kwargs)
126
127     def download(self):
128         self._done_deferred = defer.Deferred()
129         self._started = time.time()
130         self._status.set_status("Retrieving Shares")
131
132         # first, which servers can we use?
133         versionmap = self.servermap.make_versionmap()
134         shares = versionmap[self.verinfo]
135         # this sharemap is consumed as we decide to send requests
136         self.remaining_sharemap = DictOfSets()
137         for (shnum, peerid, timestamp) in shares:
138             self.remaining_sharemap.add(shnum, peerid)
139
140         self.shares = {} # maps shnum to validated blocks
141
142         # how many shares do we need?
143         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
144          offsets_tuple) = self.verinfo
145         assert len(self.remaining_sharemap) >= k
146         # we start with the lowest shnums we have available, since FEC is
147         # faster if we're using "primary shares"
148         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
149         for shnum in self.active_shnums:
150             # we use an arbitrary peer who has the share. If shares are
151             # doubled up (more than one share per peer), we could make this
152             # run faster by spreading the load among multiple peers. But the
153             # algorithm to do that is more complicated than I want to write
154             # right now, and a well-provisioned grid shouldn't have multiple
155             # shares per peer.
156             peerid = list(self.remaining_sharemap[shnum])[0]
157             self.get_data(shnum, peerid)
158
159         # control flow beyond this point: state machine. Receiving responses
160         # from queries is the input. We might send out more queries, or we
161         # might produce a result.
162
163         return self._done_deferred
164
165     def get_data(self, shnum, peerid):
166         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
167                  shnum=shnum,
168                  peerid=idlib.shortnodeid_b2a(peerid),
169                  level=log.NOISY)
170         ss = self.servermap.connections[peerid]
171         started = time.time()
172         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
173          offsets_tuple) = self.verinfo
174         offsets = dict(offsets_tuple)
175
176         # we read the checkstring, to make sure that the data we grab is from
177         # the right version.
178         readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ]
179
180         # We also read the data, and the hashes necessary to validate them
181         # (share_hash_chain, block_hash_tree, share_data). We don't read the
182         # signature or the pubkey, since that was handled during the
183         # servermap phase, and we'll be comparing the share hash chain
184         # against the roothash that was validated back then.
185
186         readv.append( (offsets['share_hash_chain'],
187                        offsets['enc_privkey'] - offsets['share_hash_chain'] ) )
188
189         # if we need the private key (for repair), we also fetch that
190         if self._need_privkey:
191             readv.append( (offsets['enc_privkey'],
192                            offsets['EOF'] - offsets['enc_privkey']) )
193
194         m = Marker()
195         self._outstanding_queries[m] = (peerid, shnum, started)
196
197         # ask the cache first
198         got_from_cache = False
199         datavs = []
200         for (offset, length) in readv:
201             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
202                                                        offset, length)
203             if data is not None:
204                 datavs.append(data)
205         if len(datavs) == len(readv):
206             self.log("got data from cache")
207             got_from_cache = True
208             d = fireEventually({shnum: datavs})
209             # datavs is a dict mapping shnum to a pair of strings
210         else:
211             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
212         self.remaining_sharemap.discard(shnum, peerid)
213
214         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
215         d.addErrback(self._query_failed, m, peerid)
216         # errors that aren't handled by _query_failed (and errors caused by
217         # _query_failed) get logged, but we still want to check for doneness.
218         def _oops(f):
219             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
220                      shnum=shnum,
221                      peerid=idlib.shortnodeid_b2a(peerid),
222                      failure=f,
223                      level=log.WEIRD, umid="W0xnQA")
224         d.addErrback(_oops)
225         d.addBoth(self._check_for_done)
226         # any error during _check_for_done means the download fails. If the
227         # download is successful, _check_for_done will fire _done by itself.
228         d.addErrback(self._done)
229         d.addErrback(log.err)
230         return d # purely for testing convenience
231
232     def _do_read(self, ss, peerid, storage_index, shnums, readv):
233         # isolate the callRemote to a separate method, so tests can subclass
234         # Publish and override it
235         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
236         return d
237
238     def remove_peer(self, peerid):
239         for shnum in list(self.remaining_sharemap.keys()):
240             self.remaining_sharemap.discard(shnum, peerid)
241
242     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
243         now = time.time()
244         elapsed = now - started
245         if not got_from_cache:
246             self._status.add_fetch_timing(peerid, elapsed)
247         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
248                  shares=len(datavs),
249                  peerid=idlib.shortnodeid_b2a(peerid),
250                  level=log.NOISY)
251         self._outstanding_queries.pop(marker, None)
252         if not self._running:
253             return
254
255         # note that we only ask for a single share per query, so we only
256         # expect a single share back. On the other hand, we use the extra
257         # shares if we get them.. seems better than an assert().
258
259         for shnum,datav in datavs.items():
260             (prefix, hash_and_data) = datav[:2]
261             try:
262                 self._got_results_one_share(shnum, peerid,
263                                             prefix, hash_and_data)
264             except CorruptShareError, e:
265                 # log it and give the other shares a chance to be processed
266                 f = failure.Failure()
267                 self.log(format="bad share: %(f_value)s",
268                          f_value=str(f.value), failure=f,
269                          level=log.WEIRD, umid="7fzWZw")
270                 self.notify_server_corruption(peerid, shnum, str(e))
271                 self.remove_peer(peerid)
272                 self.servermap.mark_bad_share(peerid, shnum, prefix)
273                 self._bad_shares.add( (peerid, shnum) )
274                 self._status.problems[peerid] = f
275                 self._last_failure = f
276                 pass
277             if self._need_privkey and len(datav) > 2:
278                 lp = None
279                 self._try_to_validate_privkey(datav[2], peerid, shnum, lp)
280         # all done!
281
282     def notify_server_corruption(self, peerid, shnum, reason):
283         ss = self.servermap.connections[peerid]
284         ss.callRemoteOnly("advise_corrupt_share",
285                           "mutable", self._storage_index, shnum, reason)
286
287     def _got_results_one_share(self, shnum, peerid,
288                                got_prefix, got_hash_and_data):
289         self.log("_got_results: got shnum #%d from peerid %s"
290                  % (shnum, idlib.shortnodeid_b2a(peerid)))
291         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
292          offsets_tuple) = self.verinfo
293         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
294         if got_prefix != prefix:
295             msg = "someone wrote to the data since we read the servermap: prefix changed"
296             raise UncoordinatedWriteError(msg)
297         (share_hash_chain, block_hash_tree,
298          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
299
300         assert isinstance(share_data, str)
301         # build the block hash tree. SDMF has only one leaf.
302         leaves = [hashutil.block_hash(share_data)]
303         t = hashtree.HashTree(leaves)
304         if list(t) != block_hash_tree:
305             raise CorruptShareError(peerid, shnum, "block hash tree failure")
306         share_hash_leaf = t[0]
307         t2 = hashtree.IncompleteHashTree(N)
308         # root_hash was checked by the signature
309         t2.set_hashes({0: root_hash})
310         try:
311             t2.set_hashes(hashes=share_hash_chain,
312                           leaves={shnum: share_hash_leaf})
313         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
314                 IndexError), e:
315             msg = "corrupt hashes: %s" % (e,)
316             raise CorruptShareError(peerid, shnum, msg)
317         self.log(" data valid! len=%d" % len(share_data))
318         # each query comes down to this: placing validated share data into
319         # self.shares
320         self.shares[shnum] = share_data
321
322     def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
323
324         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
325         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
326         if alleged_writekey != self._node.get_writekey():
327             self.log("invalid privkey from %s shnum %d" %
328                      (idlib.nodeid_b2a(peerid)[:8], shnum),
329                      parent=lp, level=log.WEIRD, umid="YIw4tA")
330             return
331
332         # it's good
333         self.log("got valid privkey from shnum %d on peerid %s" %
334                  (shnum, idlib.shortnodeid_b2a(peerid)),
335                  parent=lp)
336         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
337         self._node._populate_encprivkey(enc_privkey)
338         self._node._populate_privkey(privkey)
339         self._need_privkey = False
340
341     def _query_failed(self, f, marker, peerid):
342         self.log(format="query to [%(peerid)s] failed",
343                  peerid=idlib.shortnodeid_b2a(peerid),
344                  level=log.NOISY)
345         self._status.problems[peerid] = f
346         self._outstanding_queries.pop(marker, None)
347         if not self._running:
348             return
349         self._last_failure = f
350         self.remove_peer(peerid)
351         level = log.WEIRD
352         if f.check(DeadReferenceError):
353             level = log.UNUSUAL
354         self.log(format="error during query: %(f_value)s",
355                  f_value=str(f.value), failure=f, level=level, umid="gOJB5g")
356
357     def _check_for_done(self, res):
358         # exit paths:
359         #  return : keep waiting, no new queries
360         #  return self._send_more_queries(outstanding) : send some more queries
361         #  fire self._done(plaintext) : download successful
362         #  raise exception : download fails
363
364         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
365                  running=self._running, decoding=self._decoding,
366                  level=log.NOISY)
367         if not self._running:
368             return
369         if self._decoding:
370             return
371         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
372          offsets_tuple) = self.verinfo
373
374         if len(self.shares) < k:
375             # we don't have enough shares yet
376             return self._maybe_send_more_queries(k)
377         if self._need_privkey:
378             # we got k shares, but none of them had a valid privkey. TODO:
379             # look further. Adding code to do this is a bit complicated, and
380             # I want to avoid that complication, and this should be pretty
381             # rare (k shares with bitflips in the enc_privkey but not in the
382             # data blocks). If we actually do get here, the subsequent repair
383             # will fail for lack of a privkey.
384             self.log("got k shares but still need_privkey, bummer",
385                      level=log.WEIRD, umid="MdRHPA")
386
387         # we have enough to finish. All the shares have had their hashes
388         # checked, so if something fails at this point, we don't know how
389         # to fix it, so the download will fail.
390
391         self._decoding = True # avoid reentrancy
392         self._status.set_status("decoding")
393         now = time.time()
394         elapsed = now - self._started
395         self._status.timings["fetch"] = elapsed
396
397         d = defer.maybeDeferred(self._decode)
398         d.addCallback(self._decrypt, IV, self._node._readkey)
399         d.addBoth(self._done)
400         return d # purely for test convenience
401
402     def _maybe_send_more_queries(self, k):
403         # we don't have enough shares yet. Should we send out more queries?
404         # There are some number of queries outstanding, each for a single
405         # share. If we can generate 'needed_shares' additional queries, we do
406         # so. If we can't, then we know this file is a goner, and we raise
407         # NotEnoughSharesError.
408         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
409                          "outstanding=%(outstanding)d"),
410                  have=len(self.shares), k=k,
411                  outstanding=len(self._outstanding_queries),
412                  level=log.NOISY)
413
414         remaining_shares = k - len(self.shares)
415         needed = remaining_shares - len(self._outstanding_queries)
416         if not needed:
417             # we have enough queries in flight already
418
419             # TODO: but if they've been in flight for a long time, and we
420             # have reason to believe that new queries might respond faster
421             # (i.e. we've seen other queries come back faster, then consider
422             # sending out new queries. This could help with peers which have
423             # silently gone away since the servermap was updated, for which
424             # we're still waiting for the 15-minute TCP disconnect to happen.
425             self.log("enough queries are in flight, no more are needed",
426                      level=log.NOISY)
427             return
428
429         outstanding_shnums = set([shnum
430                                   for (peerid, shnum, started)
431                                   in self._outstanding_queries.values()])
432         # prefer low-numbered shares, they are more likely to be primary
433         available_shnums = sorted(self.remaining_sharemap.keys())
434         for shnum in available_shnums:
435             if shnum in outstanding_shnums:
436                 # skip ones that are already in transit
437                 continue
438             if shnum not in self.remaining_sharemap:
439                 # no servers for that shnum. note that DictOfSets removes
440                 # empty sets from the dict for us.
441                 continue
442             peerid = list(self.remaining_sharemap[shnum])[0]
443             # get_data will remove that peerid from the sharemap, and add the
444             # query to self._outstanding_queries
445             self._status.set_status("Retrieving More Shares")
446             self.get_data(shnum, peerid)
447             needed -= 1
448             if not needed:
449                 break
450
451         # at this point, we have as many outstanding queries as we can. If
452         # needed!=0 then we might not have enough to recover the file.
453         if needed:
454             format = ("ran out of peers: "
455                       "have %(have)d shares (k=%(k)d), "
456                       "%(outstanding)d queries in flight, "
457                       "need %(need)d more, "
458                       "found %(bad)d bad shares")
459             args = {"have": len(self.shares),
460                     "k": k,
461                     "outstanding": len(self._outstanding_queries),
462                     "need": needed,
463                     "bad": len(self._bad_shares),
464                     }
465             self.log(format=format,
466                      level=log.WEIRD, umid="ezTfjw", **args)
467             err = NotEnoughSharesError("%s, last failure: %s" %
468                                       (format % args, self._last_failure))
469             if self._bad_shares:
470                 self.log("We found some bad shares this pass. You should "
471                          "update the servermap and try again to check "
472                          "more peers",
473                          level=log.WEIRD, umid="EFkOlA")
474                 err.servermap = self.servermap
475             raise err
476
477         return
478
479     def _decode(self):
480         started = time.time()
481         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
482          offsets_tuple) = self.verinfo
483
484         # shares_dict is a dict mapping shnum to share data, but the codec
485         # wants two lists.
486         shareids = []; shares = []
487         for shareid, share in self.shares.items():
488             shareids.append(shareid)
489             shares.append(share)
490
491         assert len(shareids) >= k, len(shareids)
492         # zfec really doesn't want extra shares
493         shareids = shareids[:k]
494         shares = shares[:k]
495
496         fec = codec.CRSDecoder()
497         fec.set_params(segsize, k, N)
498
499         self.log("params %s, we have %d shares" % ((segsize, k, N), len(shares)))
500         self.log("about to decode, shareids=%s" % (shareids,))
501         d = defer.maybeDeferred(fec.decode, shares, shareids)
502         def _done(buffers):
503             self._status.timings["decode"] = time.time() - started
504             self.log(" decode done, %d buffers" % len(buffers))
505             segment = "".join(buffers)
506             self.log(" joined length %d, datalength %d" %
507                      (len(segment), datalength))
508             segment = segment[:datalength]
509             self.log(" segment len=%d" % len(segment))
510             return segment
511         def _err(f):
512             self.log(" decode failed: %s" % f)
513             return f
514         d.addCallback(_done)
515         d.addErrback(_err)
516         return d
517
518     def _decrypt(self, crypttext, IV, readkey):
519         self._status.set_status("decrypting")
520         started = time.time()
521         key = hashutil.ssk_readkey_data_hash(IV, readkey)
522         decryptor = AES(key)
523         plaintext = decryptor.process(crypttext)
524         self._status.timings["decrypt"] = time.time() - started
525         return plaintext
526
527     def _done(self, res):
528         if not self._running:
529             return
530         self._running = False
531         self._status.set_active(False)
532         self._status.timings["total"] = time.time() - self._started
533         # res is either the new contents, or a Failure
534         if isinstance(res, failure.Failure):
535             self.log("Retrieve done, with failure", failure=res,
536                      level=log.UNUSUAL)
537             self._status.set_status("Failed")
538         else:
539             self.log("Retrieve done, success!")
540             self._status.set_status("Done")
541             self._status.set_progress(1.0)
542             # remember the encoding parameters, use them again next time
543             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
544              offsets_tuple) = self.verinfo
545             self._node._populate_required_shares(k)
546             self._node._populate_total_shares(N)
547         eventually(self._done_deferred.callback, res)
548