]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable WIP: split mutable.py into separate files. All tests pass.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import struct, time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from foolscap.eventual import eventually
8 from allmydata.interfaces import IRetrieveStatus
9 from allmydata.util import hashutil, idlib, log
10 from allmydata import hashtree, codec, storage
11 from allmydata.encode import NotEnoughPeersError
12 from pycryptopp.cipher.aes import AES
13
14 from common import DictOfSets, CorruptShareError, UncoordinatedWriteError
15 from layout import SIGNED_PREFIX, unpack_share_data
16
17 class RetrieveStatus:
18     implements(IRetrieveStatus)
19     statusid_counter = count(0)
20     def __init__(self):
21         self.timings = {}
22         self.timings["fetch_per_server"] = {}
23         self.timings["cumulative_verify"] = 0.0
24         self.sharemap = {}
25         self.problems = {}
26         self.active = True
27         self.storage_index = None
28         self.helper = False
29         self.encoding = ("?","?")
30         self.search_distance = None
31         self.size = None
32         self.status = "Not started"
33         self.progress = 0.0
34         self.counter = self.statusid_counter.next()
35         self.started = time.time()
36
37     def get_started(self):
38         return self.started
39     def get_storage_index(self):
40         return self.storage_index
41     def get_encoding(self):
42         return self.encoding
43     def get_search_distance(self):
44         return self.search_distance
45     def using_helper(self):
46         return self.helper
47     def get_size(self):
48         return self.size
49     def get_status(self):
50         return self.status
51     def get_progress(self):
52         return self.progress
53     def get_active(self):
54         return self.active
55     def get_counter(self):
56         return self.counter
57
58     def set_storage_index(self, si):
59         self.storage_index = si
60     def set_helper(self, helper):
61         self.helper = helper
62     def set_encoding(self, k, n):
63         self.encoding = (k, n)
64     def set_search_distance(self, value):
65         self.search_distance = value
66     def set_size(self, size):
67         self.size = size
68     def set_status(self, status):
69         self.status = status
70     def set_progress(self, value):
71         self.progress = value
72     def set_active(self, value):
73         self.active = value
74
75 class Marker:
76     pass
77
78 class Retrieve:
79     # this class is currently single-use. Eventually (in MDMF) we will make
80     # it multi-use, in which case you can call download(range) multiple
81     # times, and each will have a separate response chain. However the
82     # Retrieve object will remain tied to a specific version of the file, and
83     # will use a single ServerMap instance.
84
85     def __init__(self, filenode, servermap, verinfo):
86         self._node = filenode
87         assert self._node._pubkey
88         self._storage_index = filenode.get_storage_index()
89         assert self._node._readkey
90         self._last_failure = None
91         prefix = storage.si_b2a(self._storage_index)[:5]
92         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
93         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
94         self._running = True
95         self._decoding = False
96
97         self.servermap = servermap
98         assert self._node._pubkey
99         self.verinfo = verinfo
100
101     def log(self, *args, **kwargs):
102         if "parent" not in kwargs:
103             kwargs["parent"] = self._log_number
104         return log.msg(*args, **kwargs)
105
106     def download(self):
107         self._done_deferred = defer.Deferred()
108
109         # first, which servers can we use?
110         versionmap = self.servermap.make_versionmap()
111         shares = versionmap[self.verinfo]
112         # this sharemap is consumed as we decide to send requests
113         self.remaining_sharemap = DictOfSets()
114         for (shnum, peerid, timestamp) in shares:
115             self.remaining_sharemap.add(shnum, peerid)
116
117         self.shares = {} # maps shnum to validated blocks
118
119         # how many shares do we need?
120         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
121          offsets_tuple) = self.verinfo
122         assert len(self.remaining_sharemap) >= k
123         # we start with the lowest shnums we have available, since FEC is
124         # faster if we're using "primary shares"
125         self.active_shnums = set(sorted(self.remaining_sharemap.keys())[:k])
126         for shnum in self.active_shnums:
127             # we use an arbitrary peer who has the share. If shares are
128             # doubled up (more than one share per peer), we could make this
129             # run faster by spreading the load among multiple peers. But the
130             # algorithm to do that is more complicated than I want to write
131             # right now, and a well-provisioned grid shouldn't have multiple
132             # shares per peer.
133             peerid = list(self.remaining_sharemap[shnum])[0]
134             self.get_data(shnum, peerid)
135
136         # control flow beyond this point: state machine. Receiving responses
137         # from queries is the input. We might send out more queries, or we
138         # might produce a result.
139
140         return self._done_deferred
141
142     def get_data(self, shnum, peerid):
143         self.log(format="sending sh#%(shnum)d request to [%(peerid)s]",
144                  shnum=shnum,
145                  peerid=idlib.shortnodeid_b2a(peerid),
146                  level=log.NOISY)
147         ss = self.servermap.connections[peerid]
148         started = time.time()
149         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
150          offsets_tuple) = self.verinfo
151         offsets = dict(offsets_tuple)
152         # we read the checkstring, to make sure that the data we grab is from
153         # the right version. We also read the data, and the hashes necessary
154         # to validate them (share_hash_chain, block_hash_tree, share_data).
155         # We don't read the signature or the pubkey, since that was handled
156         # during the servermap phase, and we'll be comparing the share hash
157         # chain against the roothash that was validated back then.
158         readv = [ (0, struct.calcsize(SIGNED_PREFIX)),
159                   (offsets['share_hash_chain'],
160                    offsets['enc_privkey'] - offsets['share_hash_chain']),
161                   ]
162
163         m = Marker()
164         self._outstanding_queries[m] = (peerid, shnum, started)
165
166         # ask the cache first
167         datav = []
168         #for (offset, length) in readv:
169         #    (data, timestamp) = self._node._cache.read(self.verinfo, shnum,
170         #                                               offset, length)
171         #    if data is not None:
172         #        datav.append(data)
173         if len(datav) == len(readv):
174             self.log("got data from cache")
175             d = defer.succeed(datav)
176         else:
177             self.remaining_sharemap[shnum].remove(peerid)
178             d = self._do_read(ss, peerid, self._storage_index, [shnum], readv)
179             d.addCallback(self._fill_cache, readv)
180
181         d.addCallback(self._got_results, m, peerid, started)
182         d.addErrback(self._query_failed, m, peerid)
183         # errors that aren't handled by _query_failed (and errors caused by
184         # _query_failed) get logged, but we still want to check for doneness.
185         def _oops(f):
186             self.log(format="problem in _query_failed for sh#%(shnum)d to %(peerid)s",
187                      shnum=shnum,
188                      peerid=idlib.shortnodeid_b2a(peerid),
189                      failure=f,
190                      level=log.WEIRD)
191         d.addErrback(_oops)
192         d.addBoth(self._check_for_done)
193         # any error during _check_for_done means the download fails. If the
194         # download is successful, _check_for_done will fire _done by itself.
195         d.addErrback(self._done)
196         d.addErrback(log.err)
197         return d # purely for testing convenience
198
199     def _fill_cache(self, datavs, readv):
200         timestamp = time.time()
201         for shnum,datav in datavs.items():
202             for i, (offset, length) in enumerate(readv):
203                 data = datav[i]
204                 self._node._cache.add(self.verinfo, shnum, offset, data,
205                                       timestamp)
206         return datavs
207
208     def _do_read(self, ss, peerid, storage_index, shnums, readv):
209         # isolate the callRemote to a separate method, so tests can subclass
210         # Publish and override it
211         d = ss.callRemote("slot_readv", storage_index, shnums, readv)
212         return d
213
214     def remove_peer(self, peerid):
215         for shnum in list(self.remaining_sharemap.keys()):
216             self.remaining_sharemap.discard(shnum, peerid)
217
218     def _got_results(self, datavs, marker, peerid, started):
219         self.log(format="got results (%(shares)d shares) from [%(peerid)s]",
220                  shares=len(datavs),
221                  peerid=idlib.shortnodeid_b2a(peerid),
222                  level=log.NOISY)
223         self._outstanding_queries.pop(marker, None)
224         if not self._running:
225             return
226
227         # note that we only ask for a single share per query, so we only
228         # expect a single share back. On the other hand, we use the extra
229         # shares if we get them.. seems better than an assert().
230
231         for shnum,datav in datavs.items():
232             (prefix, hash_and_data) = datav
233             try:
234                 self._got_results_one_share(shnum, peerid,
235                                             prefix, hash_and_data)
236             except CorruptShareError, e:
237                 # log it and give the other shares a chance to be processed
238                 f = failure.Failure()
239                 self.log("bad share: %s %s" % (f, f.value), level=log.WEIRD)
240                 self.remove_peer(peerid)
241                 self._last_failure = f
242                 pass
243         # all done!
244
245     def _got_results_one_share(self, shnum, peerid,
246                                got_prefix, got_hash_and_data):
247         self.log("_got_results: got shnum #%d from peerid %s"
248                  % (shnum, idlib.shortnodeid_b2a(peerid)))
249         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
250          offsets_tuple) = self.verinfo
251         assert len(got_prefix) == len(prefix), (len(got_prefix), len(prefix))
252         if got_prefix != prefix:
253             msg = "someone wrote to the data since we read the servermap: prefix changed"
254             raise UncoordinatedWriteError(msg)
255         (share_hash_chain, block_hash_tree,
256          share_data) = unpack_share_data(self.verinfo, got_hash_and_data)
257
258         assert isinstance(share_data, str)
259         # build the block hash tree. SDMF has only one leaf.
260         leaves = [hashutil.block_hash(share_data)]
261         t = hashtree.HashTree(leaves)
262         if list(t) != block_hash_tree:
263             raise CorruptShareError(peerid, shnum, "block hash tree failure")
264         share_hash_leaf = t[0]
265         t2 = hashtree.IncompleteHashTree(N)
266         # root_hash was checked by the signature
267         t2.set_hashes({0: root_hash})
268         try:
269             t2.set_hashes(hashes=share_hash_chain,
270                           leaves={shnum: share_hash_leaf})
271         except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
272                 IndexError), e:
273             msg = "corrupt hashes: %s" % (e,)
274             raise CorruptShareError(peerid, shnum, msg)
275         self.log(" data valid! len=%d" % len(share_data))
276         # each query comes down to this: placing validated share data into
277         # self.shares
278         self.shares[shnum] = share_data
279
280     def _query_failed(self, f, marker, peerid):
281         self.log(format="query to [%(peerid)s] failed",
282                  peerid=idlib.shortnodeid_b2a(peerid),
283                  level=log.NOISY)
284         self._outstanding_queries.pop(marker, None)
285         if not self._running:
286             return
287         self._last_failure = f
288         self.remove_peer(peerid)
289         self.log("error during query: %s %s" % (f, f.value), level=log.WEIRD)
290
291     def _check_for_done(self, res):
292         # exit paths:
293         #  return : keep waiting, no new queries
294         #  return self._send_more_queries(outstanding) : send some more queries
295         #  fire self._done(plaintext) : download successful
296         #  raise exception : download fails
297
298         self.log(format="_check_for_done: running=%(running)s, decoding=%(decoding)s",
299                  running=self._running, decoding=self._decoding,
300                  level=log.NOISY)
301         if not self._running:
302             return
303         if self._decoding:
304             return
305         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
306          offsets_tuple) = self.verinfo
307
308         if len(self.shares) < k:
309             # we don't have enough shares yet
310             return self._maybe_send_more_queries(k)
311
312         # we have enough to finish. All the shares have had their hashes
313         # checked, so if something fails at this point, we don't know how
314         # to fix it, so the download will fail.
315
316         self._decoding = True # avoid reentrancy
317
318         d = defer.maybeDeferred(self._decode)
319         d.addCallback(self._decrypt, IV, self._node._readkey)
320         d.addBoth(self._done)
321         return d # purely for test convenience
322
323     def _maybe_send_more_queries(self, k):
324         # we don't have enough shares yet. Should we send out more queries?
325         # There are some number of queries outstanding, each for a single
326         # share. If we can generate 'needed_shares' additional queries, we do
327         # so. If we can't, then we know this file is a goner, and we raise
328         # NotEnoughPeersError.
329         self.log(format=("_maybe_send_more_queries, have=%(have)d, k=%(k)d, "
330                          "outstanding=%(outstanding)d"),
331                  have=len(self.shares), k=k,
332                  outstanding=len(self._outstanding_queries),
333                  level=log.NOISY)
334
335         remaining_shares = k - len(self.shares)
336         needed = remaining_shares - len(self._outstanding_queries)
337         if not needed:
338             # we have enough queries in flight already
339
340             # TODO: but if they've been in flight for a long time, and we
341             # have reason to believe that new queries might respond faster
342             # (i.e. we've seen other queries come back faster, then consider
343             # sending out new queries. This could help with peers which have
344             # silently gone away since the servermap was updated, for which
345             # we're still waiting for the 15-minute TCP disconnect to happen.
346             self.log("enough queries are in flight, no more are needed",
347                      level=log.NOISY)
348             return
349
350         outstanding_shnums = set([shnum
351                                   for (peerid, shnum, started)
352                                   in self._outstanding_queries.values()])
353         # prefer low-numbered shares, they are more likely to be primary
354         available_shnums = sorted(self.remaining_sharemap.keys())
355         for shnum in available_shnums:
356             if shnum in outstanding_shnums:
357                 # skip ones that are already in transit
358                 continue
359             if shnum not in self.remaining_sharemap:
360                 # no servers for that shnum. note that DictOfSets removes
361                 # empty sets from the dict for us.
362                 continue
363             peerid = list(self.remaining_sharemap[shnum])[0]
364             # get_data will remove that peerid from the sharemap, and add the
365             # query to self._outstanding_queries
366             self.get_data(shnum, peerid)
367             needed -= 1
368             if not needed:
369                 break
370
371         # at this point, we have as many outstanding queries as we can. If
372         # needed!=0 then we might not have enough to recover the file.
373         if needed:
374             format = ("ran out of peers: "
375                       "have %(have)d shares (k=%(k)d), "
376                       "%(outstanding)d queries in flight, "
377                       "need %(need)d more")
378             self.log(format=format,
379                      have=len(self.shares), k=k,
380                      outstanding=len(self._outstanding_queries),
381                      need=needed,
382                      level=log.WEIRD)
383             msg2 = format % {"have": len(self.shares),
384                              "k": k,
385                              "outstanding": len(self._outstanding_queries),
386                              "need": needed,
387                              }
388             raise NotEnoughPeersError("%s, last failure: %s" %
389                                       (msg2, self._last_failure))
390
391         return
392
393     def _decode(self):
394         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
395          offsets_tuple) = self.verinfo
396
397         # shares_dict is a dict mapping shnum to share data, but the codec
398         # wants two lists.
399         shareids = []; shares = []
400         for shareid, share in self.shares.items():
401             shareids.append(shareid)
402             shares.append(share)
403
404         assert len(shareids) >= k, len(shareids)
405         # zfec really doesn't want extra shares
406         shareids = shareids[:k]
407         shares = shares[:k]
408
409         fec = codec.CRSDecoder()
410         params = "%d-%d-%d" % (segsize, k, N)
411         fec.set_serialized_params(params)
412
413         self.log("params %s, we have %d shares" % (params, len(shares)))
414         self.log("about to decode, shareids=%s" % (shareids,))
415         d = defer.maybeDeferred(fec.decode, shares, shareids)
416         def _done(buffers):
417             self.log(" decode done, %d buffers" % len(buffers))
418             segment = "".join(buffers)
419             self.log(" joined length %d, datalength %d" %
420                      (len(segment), datalength))
421             segment = segment[:datalength]
422             self.log(" segment len=%d" % len(segment))
423             return segment
424         def _err(f):
425             self.log(" decode failed: %s" % f)
426             return f
427         d.addCallback(_done)
428         d.addErrback(_err)
429         return d
430
431     def _decrypt(self, crypttext, IV, readkey):
432         started = time.time()
433         key = hashutil.ssk_readkey_data_hash(IV, readkey)
434         decryptor = AES(key)
435         plaintext = decryptor.process(crypttext)
436         return plaintext
437
438     def _done(self, res):
439         if not self._running:
440             return
441         self._running = False
442         # res is either the new contents, or a Failure
443         if isinstance(res, failure.Failure):
444             self.log("Retrieve done, with failure", failure=res)
445         else:
446             self.log("Retrieve done, success!: res=%s" % (res,))
447             # remember the encoding parameters, use them again next time
448             (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
449              offsets_tuple) = self.verinfo
450             self._node._populate_required_shares(k)
451             self._node._populate_total_shares(N)
452         eventually(self._done_deferred.callback, res)
453