]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
first pass at a mutable repairer. not tested at all yet, but of course all existing...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap.eventual import eventually, fireEventually
8 from allmydata.interfaces import IRetrieveStatus
9 from allmydata.util import hashutil, idlib, log
10 from allmydata import hashtree, codec, storage
11 from allmydata.immutable.encode import NotEnoughSharesError
12 from pycryptopp.cipher.aes import AES
13
14 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
15 from layout import SIGNED_PREFIX, unpack_share_data
16
17 class RetrieveStatus:
18     implements(IRetrieveStatus)
19     statusid_counter = count(0)
20     def __init__(self):
21         self.timings = {}
22         self.timings["fetch_per_server"] = {}
23         self.timings["cumulative_verify"] = 0.0
24         self.problems = {}
25         self.active = True
26         self.storage_index = None
27         self.helper = False
28         self.encoding = ("?","?")
29         self.size = None
30         self.status = "Not started"
31         self.progress = 0.0
32         self.counter = self.statusid_counter.next()
33         self.started = time.time()
34
35     def get_started(self):
36         return self.started
37     def get_storage_index(self):
38         return self.storage_index
39     def get_encoding(self):
40         return self.encoding
41     def using_helper(self):
42         return self.helper
43     def get_size(self):
44         return self.size
45     def get_status(self):
46         return self.status
47     def get_progress(self):
48         return self.progress
49     def get_active(self):
50         return self.active
51     def get_counter(self):
52         return self.counter
53
54     def add_fetch_timing(self, peerid, elapsed):
55         if peerid not in self.timings["fetch_per_server"]:
56             self.timings["fetch_per_server"][peerid] = []
57         self.timings["fetch_per_server"][peerid].append(elapsed)
58     def set_storage_index(self, si):
59         self.storage_index = si
60     def set_helper(self, helper):
61         self.helper = helper
62     def set_encoding(self, k, n):
63         self.encoding = (k, n)
64     def set_size(self, size):
65         self.size = size
66     def set_status(self, status):
67         self.status = status
68     def set_progress(self, value):
69         self.progress = value
70     def set_active(self, value):
71         self.active = value
72
73 class Marker:
74     pass
75
76 class Retrieve:
77     # this class is currently single-use. Eventually (in MDMF) we will make
78     # it multi-use, in which case you can call download(range) multiple
79     # times, and each will have a separate response chain. However the
80     # Retrieve object will remain tied to a specific version of the file, and
81     # will use a single ServerMap instance.
82
83     def __init__(self, filenode, servermap, verinfo):
84         self._node = filenode
85         assert self._node._pubkey
86         self._storage_index = filenode.get_storage_index()
87         assert self._node._readkey
88         self._last_failure = None
89         prefix = storage.si_b2a(self._storage_index)[:5]
90         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
91         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
92         self._running = True
93         self._decoding = False
94         self._bad_shares = set()
95
96         self.servermap = servermap
97         assert self._node._pubkey
98         self.verinfo = verinfo
99
100         self._status = RetrieveStatus()
101         self._status.set_storage_index(self._storage_index)
102         self._status.set_helper(False)
103         self._status.set_progress(0.0)
104         self._status.set_active(True)
105         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
106          offsets_tuple) = self.verinfo
107         self._status.set_size(datalength)
108         self._status.set_encoding(k, N)
109
110     def get_status(self):
111         return self._status
112
113     def log(self, *args, **kwargs):
114         if "parent" not in kwargs:
115             kwargs["parent"] = self._log_number
116         return log.msg(*args, **kwargs)
117
118     def download(self):
119         self._done_deferred = defer.Deferred()
120         self._started = time.time()
121         self._status.set_status("Retrieving Shares")
122
123         # first, which servers can we use?
124         versionmap = self.servermap.make_versionmap()
125         shares = versionmap[self.verinfo]
126         # this sharemap is consumed as we decide to send requests
127         self.remaining_sharemap = DictOfSets()
128         for (shnum, peerid, timestamp) in shares:
129             self.remaining_sharemap.add(shnum, peerid)
130
131         self.shares = {} # maps shnum to validated blocks
132
133         # how many shares do we need?
134         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
135          offsets_tuple) = self.verinfo
136         assert len(self.remaining_sharemap) >= k
137         # we start with the lowest shnums we have available, since FEC is
138         # faster if we're using "primary shares"
139         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
140         for shnum in self.active_shnums:
141             # we use an arbitrary peer who has the share. If shares are
142             # doubled up (more than one share per peer), we could make this
143             # run faster by spreading the load among multiple peers. But the
144             # algorithm to do that is more complicated than I want to write
145             # right now, and a well-provisioned grid shouldn't have multiple
146             # shares per peer.
147             peerid = list(self.remaining_sharemap[shnum])[0]
148             self.get_data(shnum, peerid)
149
150         # control flow beyond this point: state machine. Receiving responses
151         # from queries is the input. We might send out more queries, or we
152         # might produce a result.
153
154         return self._done_deferred
155
156     def get_data(self, shnum, peerid):
157         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
158                  shnum=shnum,
159                  peerid=idlib.shortnodeid_b2a(peerid),
160                  level=log.NOISY)
161         ss = self.servermap.connections[peerid]
162         started = time.time()
163         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
164          offsets_tuple) = self.verinfo
165         offsets = dict(offsets_tuple)
166         # we read the checkstring, to make sure that the data we grab is from
167         # the right version. We also read the data, and the hashes necessary
168         # to validate them (share_hash_chain, block_hash_tree, share_data).
169         # We don't read the signature or the pubkey, since that was handled
170         # during the servermap phase, and we'll be comparing the share hash
171         # chain against the roothash that was validated back then.
172         readv = [ (0, struct.calcsize(SIGNED_PREFIX)),
173                   (offsets['share_hash_chain'],
174                    offsets['enc_privkey'] - offsets['share_hash_chain']),
175                   ]
176
177         m = Marker()
178         self._outstanding_queries[m] = (peerid, shnum, started)
179
180         # ask the cache first
181         got_from_cache = False
182         datavs = []
183         for (offset, length) in readv:
184             (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
185                                                        offset, length)
186             if data is not None:
187                 datavs.append(data)
188         if len(datavs) == len(readv):
189             self.log("got data from cache")
190             got_from_cache = True
191             d = fireEventually({shnum: datavs})
192             # datavs is a dict mapping shnum to a pair of strings
193         else:
194             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
195         self.remaining_sharemap.discard(shnum, peerid)
196
197         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
198         d.addErrback(self._query_failed, m, peerid)
199         # errors that aren't handled by _query_failed (and errors caused by
200         # _query_failed) get logged, but we still want to check for doneness.
201         def _oops(f):
202             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
203                      shnum=shnum,
204                      peerid=idlib.shortnodeid_b2a(peerid),
205                      failure=f,
206                      level=log.WEIRD)
207         d.addErrback(_oops)
208         d.addBoth(self._check_for_done)
209         # any error during _check_for_done means the download fails. If the
210         # download is successful, _check_for_done will fire _done by itself.
211         d.addErrback(self._done)
212         d.addErrback(log.err)
213         return d # purely for testing convenience
214
215     def _do_read(self, ss, peerid, storage_index, shnums, readv):
216         # isolate the callRemote to a separate method, so tests can subclass
217         # Publish and override it
218         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
219         return d
220
221     def remove_peer(self, peerid):
222         for shnum in list(self.remaining_sharemap.keys()):
223             self.remaining_sharemap.discard(shnum, peerid)
224
225     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
226         now = time.time()
227         elapsed = now - started
228         if not got_from_cache:
229             self._status.add_fetch_timing(peerid, elapsed)
230         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
231                  shares=len(datavs),
232                  peerid=idlib.shortnodeid_b2a(peerid),
233                  level=log.NOISY)
234         self._outstanding_queries.pop(marker, None)
235         if not self._running:
236             return
237
238         # note that we only ask for a single share per query, so we only
239         # expect a single share back. On the other hand, we use the extra
240         # shares if we get them.. seems better than an assert().
241
242         for shnum,datav in datavs.items():
243             (prefix, hash_and_data) = datav
244             try:
245                 self._got_results_one_share(shnum, peerid,
246                                             prefix, hash_and_data)
247             except CorruptShareError, e:
248                 # log it and give the other shares a chance to be processed
249                 f = failure.Failure()
250                 self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD)
251                 self.remove_peer(peerid)
252                 self.servermap.mark_bad_share(peerid, shnum, prefix)
253                 self._bad_shares.add( (peerid, shnum) )
254                 self._status.problems[peerid] = f
255                 self._last_failure = f
256                 pass
257         # all done!
258
259     def _got_results_one_share(self, shnum, peerid,
260                                got_prefix, got_hash_and_data):
261         self.log("_got_results: got shnum #%d from peerid %s"
262                  % (shnum, idlib.shortnodeid_b2a(peerid)))
263         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
264          offsets_tuple) = self.verinfo
265         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
266         if got_prefix != prefix:
267             msg = "someone wrote to the data since we read the servermap: prefix changed"
268             raise UncoordinatedWriteError(msg)
269         (share_hash_chain, block_hash_tree,
270          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
271
272         assert isinstance(share_data, str)
273         # build the block hash tree. SDMF has only one leaf.
274         leaves = [hashutil.block_hash(share_data)]
275         t = hashtree.HashTree(leaves)
276         if list(t) != block_hash_tree:
277             raise CorruptShareError(peerid, shnum, "block hash tree failure")
278         share_hash_leaf = t[0]
279         t2 = hashtree.IncompleteHashTree(N)
280         # root_hash was checked by the signature
281         t2.set_hashes({0: root_hash})
282         try:
283             t2.set_hashes(hashes=share_hash_chain,
284                           leaves={shnum: share_hash_leaf})
285         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
286                 IndexError), e:
287             msg = "corrupt hashes: %s" % (e,)
288             raise CorruptShareError(peerid, shnum, msg)
289         self.log(" data valid! len=%d" % len(share_data))
290         # each query comes down to this: placing validated share data into
291         # self.shares
292         self.shares[shnum] = share_data
293
294     def _query_failed(self, f, marker, peerid):
295         self.log(format="query to [%(peerid)s] failed",
296                  peerid=idlib.shortnodeid_b2a(peerid),
297                  level=log.NOISY)
298         self._status.problems[peerid] = f
299         self._outstanding_queries.pop(marker, None)
300         if not self._running:
301             return
302         self._last_failure = f
303         self.remove_peer(peerid)
304         self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD)
305
306     def _check_for_done(self, res):
307         # exit paths:
308         #  return : keep waiting, no new queries
309         #  return self._send_more_queries(outstanding) : send some more queries
310         #  fire self._done(plaintext) : download successful
311         #  raise exception : download fails
312
313         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
314                  running=self._running, decoding=self._decoding,
315                  level=log.NOISY)
316         if not self._running:
317             return
318         if self._decoding:
319             return
320         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
321          offsets_tuple) = self.verinfo
322
323         if len(self.shares) < k:
324             # we don't have enough shares yet
325             return self._maybe_send_more_queries(k)
326
327         # we have enough to finish. All the shares have had their hashes
328         # checked, so if something fails at this point, we don't know how
329         # to fix it, so the download will fail.
330
331         self._decoding = True # avoid reentrancy
332         self._status.set_status("decoding")
333         now = time.time()
334         elapsed = now - self._started
335         self._status.timings["fetch"] = elapsed
336
337         d = defer.maybeDeferred(self._decode)
338         d.addCallback(self._decrypt, IV, self._node._readkey)
339         d.addBoth(self._done)
340         return d # purely for test convenience
341
342     def _maybe_send_more_queries(self, k):
343         # we don't have enough shares yet. Should we send out more queries?
344         # There are some number of queries outstanding, each for a single
345         # share. If we can generate 'needed_shares' additional queries, we do
346         # so. If we can't, then we know this file is a goner, and we raise
347         # NotEnoughSharesError.
348         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
349                          "outstanding=%(outstanding)d"),
350                  have=len(self.shares), k=k,
351                  outstanding=len(self._outstanding_queries),
352                  level=log.NOISY)
353
354         remaining_shares = k - len(self.shares)
355         needed = remaining_shares - len(self._outstanding_queries)
356         if not needed:
357             # we have enough queries in flight already
358
359             # TODO: but if they've been in flight for a long time, and we
360             # have reason to believe that new queries might respond faster
361             # (i.e. we've seen other queries come back faster, then consider
362             # sending out new queries. This could help with peers which have
363             # silently gone away since the servermap was updated, for which
364             # we're still waiting for the 15-minute TCP disconnect to happen.
365             self.log("enough queries are in flight, no more are needed",
366                      level=log.NOISY)
367             return
368
369         outstanding_shnums = set([shnum
370                                   for (peerid, shnum, started)
371                                   in self._outstanding_queries.values()])
372         # prefer low-numbered shares, they are more likely to be primary
373         available_shnums = sorted(self.remaining_sharemap.keys())
374         for shnum in available_shnums:
375             if shnum in outstanding_shnums:
376                 # skip ones that are already in transit
377                 continue
378             if shnum not in self.remaining_sharemap:
379                 # no servers for that shnum. note that DictOfSets removes
380                 # empty sets from the dict for us.
381                 continue
382             peerid = list(self.remaining_sharemap[shnum])[0]
383             # get_data will remove that peerid from the sharemap, and add the
384             # query to self._outstanding_queries
385             self._status.set_status("Retrieving More Shares")
386             self.get_data(shnum, peerid)
387             needed -= 1
388             if not needed:
389                 break
390
391         # at this point, we have as many outstanding queries as we can. If
392         # needed!=0 then we might not have enough to recover the file.
393         if needed:
394             format = ("ran out of peers: "
395                       "have %(have)d shares (k=%(k)d), "
396                       "%(outstanding)d queries in flight, "
397                       "need %(need)d more, "
398                       "found %(bad)d bad shares")
399             args = {"have": len(self.shares),
400                     "k": k,
401                     "outstanding": len(self._outstanding_queries),
402                     "need": needed,
403                     "bad": len(self._bad_shares),
404                     }
405             self.log(format=format,
406                      level=log.WEIRD, **args)
407             err = NotEnoughSharesError("%s, last failure: %s" %
408                                       (format % args, self._last_failure))
409             if self._bad_shares:
410                 self.log("We found some bad shares this pass. You should "
411                          "update the servermap and try again to check "
412                          "more peers",
413                          level=log.WEIRD)
414                 err.servermap = self.servermap
415             raise err
416
417         return
418
419     def _decode(self):
420         started = time.time()
421         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
422          offsets_tuple) = self.verinfo
423
424         # shares_dict is a dict mapping shnum to share data, but the codec
425         # wants two lists.
426         shareids = []; shares = []
427         for shareid, share in self.shares.items():
428             shareids.append(shareid)
429             shares.append(share)
430
431         assert len(shareids) >= k, len(shareids)
432         # zfec really doesn't want extra shares
433         shareids = shareids[:k]
434         shares = shares[:k]
435
436         fec = codec.CRSDecoder()
437         params = "%d-%d-%d" % (segsize, k, N)
438         fec.set_serialized_params(params)
439
440         self.log("params %s, we have %d shares" % (params, len(shares)))
441         self.log("about to decode, shareids=%s" % (shareids,))
442         d = defer.maybeDeferred(fec.decode, shares, shareids)
443         def _done(buffers):
444             self._status.timings["decode"] = time.time() - started
445             self.log(" decode done, %d buffers" % len(buffers))
446             segment = "".join(buffers)
447             self.log(" joined length %d, datalength %d" %
448                      (len(segment), datalength))
449             segment = segment[:datalength]
450             self.log(" segment len=%d" % len(segment))
451             return segment
452         def _err(f):
453             self.log(" decode failed: %s" % f)
454             return f
455         d.addCallback(_done)
456         d.addErrback(_err)
457         return d
458
459     def _decrypt(self, crypttext, IV, readkey):
460         self._status.set_status("decrypting")
461         started = time.time()
462         key = hashutil.ssk_readkey_data_hash(IV, readkey)
463         decryptor = AES(key)
464         plaintext = decryptor.process(crypttext)
465         self._status.timings["decrypt"] = time.time() - started
466         return plaintext
467
468     def _done(self, res):
469         if not self._running:
470             return
471         self._running = False
472         self._status.set_active(False)
473         self._status.timings["total"] = time.time() - self._started
474         # res is either the new contents, or a Failure
475         if isinstance(res, failure.Failure):
476             self.log("Retrieve done, with failure", failure=res,
477                      level=log.UNUSUAL)
478             self._status.set_status("Failed")
479         else:
480             self.log("Retrieve done, success!")
481             self._status.set_status("Done")
482             self._status.set_progress(1.0)
483             # remember the encoding parameters, use them again next time
484             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
485              offsets_tuple) = self.verinfo
486             self._node._populate_required_shares(k)
487             self._node._populate_total_shares(N)
488         eventually(self._done_deferred.callback, res)
489