]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable WIP: re-enable publish/retrieve status
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap.eventual import eventually
8 from allmydata.interfaces import IRetrieveStatus
9 from allmydata.util import hashutil, idlib, log
10 from allmydata import hashtree, codec, storage
11 from allmydata.encode import NotEnoughSharesError
12 from pycryptopp.cipher.aes import AES
13
14 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
15 from layout import SIGNED_PREFIX, unpack_share_data
16
17 class RetrieveStatus:
18     implements(IRetrieveStatus)
19     statusid_counter = count(0)
20     def __init__(self):
21         self.timings = {}
22         self.timings["fetch_per_server"] = {}
23         self.timings["cumulative_verify"] = 0.0
24         self.problems = {}
25         self.active = True
26         self.storage_index = None
27         self.helper = False
28         self.encoding = ("?","?")
29         self.size = None
30         self.status = "Not started"
31         self.progress = 0.0
32         self.counter = self.statusid_counter.next()
33         self.started = time.time()
34
35     def get_started(self):
36         return self.started
37     def get_storage_index(self):
38         return self.storage_index
39     def get_encoding(self):
40         return self.encoding
41     def using_helper(self):
42         return self.helper
43     def get_size(self):
44         return self.size
45     def get_status(self):
46         return self.status
47     def get_progress(self):
48         return self.progress
49     def get_active(self):
50         return self.active
51     def get_counter(self):
52         return self.counter
53
54     def add_fetch_timing(self, peerid, elapsed):
55         if peerid not in self.timings["fetch_per_server"]:
56             self.timings["fetch_per_server"][peerid] = []
57         self.timings["fetch_per_server"][peerid].append(elapsed)
58     def set_storage_index(self, si):
59         self.storage_index = si
60     def set_helper(self, helper):
61         self.helper = helper
62     def set_encoding(self, k, n):
63         self.encoding = (k, n)
64     def set_size(self, size):
65         self.size = size
66     def set_status(self, status):
67         self.status = status
68     def set_progress(self, value):
69         self.progress = value
70     def set_active(self, value):
71         self.active = value
72
73 class Marker:
74     pass
75
76 class Retrieve:
77     # this class is currently single-use. Eventually (in MDMF) we will make
78     # it multi-use, in which case you can call download(range) multiple
79     # times, and each will have a separate response chain. However the
80     # Retrieve object will remain tied to a specific version of the file, and
81     # will use a single ServerMap instance.
82
83     def __init__(self, filenode, servermap, verinfo):
84         self._node = filenode
85         assert self._node._pubkey
86         self._storage_index = filenode.get_storage_index()
87         assert self._node._readkey
88         self._last_failure = None
89         prefix = storage.si_b2a(self._storage_index)[:5]
90         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
91         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
92         self._running = True
93         self._decoding = False
94         self._bad_shares = set()
95
96         self.servermap = servermap
97         assert self._node._pubkey
98         self.verinfo = verinfo
99
100         self._status = RetrieveStatus()
101         self._status.set_storage_index(self._storage_index)
102         self._status.set_helper(False)
103         self._status.set_progress(0.0)
104         self._status.set_active(True)
105         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
106          offsets_tuple) = self.verinfo
107         self._status.set_size(datalength)
108         self._status.set_encoding(k, N)
109
110     def get_status(self):
111         return self._status
112
113     def log(self, *args, **kwargs):
114         if "parent" not in kwargs:
115             kwargs["parent"] = self._log_number
116         return log.msg(*args, **kwargs)
117
118     def download(self):
119         self._done_deferred = defer.Deferred()
120         self._started = time.time()
121         self._status.set_status("Retrieving Shares")
122
123         # first, which servers can we use?
124         versionmap = self.servermap.make_versionmap()
125         shares = versionmap[self.verinfo]
126         # this sharemap is consumed as we decide to send requests
127         self.remaining_sharemap = DictOfSets()
128         for (shnum, peerid, timestamp) in shares:
129             self.remaining_sharemap.add(shnum, peerid)
130
131         self.shares = {} # maps shnum to validated blocks
132
133         # how many shares do we need?
134         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
135          offsets_tuple) = self.verinfo
136         assert len(self.remaining_sharemap) >= k
137         # we start with the lowest shnums we have available, since FEC is
138         # faster if we're using "primary shares"
139         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
140         for shnum in self.active_shnums:
141             # we use an arbitrary peer who has the share. If shares are
142             # doubled up (more than one share per peer), we could make this
143             # run faster by spreading the load among multiple peers. But the
144             # algorithm to do that is more complicated than I want to write
145             # right now, and a well-provisioned grid shouldn't have multiple
146             # shares per peer.
147             peerid = list(self.remaining_sharemap[shnum])[0]
148             self.get_data(shnum, peerid)
149
150         # control flow beyond this point: state machine. Receiving responses
151         # from queries is the input. We might send out more queries, or we
152         # might produce a result.
153
154         return self._done_deferred
155
156     def get_data(self, shnum, peerid):
157         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
158                  shnum=shnum,
159                  peerid=idlib.shortnodeid_b2a(peerid),
160                  level=log.NOISY)
161         ss = self.servermap.connections[peerid]
162         started = time.time()
163         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
164          offsets_tuple) = self.verinfo
165         offsets = dict(offsets_tuple)
166         # we read the checkstring, to make sure that the data we grab is from
167         # the right version. We also read the data, and the hashes necessary
168         # to validate them (share_hash_chain, block_hash_tree, share_data).
169         # We don't read the signature or the pubkey, since that was handled
170         # during the servermap phase, and we'll be comparing the share hash
171         # chain against the roothash that was validated back then.
172         readv = [ (0, struct.calcsize(SIGNED_PREFIX)),
173                   (offsets['share_hash_chain'],
174                    offsets['enc_privkey'] - offsets['share_hash_chain']),
175                   ]
176
177         m = Marker()
178         self._outstanding_queries[m] = (peerid, shnum, started)
179
180         # ask the cache first
181         got_from_cache = False
182         datav = []
183         #for (offset, length) in readv:
184         #    (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
185         #                                               offset, length)
186         #    if data is not None:
187         #        datav.append(data)
188         if len(datav) == len(readv):
189             self.log("got data from cache")
190             got_from_cache = True
191             d = defer.succeed(datav)
192         else:
193             self.remaining_sharemap[shnum].remove(peerid)
194             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
195             d.addCallback(self._fill_cache, readv)
196
197         d.addCallback(self._got_results, m, peerid, started, got_from_cache)
198         d.addErrback(self._query_failed, m, peerid)
199         # errors that aren't handled by _query_failed (and errors caused by
200         # _query_failed) get logged, but we still want to check for doneness.
201         def _oops(f):
202             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
203                      shnum=shnum,
204                      peerid=idlib.shortnodeid_b2a(peerid),
205                      failure=f,
206                      level=log.WEIRD)
207         d.addErrback(_oops)
208         d.addBoth(self._check_for_done)
209         # any error during _check_for_done means the download fails. If the
210         # download is successful, _check_for_done will fire _done by itself.
211         d.addErrback(self._done)
212         d.addErrback(log.err)
213         return d # purely for testing convenience
214
215     def _fill_cache(self, datavs, readv):
216         timestamp = time.time()
217         for shnum,datav in datavs.items():
218             for i, (offset, length) in enumerate(readv):
219                 data = datav[i]
220                 self._node._cache.add(self.verinfo, shnum, offset, data,
221                                       timestamp)
222         return datavs
223
224     def _do_read(self, ss, peerid, storage_index, shnums, readv):
225         # isolate the callRemote to a separate method, so tests can subclass
226         # Publish and override it
227         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
228         return d
229
230     def remove_peer(self, peerid):
231         for shnum in list(self.remaining_sharemap.keys()):
232             self.remaining_sharemap.discard(shnum, peerid)
233
234     def _got_results(self, datavs, marker, peerid, started, got_from_cache):
235         now = time.time()
236         elapsed = now - started
237         if not got_from_cache:
238             self._status.add_fetch_timing(peerid, elapsed)
239         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
240                  shares=len(datavs),
241                  peerid=idlib.shortnodeid_b2a(peerid),
242                  level=log.NOISY)
243         self._outstanding_queries.pop(marker, None)
244         if not self._running:
245             return
246
247         # note that we only ask for a single share per query, so we only
248         # expect a single share back. On the other hand, we use the extra
249         # shares if we get them.. seems better than an assert().
250
251         for shnum,datav in datavs.items():
252             (prefix, hash_and_data) = datav
253             try:
254                 self._got_results_one_share(shnum, peerid,
255                                             prefix, hash_and_data)
256             except CorruptShareError, e:
257                 # log it and give the other shares a chance to be processed
258                 f = failure.Failure()
259                 self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD)
260                 self.remove_peer(peerid)
261                 self.servermap.mark_bad_share(peerid, shnum)
262                 self._bad_shares.add( (peerid, shnum) )
263                 self._status.problems[peerid] = f
264                 self._last_failure = f
265                 pass
266         # all done!
267
268     def _got_results_one_share(self, shnum, peerid,
269                                got_prefix, got_hash_and_data):
270         self.log("_got_results: got shnum #%d from peerid %s"
271                  % (shnum, idlib.shortnodeid_b2a(peerid)))
272         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
273          offsets_tuple) = self.verinfo
274         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
275         if got_prefix != prefix:
276             msg = "someone wrote to the data since we read the servermap: prefix changed"
277             raise UncoordinatedWriteError(msg)
278         (share_hash_chain, block_hash_tree,
279          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
280
281         assert isinstance(share_data, str)
282         # build the block hash tree. SDMF has only one leaf.
283         leaves = [hashutil.block_hash(share_data)]
284         t = hashtree.HashTree(leaves)
285         if list(t) != block_hash_tree:
286             raise CorruptShareError(peerid, shnum, "block hash tree failure")
287         share_hash_leaf = t[0]
288         t2 = hashtree.IncompleteHashTree(N)
289         # root_hash was checked by the signature
290         t2.set_hashes({0: root_hash})
291         try:
292             t2.set_hashes(hashes=share_hash_chain,
293                           leaves={shnum: share_hash_leaf})
294         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
295                 IndexError), e:
296             msg = "corrupt hashes: %s" % (e,)
297             raise CorruptShareError(peerid, shnum, msg)
298         self.log(" data valid! len=%d" % len(share_data))
299         # each query comes down to this: placing validated share data into
300         # self.shares
301         self.shares[shnum] = share_data
302
303     def _query_failed(self, f, marker, peerid):
304         self.log(format="query to [%(peerid)s] failed",
305                  peerid=idlib.shortnodeid_b2a(peerid),
306                  level=log.NOISY)
307         self._status.problems[peerid] = f
308         self._outstanding_queries.pop(marker, None)
309         if not self._running:
310             return
311         self._last_failure = f
312         self.remove_peer(peerid)
313         self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD)
314
315     def _check_for_done(self, res):
316         # exit paths:
317         #  return : keep waiting, no new queries
318         #  return self._send_more_queries(outstanding) : send some more queries
319         #  fire self._done(plaintext) : download successful
320         #  raise exception : download fails
321
322         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
323                  running=self._running, decoding=self._decoding,
324                  level=log.NOISY)
325         if not self._running:
326             return
327         if self._decoding:
328             return
329         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
330          offsets_tuple) = self.verinfo
331
332         if len(self.shares) < k:
333             # we don't have enough shares yet
334             return self._maybe_send_more_queries(k)
335
336         # we have enough to finish. All the shares have had their hashes
337         # checked, so if something fails at this point, we don't know how
338         # to fix it, so the download will fail.
339
340         self._decoding = True # avoid reentrancy
341         self._status.set_status("decoding")
342         now = time.time()
343         elapsed = now - self._started
344         self._status.timings["fetch"] = elapsed
345
346         d = defer.maybeDeferred(self._decode)
347         d.addCallback(self._decrypt, IV, self._node._readkey)
348         d.addBoth(self._done)
349         return d # purely for test convenience
350
351     def _maybe_send_more_queries(self, k):
352         # we don't have enough shares yet. Should we send out more queries?
353         # There are some number of queries outstanding, each for a single
354         # share. If we can generate 'needed_shares' additional queries, we do
355         # so. If we can't, then we know this file is a goner, and we raise
356         # NotEnoughSharesError.
357         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
358                          "outstanding=%(outstanding)d"),
359                  have=len(self.shares), k=k,
360                  outstanding=len(self._outstanding_queries),
361                  level=log.NOISY)
362
363         remaining_shares = k - len(self.shares)
364         needed = remaining_shares - len(self._outstanding_queries)
365         if not needed:
366             # we have enough queries in flight already
367
368             # TODO: but if they've been in flight for a long time, and we
369             # have reason to believe that new queries might respond faster
370             # (i.e. we've seen other queries come back faster, then consider
371             # sending out new queries. This could help with peers which have
372             # silently gone away since the servermap was updated, for which
373             # we're still waiting for the 15-minute TCP disconnect to happen.
374             self.log("enough queries are in flight, no more are needed",
375                      level=log.NOISY)
376             return
377
378         outstanding_shnums = set([shnum
379                                   for (peerid, shnum, started)
380                                   in self._outstanding_queries.values()])
381         # prefer low-numbered shares, they are more likely to be primary
382         available_shnums = sorted(self.remaining_sharemap.keys())
383         for shnum in available_shnums:
384             if shnum in outstanding_shnums:
385                 # skip ones that are already in transit
386                 continue
387             if shnum not in self.remaining_sharemap:
388                 # no servers for that shnum. note that DictOfSets removes
389                 # empty sets from the dict for us.
390                 continue
391             peerid = list(self.remaining_sharemap[shnum])[0]
392             # get_data will remove that peerid from the sharemap, and add the
393             # query to self._outstanding_queries
394             self._status.set_status("Retrieving More Shares")
395             self.get_data(shnum, peerid)
396             needed -= 1
397             if not needed:
398                 break
399
400         # at this point, we have as many outstanding queries as we can. If
401         # needed!=0 then we might not have enough to recover the file.
402         if needed:
403             format = ("ran out of peers: "
404                       "have %(have)d shares (k=%(k)d), "
405                       "%(outstanding)d queries in flight, "
406                       "need %(need)d more, "
407                       "found %(bad)d bad shares")
408             args = {"have": len(self.shares),
409                     "k": k,
410                     "outstanding": len(self._outstanding_queries),
411                     "need": needed,
412                     "bad": len(self._bad_shares),
413                     }
414             self.log(format=format,
415                      level=log.WEIRD, **args)
416             err = NotEnoughSharesError("%s, last failure: %s" %
417                                       (format % args, self._last_failure))
418             if self._bad_shares:
419                 self.log("We found some bad shares this pass. You should "
420                          "update the servermap and try again to check "
421                          "more peers",
422                          level=log.WEIRD)
423                 err.servermap = self.servermap
424             raise err
425
426         return
427
428     def _decode(self):
429         started = time.time()
430         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
431          offsets_tuple) = self.verinfo
432
433         # shares_dict is a dict mapping shnum to share data, but the codec
434         # wants two lists.
435         shareids = []; shares = []
436         for shareid, share in self.shares.items():
437             shareids.append(shareid)
438             shares.append(share)
439
440         assert len(shareids) >= k, len(shareids)
441         # zfec really doesn't want extra shares
442         shareids = shareids[:k]
443         shares = shares[:k]
444
445         fec = codec.CRSDecoder()
446         params = "%d-%d-%d" % (segsize, k, N)
447         fec.set_serialized_params(params)
448
449         self.log("params %s, we have %d shares" % (params, len(shares)))
450         self.log("about to decode, shareids=%s" % (shareids,))
451         d = defer.maybeDeferred(fec.decode, shares, shareids)
452         def _done(buffers):
453             self._status.timings["decode"] = time.time() - started
454             self.log(" decode done, %d buffers" % len(buffers))
455             segment = "".join(buffers)
456             self.log(" joined length %d, datalength %d" %
457                      (len(segment), datalength))
458             segment = segment[:datalength]
459             self.log(" segment len=%d" % len(segment))
460             return segment
461         def _err(f):
462             self.log(" decode failed: %s" % f)
463             return f
464         d.addCallback(_done)
465         d.addErrback(_err)
466         return d
467
468     def _decrypt(self, crypttext, IV, readkey):
469         self._status.set_status("decrypting")
470         started = time.time()
471         key = hashutil.ssk_readkey_data_hash(IV, readkey)
472         decryptor = AES(key)
473         plaintext = decryptor.process(crypttext)
474         self._status.timings["decrypt"] = time.time() - started
475         return plaintext
476
477     def _done(self, res):
478         if not self._running:
479             return
480         self._running = False
481         self._status.set_active(False)
482         self._status.timings["total"] = time.time() - self._started
483         # res is either the new contents, or a Failure
484         if isinstance(res, failure.Failure):
485             self.log("Retrieve done, with failure", failure=res)
486             self._status.set_status("Failed")
487         else:
488             self.log("Retrieve done, success!")
489             self._status.set_status("Done")
490             self._status.set_progress(1.0)
491             # remember the encoding parameters, use them again next time
492             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
493              offsets_tuple) = self.verinfo
494             self._node._populate_required_shares(k)
495             self._node._populate_total_shares(N)
496         eventually(self._done_deferred.callback, res)
497