]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
break storage.py into smaller pieces in storage/*.py . No behavioral changes.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap import DeadReferenceError
8 from foolscap.eventual import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError
10 from allmydata.util import hashutil, idlib, log
11 from allmydata import hashtree, codec
12 from allmydata.storage.server import si_b2a
13 from pycryptopp.cipher.aes import AES
14 from pycryptopp.publickey import rsa
15
16 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
17 from layout import SIGNED_PREFIX, unpack_share_data
18
19 class RetrieveStatus:
20     implements(IRetrieveStatus)
21     statusid_counter = count(0)
22     def __init__(self):
23         self.timings = {}
24         self.timings["fetch_per_server"] = {}
25         self.timings["cumulative_verify"] = 0.0
26         self.problems = {}
27         self.active = True
28         self.storage_index = None
29         self.helper = False
30         self.encoding = ("?","?")
31         self.size = None
32         self.status = "Not started"
33         self.progress = 0.0
34         self.counter = self.statusid_counter.next()
35         self.started = time.time()
36
37     def get_started(self):
38         return self.started
39     def get_storage_index(self):
40         return self.storage_index
41     def get_encoding(self):
42         return self.encoding
43     def using_helper(self):
44         return self.helper
45     def get_size(self):
46         return self.size
47     def get_status(self):
48         return self.status
49     def get_progress(self):
50         return self.progress
51     def get_active(self):
52         return self.active
53     def get_counter(self):
54         return self.counter
55
56     def add_fetch_timing(self, peerid, elapsed):
57         if peerid not in self.timings["fetch_per_server"]:
58             self.timings["fetch_per_server"][peerid] = []
59         self.timings["fetch_per_server"][peerid].append(elapsed)
60     def set_storage_index(self, si):
61         self.storage_index = si
62     def set_helper(self, helper):
63         self.helper = helper
64     def set_encoding(self, k, n):
65         self.encoding = (k, n)
66     def set_size(self, size):
67         self.size = size
68     def set_status(self, status):
69         self.status = status
70     def set_progress(self, value):
71         self.progress = value
72     def set_active(self, value):
73         self.active = value
74
75 class Marker:
76     pass
77
78 class Retrieve:
79     # this class is currently single-use. Eventually (in MDMF) we will make
80     # it multi-use, in which case you can call download(range) multiple
81     # times, and each will have a separate response chain. However the
82     # Retrieve object will remain tied to a specific version of the file, and
83     # will use a single ServerMap instance.
84
85     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False):
86         self._node = filenode
87         assert self._node._pubkey
88         self._storage_index = filenode.get_storage_index()
89         assert self._node._readkey
90         self._last_failure = None
91         prefix = si_b2a(self._storage_index)[:5]
92         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
93         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
94         self._running = True
95         self._decoding = False
96         self._bad_shares = set()
97
98         self.servermap = servermap
99         assert self._node._pubkey
100         self.verinfo = verinfo
101         # during repair, we may be called upon to grab the private key, since
102         # it wasn't picked up during a verify=False checker run, and we'll
103         # need it for repair to generate the a new version.
104         self._need_privkey = fetch_privkey
105         if self._node._privkey:
106             self._need_privkey = False
107
108         self._status = RetrieveStatus()
109         self._status.set_storage_index(self._storage_index)
110         self._status.set_helper(False)
111         self._status.set_progress(0.0)
112         self._status.set_active(True)
113         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
114          offsets_tuple) = self.verinfo
115         self._status.set_size(datalength)
116         self._status.set_encoding(k, N)
117
118     def get_status(self):
119         return self._status
120
121     def log(self, *args, **kwargs):
122         if "parent" not in kwargs:
123             kwargs["parent"] = self._log_number
124         if "facility" not in kwargs:
125             kwargs["facility"] = "tahoe.mutable.retrieve"
126         return log.msg(*args, **kwargs)
127
128     def download(self):
129         self._done_deferred = defer.Deferred()
130         self._started = time.time()
131         self._status.set_status("Retrieving Shares")
132
133         # first, which servers can we use?
134         versionmap = self.servermap.make_versionmap()
135         shares = versionmap[self.verinfo]
136         # this sharemap is consumed as we decide to send requests
137         self.remaining_sharemap = DictOfSets()
138         for (shnum, peerid, timestamp) in shares:
139             self.remaining_sharemap.add(shnum, peerid)
140
141         self.shares = {} # maps shnum to validated blocks
142
143         # how many shares do we need?
144         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
145          offsets_tuple) = self.verinfo
146         assert len(self.remaining_sharemap) >= k
147         # we start with the lowest shnums we have available, since FEC is
148         # faster if we're using "primary shares"
149         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
150         for shnum in self.active_shnums:
151             # we use an arbitrary peer who has the share. If shares are
152             # doubled up (more than one share per peer), we could make this
153             # run faster by spreading the load among multiple peers. But the
154             # algorithm to do that is more complicated than I want to write
155             # right now, and a well-provisioned grid shouldn't have multiple
156             # shares per peer.
157             peerid = list(self.remaining_sharemap[shnum])[0]
158             self.get_data(shnum, peerid)
159
160         # control flow beyond this point: state machine. Receiving responses
161         # from queries is the input. We might send out more queries, or we
162         # might produce a result.
163
164         return self._done_deferred
165
166     def get_data(self, shnum, peerid):
167         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
168                  shnum=shnum,
169                  peerid=idlib.shortnodeid_b2a(peerid),
170                  level=log.NOISY)
171         ss = self.servermap.connections[peerid]
172         started = time.time()
173         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
174          offsets_tuple) = self.verinfo
175         offsets = dict(offsets_tuple)
176
177         # we read the checkstring, to make sure that the data we grab is from
178         # the right version.
179         readv = [ (0, struct.calcsize(SIGNED_PREFIX)) ]
180
181         # We also read the data, and the hashes necessary to validate them
182         # (share_hash_chain, block_hash_tree, share_data). We don't read the
183         # signature or the pubkey, since that was handled during the
184         # servermap phase, and we'll be comparing the share hash chain
185         # against the roothash that was validated back then.
186
187         readv.append( (offsets['share_hash_chain'],
188                        offsets['enc_privkey'] - offsets['share_hash_chain'] ) )
189
190         # if we need the private key (for repair), we also fetch that
191         if self._need_privkey:
192             readv.append( (offsets['enc_privkey'],
193                            offsets['EOF'] - offsets['enc_privkey']) )
194
195         m = Marker()
196         self._outstanding_queries[m] = (peerid, shnum, started)
197
198         # ask the cache first
199         got_from_cache = False
200         datavs = []
201         for (offset, length) in readv:
202             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
203                                                        offset, length)
204             if data is not None:
205                 datavs.append(data)
206         if len(datavs) == len(readv):
207             self.log("got data from cache")
208             got_from_cache = True
209             d = fireEventually({shnum: datavs})
210             # datavs is a dict mapping shnum to a pair of strings
211         else:
212             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
213         self.remaining_sharemap.discard(shnum, peerid)
214
215         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
216         d.addErrback(self._query_failed, m, peerid)
217         # errors that aren't handled by _query_failed (and errors caused by
218         # _query_failed) get logged, but we still want to check for doneness.
219         def _oops(f):
220             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
221                      shnum=shnum,
222                      peerid=idlib.shortnodeid_b2a(peerid),
223                      failure=f,
224                      level=log.WEIRD, umid="W0xnQA")
225         d.addErrback(_oops)
226         d.addBoth(self._check_for_done)
227         # any error during _check_for_done means the download fails. If the
228         # download is successful, _check_for_done will fire _done by itself.
229         d.addErrback(self._done)
230         d.addErrback(log.err)
231         return d # purely for testing convenience
232
233     def _do_read(self, ss, peerid, storage_index, shnums, readv):
234         # isolate the callRemote to a separate method, so tests can subclass
235         # Publish and override it
236         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
237         return d
238
239     def remove_peer(self, peerid):
240         for shnum in list(self.remaining_sharemap.keys()):
241             self.remaining_sharemap.discard(shnum, peerid)
242
243     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
244         now = time.time()
245         elapsed = now - started
246         if not got_from_cache:
247             self._status.add_fetch_timing(peerid, elapsed)
248         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
249                  shares=len(datavs),
250                  peerid=idlib.shortnodeid_b2a(peerid),
251                  level=log.NOISY)
252         self._outstanding_queries.pop(marker, None)
253         if not self._running:
254             return
255
256         # note that we only ask for a single share per query, so we only
257         # expect a single share back. On the other hand, we use the extra
258         # shares if we get them.. seems better than an assert().
259
260         for shnum,datav in datavs.items():
261             (prefix, hash_and_data) = datav[:2]
262             try:
263                 self._got_results_one_share(shnum, peerid,
264                                             prefix, hash_and_data)
265             except CorruptShareError, e:
266                 # log it and give the other shares a chance to be processed
267                 f = failure.Failure()
268                 self.log(format="bad share: %(f_value)s",
269                          f_value=str(f.value), failure=f,
270                          level=log.WEIRD, umid="7fzWZw")
271                 self.notify_server_corruption(peerid, shnum, str(e))
272                 self.remove_peer(peerid)
273                 self.servermap.mark_bad_share(peerid, shnum, prefix)
274                 self._bad_shares.add( (peerid, shnum) )
275                 self._status.problems[peerid] = f
276                 self._last_failure = f
277                 pass
278             if self._need_privkey and len(datav) > 2:
279                 lp = None
280                 self._try_to_validate_privkey(datav[2], peerid, shnum, lp)
281         # all done!
282
283     def notify_server_corruption(self, peerid, shnum, reason):
284         ss = self.servermap.connections[peerid]
285         ss.callRemoteOnly("advise_corrupt_share",
286                           "mutable", self._storage_index, shnum, reason)
287
288     def _got_results_one_share(self, shnum, peerid,
289                                got_prefix, got_hash_and_data):
290         self.log("_got_results: got shnum #%d from peerid %s"
291                  % (shnum, idlib.shortnodeid_b2a(peerid)))
292         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
293          offsets_tuple) = self.verinfo
294         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
295         if got_prefix != prefix:
296             msg = "someone wrote to the data since we read the servermap: prefix changed"
297             raise UncoordinatedWriteError(msg)
298         (share_hash_chain, block_hash_tree,
299          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
300
301         assert isinstance(share_data, str)
302         # build the block hash tree. SDMF has only one leaf.
303         leaves = [hashutil.block_hash(share_data)]
304         t = hashtree.HashTree(leaves)
305         if list(t) != block_hash_tree:
306             raise CorruptShareError(peerid, shnum, "block hash tree failure")
307         share_hash_leaf = t[0]
308         t2 = hashtree.IncompleteHashTree(N)
309         # root_hash was checked by the signature
310         t2.set_hashes({0: root_hash})
311         try:
312             t2.set_hashes(hashes=share_hash_chain,
313                           leaves={shnum: share_hash_leaf})
314         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
315                 IndexError), e:
316             msg = "corrupt hashes: %s" % (e,)
317             raise CorruptShareError(peerid, shnum, msg)
318         self.log(" data valid! len=%d" % len(share_data))
319         # each query comes down to this: placing validated share data into
320         # self.shares
321         self.shares[shnum] = share_data
322
323     def _try_to_validate_privkey(self, enc_privkey, peerid, shnum, lp):
324
325         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
326         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
327         if alleged_writekey != self._node.get_writekey():
328             self.log("invalid privkey from %s shnum %d" %
329                      (idlib.nodeid_b2a(peerid)[:8], shnum),
330                      parent=lp, level=log.WEIRD, umid="YIw4tA")
331             return
332
333         # it's good
334         self.log("got valid privkey from shnum %d on peerid %s" %
335                  (shnum, idlib.shortnodeid_b2a(peerid)),
336                  parent=lp)
337         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
338         self._node._populate_encprivkey(enc_privkey)
339         self._node._populate_privkey(privkey)
340         self._need_privkey = False
341
342     def _query_failed(self, f, marker, peerid):
343         self.log(format="query to [%(peerid)s] failed",
344                  peerid=idlib.shortnodeid_b2a(peerid),
345                  level=log.NOISY)
346         self._status.problems[peerid] = f
347         self._outstanding_queries.pop(marker, None)
348         if not self._running:
349             return
350         self._last_failure = f
351         self.remove_peer(peerid)
352         level = log.WEIRD
353         if f.check(DeadReferenceError):
354             level = log.UNUSUAL
355         self.log(format="error during query: %(f_value)s",
356                  f_value=str(f.value), failure=f, level=level, umid="gOJB5g")
357
358     def _check_for_done(self, res):
359         # exit paths:
360         #  return : keep waiting, no new queries
361         #  return self._send_more_queries(outstanding) : send some more queries
362         #  fire self._done(plaintext) : download successful
363         #  raise exception : download fails
364
365         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
366                  running=self._running, decoding=self._decoding,
367                  level=log.NOISY)
368         if not self._running:
369             return
370         if self._decoding:
371             return
372         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
373          offsets_tuple) = self.verinfo
374
375         if len(self.shares) < k:
376             # we don't have enough shares yet
377             return self._maybe_send_more_queries(k)
378         if self._need_privkey:
379             # we got k shares, but none of them had a valid privkey. TODO:
380             # look further. Adding code to do this is a bit complicated, and
381             # I want to avoid that complication, and this should be pretty
382             # rare (k shares with bitflips in the enc_privkey but not in the
383             # data blocks). If we actually do get here, the subsequent repair
384             # will fail for lack of a privkey.
385             self.log("got k shares but still need_privkey, bummer",
386                      level=log.WEIRD, umid="MdRHPA")
387
388         # we have enough to finish. All the shares have had their hashes
389         # checked, so if something fails at this point, we don't know how
390         # to fix it, so the download will fail.
391
392         self._decoding = True # avoid reentrancy
393         self._status.set_status("decoding")
394         now = time.time()
395         elapsed = now - self._started
396         self._status.timings["fetch"] = elapsed
397
398         d = defer.maybeDeferred(self._decode)
399         d.addCallback(self._decrypt, IV, self._node._readkey)
400         d.addBoth(self._done)
401         return d # purely for test convenience
402
403     def _maybe_send_more_queries(self, k):
404         # we don't have enough shares yet. Should we send out more queries?
405         # There are some number of queries outstanding, each for a single
406         # share. If we can generate 'needed_shares' additional queries, we do
407         # so. If we can't, then we know this file is a goner, and we raise
408         # NotEnoughSharesError.
409         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
410                          "outstanding=%(outstanding)d"),
411                  have=len(self.shares), k=k,
412                  outstanding=len(self._outstanding_queries),
413                  level=log.NOISY)
414
415         remaining_shares = k - len(self.shares)
416         needed = remaining_shares - len(self._outstanding_queries)
417         if not needed:
418             # we have enough queries in flight already
419
420             # TODO: but if they've been in flight for a long time, and we
421             # have reason to believe that new queries might respond faster
422             # (i.e. we've seen other queries come back faster, then consider
423             # sending out new queries. This could help with peers which have
424             # silently gone away since the servermap was updated, for which
425             # we're still waiting for the 15-minute TCP disconnect to happen.
426             self.log("enough queries are in flight, no more are needed",
427                      level=log.NOISY)
428             return
429
430         outstanding_shnums = set([shnum
431                                   for (peerid, shnum, started)
432                                   in self._outstanding_queries.values()])
433         # prefer low-numbered shares, they are more likely to be primary
434         available_shnums = sorted(self.remaining_sharemap.keys())
435         for shnum in available_shnums:
436             if shnum in outstanding_shnums:
437                 # skip ones that are already in transit
438                 continue
439             if shnum not in self.remaining_sharemap:
440                 # no servers for that shnum. note that DictOfSets removes
441                 # empty sets from the dict for us.
442                 continue
443             peerid = list(self.remaining_sharemap[shnum])[0]
444             # get_data will remove that peerid from the sharemap, and add the
445             # query to self._outstanding_queries
446             self._status.set_status("Retrieving More Shares")
447             self.get_data(shnum, peerid)
448             needed -= 1
449             if not needed:
450                 break
451
452         # at this point, we have as many outstanding queries as we can. If
453         # needed!=0 then we might not have enough to recover the file.
454         if needed:
455             format = ("ran out of peers: "
456                       "have %(have)d shares (k=%(k)d), "
457                       "%(outstanding)d queries in flight, "
458                       "need %(need)d more, "
459                       "found %(bad)d bad shares")
460             args = {"have": len(self.shares),
461                     "k": k,
462                     "outstanding": len(self._outstanding_queries),
463                     "need": needed,
464                     "bad": len(self._bad_shares),
465                     }
466             self.log(format=format,
467                      level=log.WEIRD, umid="ezTfjw", **args)
468             err = NotEnoughSharesError("%s, last failure: %s" %
469                                       (format % args, self._last_failure))
470             if self._bad_shares:
471                 self.log("We found some bad shares this pass. You should "
472                          "update the servermap and try again to check "
473                          "more peers",
474                          level=log.WEIRD, umid="EFkOlA")
475                 err.servermap = self.servermap
476             raise err
477
478         return
479
480     def _decode(self):
481         started = time.time()
482         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
483          offsets_tuple) = self.verinfo
484
485         # shares_dict is a dict mapping shnum to share data, but the codec
486         # wants two lists.
487         shareids = []; shares = []
488         for shareid, share in self.shares.items():
489             shareids.append(shareid)
490             shares.append(share)
491
492         assert len(shareids) >= k, len(shareids)
493         # zfec really doesn't want extra shares
494         shareids = shareids[:k]
495         shares = shares[:k]
496
497         fec = codec.CRSDecoder()
498         fec.set_params(segsize, k, N)
499
500         self.log("params %s, we have %d shares" % ((segsize, k, N), len(shares)))
501         self.log("about to decode, shareids=%s" % (shareids,))
502         d = defer.maybeDeferred(fec.decode, shares, shareids)
503         def _done(buffers):
504             self._status.timings["decode"] = time.time() - started
505             self.log(" decode done, %d buffers" % len(buffers))
506             segment = "".join(buffers)
507             self.log(" joined length %d, datalength %d" %
508                      (len(segment), datalength))
509             segment = segment[:datalength]
510             self.log(" segment len=%d" % len(segment))
511             return segment
512         def _err(f):
513             self.log(" decode failed: %s" % f)
514             return f
515         d.addCallback(_done)
516         d.addErrback(_err)
517         return d
518
519     def _decrypt(self, crypttext, IV, readkey):
520         self._status.set_status("decrypting")
521         started = time.time()
522         key = hashutil.ssk_readkey_data_hash(IV, readkey)
523         decryptor = AES(key)
524         plaintext = decryptor.process(crypttext)
525         self._status.timings["decrypt"] = time.time() - started
526         return plaintext
527
528     def _done(self, res):
529         if not self._running:
530             return
531         self._running = False
532         self._status.set_active(False)
533         self._status.timings["total"] = time.time() - self._started
534         # res is either the new contents, or a Failure
535         if isinstance(res, failure.Failure):
536             self.log("Retrieve done, with failure", failure=res,
537                      level=log.UNUSUAL)
538             self._status.set_status("Failed")
539         else:
540             self.log("Retrieve done, success!")
541             self._status.set_status("Done")
542             self._status.set_progress(1.0)
543             # remember the encoding parameters, use them again next time
544             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
545              offsets_tuple) = self.verinfo
546             self._node._populate_required_shares(k)
547             self._node._populate_total_shares(N)
548         eventually(self._done_deferred.callback, res)
549