]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable WIP: if corrupt shares cause a retrieve to fail, restart it once, ignoring...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap.eventual import eventually
8 from allmydata.interfaces import IRetrieveStatus
9 from allmydata.util import hashutil, idlib, log
10 from allmydata import hashtree, codec, storage
11 from allmydata.encode import NotEnoughPeersError
12 from pycryptopp.cipher.aes import AES
13
14 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
15 from layout import SIGNED_PREFIX, unpack_share_data
16
17 class RetrieveStatus:
18     implements(IRetrieveStatus)
19     statusid_counter = count(0)
20     def __init__(self):
21         self.timings = {}
22         self.timings["fetch_per_server"] = {}
23         self.timings["cumulative_verify"] = 0.0
24         self.sharemap = {}
25         self.problems = {}
26         self.active = True
27         self.storage_index = None
28         self.helper = False
29         self.encoding = ("?","?")
30         self.search_distance = None
31         self.size = None
32         self.status = "Not started"
33         self.progress = 0.0
34         self.counter = self.statusid_counter.next()
35         self.started = time.time()
36
37     def get_started(self):
38         return self.started
39     def get_storage_index(self):
40         return self.storage_index
41     def get_encoding(self):
42         return self.encoding
43     def get_search_distance(self):
44         return self.search_distance
45     def using_helper(self):
46         return self.helper
47     def get_size(self):
48         return self.size
49     def get_status(self):
50         return self.status
51     def get_progress(self):
52         return self.progress
53     def get_active(self):
54         return self.active
55     def get_counter(self):
56         return self.counter
57
58     def set_storage_index(self, si):
59         self.storage_index = si
60     def set_helper(self, helper):
61         self.helper = helper
62     def set_encoding(self, k, n):
63         self.encoding = (k, n)
64     def set_search_distance(self, value):
65         self.search_distance = value
66     def set_size(self, size):
67         self.size = size
68     def set_status(self, status):
69         self.status = status
70     def set_progress(self, value):
71         self.progress = value
72     def set_active(self, value):
73         self.active = value
74
75 class Marker:
76     pass
77
78 class Retrieve:
79     # this class is currently single-use. Eventually (in MDMF) we will make
80     # it multi-use, in which case you can call download(range) multiple
81     # times, and each will have a separate response chain. However the
82     # Retrieve object will remain tied to a specific version of the file, and
83     # will use a single ServerMap instance.
84
85     def __init__(self, filenode, servermap, verinfo):
86         self._node = filenode
87         assert self._node._pubkey
88         self._storage_index = filenode.get_storage_index()
89         assert self._node._readkey
90         self._last_failure = None
91         prefix = storage.si_b2a(self._storage_index)[:5]
92         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
93         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
94         self._running = True
95         self._decoding = False
96         self._bad_shares = set()
97
98         self.servermap = servermap
99         assert self._node._pubkey
100         self.verinfo = verinfo
101
102     def log(self, *args, **kwargs):
103         if "parent" not in kwargs:
104             kwargs["parent"] = self._log_number
105         return log.msg(*args, **kwargs)
106
107     def download(self):
108         self._done_deferred = defer.Deferred()
109
110         # first, which servers can we use?
111         versionmap = self.servermap.make_versionmap()
112         shares = versionmap[self.verinfo]
113         # this sharemap is consumed as we decide to send requests
114         self.remaining_sharemap = DictOfSets()
115         for (shnum, peerid, timestamp) in shares:
116             self.remaining_sharemap.add(shnum, peerid)
117
118         self.shares = {} # maps shnum to validated blocks
119
120         # how many shares do we need?
121         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
122          offsets_tuple) = self.verinfo
123         assert len(self.remaining_sharemap) >= k
124         # we start with the lowest shnums we have available, since FEC is
125         # faster if we're using "primary shares"
126         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
127         for shnum in self.active_shnums:
128             # we use an arbitrary peer who has the share. If shares are
129             # doubled up (more than one share per peer), we could make this
130             # run faster by spreading the load among multiple peers. But the
131             # algorithm to do that is more complicated than I want to write
132             # right now, and a well-provisioned grid shouldn't have multiple
133             # shares per peer.
134             peerid = list(self.remaining_sharemap[shnum])[0]
135             self.get_data(shnum, peerid)
136
137         # control flow beyond this point: state machine. Receiving responses
138         # from queries is the input. We might send out more queries, or we
139         # might produce a result.
140
141         return self._done_deferred
142
143     def get_data(self, shnum, peerid):
144         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
145                  shnum=shnum,
146                  peerid=idlib.shortnodeid_b2a(peerid),
147                  level=log.NOISY)
148         ss = self.servermap.connections[peerid]
149         started = time.time()
150         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
151          offsets_tuple) = self.verinfo
152         offsets = dict(offsets_tuple)
153         # we read the checkstring, to make sure that the data we grab is from
154         # the right version. We also read the data, and the hashes necessary
155         # to validate them (share_hash_chain, block_hash_tree, share_data).
156         # We don't read the signature or the pubkey, since that was handled
157         # during the servermap phase, and we'll be comparing the share hash
158         # chain against the roothash that was validated back then.
159         readv = [ (0, struct.calcsize(SIGNED_PREFIX)),
160                   (offsets['share_hash_chain'],
161                    offsets['enc_privkey'] - offsets['share_hash_chain']),
162                   ]
163
164         m = Marker()
165         self._outstanding_queries[m] = (peerid, shnum, started)
166
167         # ask the cache first
168         datav = []
169         #for (offset, length) in readv:
170         #    (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
171         #                                               offset, length)
172         #    if data is not None:
173         #        datav.append(data)
174         if len(datav) == len(readv):
175             self.log("got data from cache")
176             d = defer.succeed(datav)
177         else:
178             self.remaining_sharemap[shnum].remove(peerid)
179             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
180             d.addCallback(self._fill_cache, readv)
181
182         d.addCallback(self._got_results, m, peerid, started)
183         d.addErrback(self._query_failed, m, peerid)
184         # errors that aren't handled by _query_failed (and errors caused by
185         # _query_failed) get logged, but we still want to check for doneness.
186         def _oops(f):
187             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
188                      shnum=shnum,
189                      peerid=idlib.shortnodeid_b2a(peerid),
190                      failure=f,
191                      level=log.WEIRD)
192         d.addErrback(_oops)
193         d.addBoth(self._check_for_done)
194         # any error during _check_for_done means the download fails. If the
195         # download is successful, _check_for_done will fire _done by itself.
196         d.addErrback(self._done)
197         d.addErrback(log.err)
198         return d # purely for testing convenience
199
200     def _fill_cache(self, datavs, readv):
201         timestamp = time.time()
202         for shnum,datav in datavs.items():
203             for i, (offset, length) in enumerate(readv):
204                 data = datav[i]
205                 self._node._cache.add(self.verinfo, shnum, offset, data,
206                                       timestamp)
207         return datavs
208
209     def _do_read(self, ss, peerid, storage_index, shnums, readv):
210         # isolate the callRemote to a separate method, so tests can subclass
211         # Publish and override it
212         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
213         return d
214
215     def remove_peer(self, peerid):
216         for shnum in list(self.remaining_sharemap.keys()):
217             self.remaining_sharemap.discard(shnum, peerid)
218
219     def _got_results(self, datavs, marker, peerid, started):
220         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
221                  shares=len(datavs),
222                  peerid=idlib.shortnodeid_b2a(peerid),
223                  level=log.NOISY)
224         self._outstanding_queries.pop(marker, None)
225         if not self._running:
226             return
227
228         # note that we only ask for a single share per query, so we only
229         # expect a single share back. On the other hand, we use the extra
230         # shares if we get them.. seems better than an assert().
231
232         for shnum,datav in datavs.items():
233             (prefix, hash_and_data) = datav
234             try:
235                 self._got_results_one_share(shnum, peerid,
236                                             prefix, hash_and_data)
237             except CorruptShareError, e:
238                 # log it and give the other shares a chance to be processed
239                 f = failure.Failure()
240                 self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD)
241                 self.remove_peer(peerid)
242                 self.servermap.mark_bad_share(peerid, shnum)
243                 self._bad_shares.add( (peerid, shnum) )
244                 self._last_failure = f
245                 pass
246         # all done!
247
248     def _got_results_one_share(self, shnum, peerid,
249                                got_prefix, got_hash_and_data):
250         self.log("_got_results: got shnum #%d from peerid %s"
251                  % (shnum, idlib.shortnodeid_b2a(peerid)))
252         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
253          offsets_tuple) = self.verinfo
254         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
255         if got_prefix != prefix:
256             msg = "someone wrote to the data since we read the servermap: prefix changed"
257             raise UncoordinatedWriteError(msg)
258         (share_hash_chain, block_hash_tree,
259          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
260
261         assert isinstance(share_data, str)
262         # build the block hash tree. SDMF has only one leaf.
263         leaves = [hashutil.block_hash(share_data)]
264         t = hashtree.HashTree(leaves)
265         if list(t) != block_hash_tree:
266             raise CorruptShareError(peerid, shnum, "block hash tree failure")
267         share_hash_leaf = t[0]
268         t2 = hashtree.IncompleteHashTree(N)
269         # root_hash was checked by the signature
270         t2.set_hashes({0: root_hash})
271         try:
272             t2.set_hashes(hashes=share_hash_chain,
273                           leaves={shnum: share_hash_leaf})
274         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
275                 IndexError), e:
276             msg = "corrupt hashes: %s" % (e,)
277             raise CorruptShareError(peerid, shnum, msg)
278         self.log(" data valid! len=%d" % len(share_data))
279         # each query comes down to this: placing validated share data into
280         # self.shares
281         self.shares[shnum] = share_data
282
283     def _query_failed(self, f, marker, peerid):
284         self.log(format="query to [%(peerid)s] failed",
285                  peerid=idlib.shortnodeid_b2a(peerid),
286                  level=log.NOISY)
287         self._outstanding_queries.pop(marker, None)
288         if not self._running:
289             return
290         self._last_failure = f
291         self.remove_peer(peerid)
292         self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD)
293
294     def _check_for_done(self, res):
295         # exit paths:
296         #  return : keep waiting, no new queries
297         #  return self._send_more_queries(outstanding) : send some more queries
298         #  fire self._done(plaintext) : download successful
299         #  raise exception : download fails
300
301         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
302                  running=self._running, decoding=self._decoding,
303                  level=log.NOISY)
304         if not self._running:
305             return
306         if self._decoding:
307             return
308         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
309          offsets_tuple) = self.verinfo
310
311         if len(self.shares) < k:
312             # we don't have enough shares yet
313             return self._maybe_send_more_queries(k)
314
315         # we have enough to finish. All the shares have had their hashes
316         # checked, so if something fails at this point, we don't know how
317         # to fix it, so the download will fail.
318
319         self._decoding = True # avoid reentrancy
320
321         d = defer.maybeDeferred(self._decode)
322         d.addCallback(self._decrypt, IV, self._node._readkey)
323         d.addBoth(self._done)
324         return d # purely for test convenience
325
326     def _maybe_send_more_queries(self, k):
327         # we don't have enough shares yet. Should we send out more queries?
328         # There are some number of queries outstanding, each for a single
329         # share. If we can generate 'needed_shares' additional queries, we do
330         # so. If we can't, then we know this file is a goner, and we raise
331         # NotEnoughPeersError.
332         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
333                          "outstanding=%(outstanding)d"),
334                  have=len(self.shares), k=k,
335                  outstanding=len(self._outstanding_queries),
336                  level=log.NOISY)
337
338         remaining_shares = k - len(self.shares)
339         needed = remaining_shares - len(self._outstanding_queries)
340         if not needed:
341             # we have enough queries in flight already
342
343             # TODO: but if they've been in flight for a long time, and we
344             # have reason to believe that new queries might respond faster
345             # (i.e. we've seen other queries come back faster, then consider
346             # sending out new queries. This could help with peers which have
347             # silently gone away since the servermap was updated, for which
348             # we're still waiting for the 15-minute TCP disconnect to happen.
349             self.log("enough queries are in flight, no more are needed",
350                      level=log.NOISY)
351             return
352
353         outstanding_shnums = set([shnum
354                                   for (peerid, shnum, started)
355                                   in self._outstanding_queries.values()])
356         # prefer low-numbered shares, they are more likely to be primary
357         available_shnums = sorted(self.remaining_sharemap.keys())
358         for shnum in available_shnums:
359             if shnum in outstanding_shnums:
360                 # skip ones that are already in transit
361                 continue
362             if shnum not in self.remaining_sharemap:
363                 # no servers for that shnum. note that DictOfSets removes
364                 # empty sets from the dict for us.
365                 continue
366             peerid = list(self.remaining_sharemap[shnum])[0]
367             # get_data will remove that peerid from the sharemap, and add the
368             # query to self._outstanding_queries
369             self.get_data(shnum, peerid)
370             needed -= 1
371             if not needed:
372                 break
373
374         # at this point, we have as many outstanding queries as we can. If
375         # needed!=0 then we might not have enough to recover the file.
376         if needed:
377             format = ("ran out of peers: "
378                       "have %(have)d shares (k=%(k)d), "
379                       "%(outstanding)d queries in flight, "
380                       "need %(need)d more, "
381                       "found %(bad)d bad shares")
382             args = {"have": len(self.shares),
383                     "k": k,
384                     "outstanding": len(self._outstanding_queries),
385                     "need": needed,
386                     "bad": len(self._bad_shares),
387                     }
388             self.log(format=format,
389                      level=log.WEIRD, **args)
390             err = NotEnoughPeersError("%s, last failure: %s" %
391                                       (format % args, self._last_failure))
392             if self._bad_shares:
393                 self.log("We found some bad shares this pass. You should "
394                          "update the servermap and try again to check "
395                          "more peers",
396                          level=log.WEIRD)
397                 err.worth_retrying = True
398                 err.servermap = self.servermap
399             raise err
400
401         return
402
403     def _decode(self):
404         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
405          offsets_tuple) = self.verinfo
406
407         # shares_dict is a dict mapping shnum to share data, but the codec
408         # wants two lists.
409         shareids = []; shares = []
410         for shareid, share in self.shares.items():
411             shareids.append(shareid)
412             shares.append(share)
413
414         assert len(shareids) >= k, len(shareids)
415         # zfec really doesn't want extra shares
416         shareids = shareids[:k]
417         shares = shares[:k]
418
419         fec = codec.CRSDecoder()
420         params = "%d-%d-%d" % (segsize, k, N)
421         fec.set_serialized_params(params)
422
423         self.log("params %s, we have %d shares" % (params, len(shares)))
424         self.log("about to decode, shareids=%s" % (shareids,))
425         d = defer.maybeDeferred(fec.decode, shares, shareids)
426         def _done(buffers):
427             self.log(" decode done, %d buffers" % len(buffers))
428             segment = "".join(buffers)
429             self.log(" joined length %d, datalength %d" %
430                      (len(segment), datalength))
431             segment = segment[:datalength]
432             self.log(" segment len=%d" % len(segment))
433             return segment
434         def _err(f):
435             self.log(" decode failed: %s" % f)
436             return f
437         d.addCallback(_done)
438         d.addErrback(_err)
439         return d
440
441     def _decrypt(self, crypttext, IV, readkey):
442         started = time.time()
443         key = hashutil.ssk_readkey_data_hash(IV, readkey)
444         decryptor = AES(key)
445         plaintext = decryptor.process(crypttext)
446         return plaintext
447
448     def _done(self, res):
449         if not self._running:
450             return
451         self._running = False
452         # res is either the new contents, or a Failure
453         if isinstance(res, failure.Failure):
454             self.log("Retrieve done, with failure", failure=res)
455         else:
456             self.log("Retrieve done, success!: res=%s" % (res,))
457             # remember the encoding parameters, use them again next time
458             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
459              offsets_tuple) = self.verinfo
460             self._node._populate_required_shares(k)
461             self._node._populate_total_shares(N)
462         eventually(self._done_deferred.callback, res)
463