]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
2be92163ec822df2c972bee9005a5ee264aa6a1f
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
9      RemoteException
10 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
11      DownloadStopped, MDMF_VERSION, SDMF_VERSION
12 from allmydata.util import hashutil, log, mathutil, deferredutil
13 from allmydata.util.dictutil import DictOfSets
14 from allmydata import hashtree, codec
15 from allmydata.storage.server import si_b2a
16 from pycryptopp.cipher.aes import AES
17 from pycryptopp.publickey import rsa
18
19 from allmydata.mutable.common import CorruptShareError, BadShareError, \
20      UncoordinatedWriteError
21 from allmydata.mutable.layout import MDMFSlotReadProxy
22
23 class RetrieveStatus:
24     implements(IRetrieveStatus)
25     statusid_counter = count(0)
26     def __init__(self):
27         self.timings = {}
28         self.timings["fetch_per_server"] = {}
29         self.timings["decode"] = 0.0
30         self.timings["decrypt"] = 0.0
31         self.timings["cumulative_verify"] = 0.0
32         self._problems = {}
33         self.active = True
34         self.storage_index = None
35         self.helper = False
36         self.encoding = ("?","?")
37         self.size = None
38         self.status = "Not started"
39         self.progress = 0.0
40         self.counter = self.statusid_counter.next()
41         self.started = time.time()
42
43     def get_started(self):
44         return self.started
45     def get_storage_index(self):
46         return self.storage_index
47     def get_encoding(self):
48         return self.encoding
49     def using_helper(self):
50         return self.helper
51     def get_size(self):
52         return self.size
53     def get_status(self):
54         return self.status
55     def get_progress(self):
56         return self.progress
57     def get_active(self):
58         return self.active
59     def get_counter(self):
60         return self.counter
61     def get_problems(self):
62         return self._problems
63
64     def add_fetch_timing(self, server, elapsed):
65         if server not in self.timings["fetch_per_server"]:
66             self.timings["fetch_per_server"][server] = []
67         self.timings["fetch_per_server"][server].append(elapsed)
68     def accumulate_decode_time(self, elapsed):
69         self.timings["decode"] += elapsed
70     def accumulate_decrypt_time(self, elapsed):
71         self.timings["decrypt"] += elapsed
72     def set_storage_index(self, si):
73         self.storage_index = si
74     def set_helper(self, helper):
75         self.helper = helper
76     def set_encoding(self, k, n):
77         self.encoding = (k, n)
78     def set_size(self, size):
79         self.size = size
80     def set_status(self, status):
81         self.status = status
82     def set_progress(self, value):
83         self.progress = value
84     def set_active(self, value):
85         self.active = value
86     def add_problem(self, server, f):
87         serverid = server.get_serverid()
88         self._problems[serverid] = f
89
90 class Marker:
91     pass
92
93 class Retrieve:
94     # this class is currently single-use. Eventually (in MDMF) we will make
95     # it multi-use, in which case you can call download(range) multiple
96     # times, and each will have a separate response chain. However the
97     # Retrieve object will remain tied to a specific version of the file, and
98     # will use a single ServerMap instance.
99     implements(IPushProducer)
100
101     def __init__(self, filenode, storage_broker, servermap, verinfo,
102                  fetch_privkey=False, verify=False):
103         self._node = filenode
104         assert self._node.get_pubkey()
105         self._storage_broker = storage_broker
106         self._storage_index = filenode.get_storage_index()
107         assert self._node.get_readkey()
108         self._last_failure = None
109         prefix = si_b2a(self._storage_index)[:5]
110         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
111         self._running = True
112         self._decoding = False
113         self._bad_shares = set()
114
115         self.servermap = servermap
116         assert self._node.get_pubkey()
117         self.verinfo = verinfo
118         # during repair, we may be called upon to grab the private key, since
119         # it wasn't picked up during a verify=False checker run, and we'll
120         # need it for repair to generate a new version.
121         self._need_privkey = verify or (fetch_privkey
122                                         and not self._node.get_privkey())
123
124         if self._need_privkey:
125             # TODO: Evaluate the need for this. We'll use it if we want
126             # to limit how many queries are on the wire for the privkey
127             # at once.
128             self._privkey_query_markers = [] # one Marker for each time we've
129                                              # tried to get the privkey.
130
131         # verify means that we are using the downloader logic to verify all
132         # of our shares. This tells the downloader a few things.
133         #
134         # 1. We need to download all of the shares.
135         # 2. We don't need to decode or decrypt the shares, since our
136         #    caller doesn't care about the plaintext, only the
137         #    information about which shares are or are not valid.
138         # 3. When we are validating readers, we need to validate the
139         #    signature on the prefix. Do we? We already do this in the
140         #    servermap update?
141         self._verify = verify
142
143         self._status = RetrieveStatus()
144         self._status.set_storage_index(self._storage_index)
145         self._status.set_helper(False)
146         self._status.set_progress(0.0)
147         self._status.set_active(True)
148         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
149          offsets_tuple) = self.verinfo
150         self._status.set_size(datalength)
151         self._status.set_encoding(k, N)
152         self.readers = {}
153         self._stopped = False
154         self._pause_deferred = None
155         self._offset = None
156         self._read_length = None
157         self.log("got seqnum %d" % self.verinfo[0])
158
159
160     def get_status(self):
161         return self._status
162
163     def log(self, *args, **kwargs):
164         if "parent" not in kwargs:
165             kwargs["parent"] = self._log_number
166         if "facility" not in kwargs:
167             kwargs["facility"] = "tahoe.mutable.retrieve"
168         return log.msg(*args, **kwargs)
169
170     def _set_current_status(self, state):
171         seg = "%d/%d" % (self._current_segment, self._last_segment)
172         self._status.set_status("segment %s (%s)" % (seg, state))
173
174     ###################
175     # IPushProducer
176
177     def pauseProducing(self):
178         """
179         I am called by my download target if we have produced too much
180         data for it to handle. I make the downloader stop producing new
181         data until my resumeProducing method is called.
182         """
183         if self._pause_deferred is not None:
184             return
185
186         # fired when the download is unpaused.
187         self._old_status = self._status.get_status()
188         self._set_current_status("paused")
189
190         self._pause_deferred = defer.Deferred()
191
192
193     def resumeProducing(self):
194         """
195         I am called by my download target once it is ready to begin
196         receiving data again.
197         """
198         if self._pause_deferred is None:
199             return
200
201         p = self._pause_deferred
202         self._pause_deferred = None
203         self._status.set_status(self._old_status)
204
205         eventually(p.callback, None)
206
207     def stopProducing(self):
208         self._stopped = True
209         self.resumeProducing()
210
211
212     def _check_for_paused(self, res):
213         """
214         I am called just before a write to the consumer. I return a
215         Deferred that eventually fires with the data that is to be
216         written to the consumer. If the download has not been paused,
217         the Deferred fires immediately. Otherwise, the Deferred fires
218         when the downloader is unpaused.
219         """
220         if self._pause_deferred is not None:
221             d = defer.Deferred()
222             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
223             return d
224         return res
225
226     def _check_for_stopped(self, res):
227         if self._stopped:
228             raise DownloadStopped("our Consumer called stopProducing()")
229         return res
230
231
232     def download(self, consumer=None, offset=0, size=None):
233         assert IConsumer.providedBy(consumer) or self._verify
234
235         if consumer:
236             self._consumer = consumer
237             # we provide IPushProducer, so streaming=True, per
238             # IConsumer.
239             self._consumer.registerProducer(self, streaming=True)
240
241         self._done_deferred = defer.Deferred()
242         self._offset = offset
243         self._read_length = size
244         self._setup_encoding_parameters()
245         self._setup_download()
246         self.log("starting download")
247         self._started_fetching = time.time()
248         # The download process beyond this is a state machine.
249         # _add_active_servers will select the servers that we want to use
250         # for the download, and then attempt to start downloading. After
251         # each segment, it will check for doneness, reacting to broken
252         # servers and corrupt shares as necessary. If it runs out of good
253         # servers before downloading all of the segments, _done_deferred
254         # will errback.  Otherwise, it will eventually callback with the
255         # contents of the mutable file.
256         self.loop()
257         return self._done_deferred
258
259     def loop(self):
260         d = fireEventually(None) # avoid #237 recursion limit problem
261         d.addCallback(lambda ign: self._activate_enough_servers())
262         d.addCallback(lambda ign: self._download_current_segment())
263         # when we're done, _download_current_segment will call _done. If we
264         # aren't, it will call loop() again.
265         d.addErrback(self._error)
266
267     def _setup_download(self):
268         self._started = time.time()
269         self._status.set_status("Retrieving Shares")
270
271         # how many shares do we need?
272         (seqnum,
273          root_hash,
274          IV,
275          segsize,
276          datalength,
277          k,
278          N,
279          prefix,
280          offsets_tuple) = self.verinfo
281
282         # first, which servers can we use?
283         versionmap = self.servermap.make_versionmap()
284         shares = versionmap[self.verinfo]
285         # this sharemap is consumed as we decide to send requests
286         self.remaining_sharemap = DictOfSets()
287         for (shnum, server, timestamp) in shares:
288             self.remaining_sharemap.add(shnum, server)
289             # Reuse the SlotReader from the servermap.
290             key = (self.verinfo, server.get_serverid(),
291                    self._storage_index, shnum)
292             if key in self.servermap.proxies:
293                 reader = self.servermap.proxies[key]
294             else:
295                 reader = MDMFSlotReadProxy(server.get_rref(),
296                                            self._storage_index, shnum, None)
297             reader.server = server
298             self.readers[shnum] = reader
299
300         if len(self.remaining_sharemap) < k:
301             self._raise_notenoughshareserror()
302
303         self.shares = {} # maps shnum to validated blocks
304         self._active_readers = [] # list of active readers for this dl.
305         self._block_hash_trees = {} # shnum => hashtree
306
307         for i in xrange(self._total_shares):
308             # So we don't have to do this later.
309             self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
310
311         # We need one share hash tree for the entire file; its leaves
312         # are the roots of the block hash trees for the shares that
313         # comprise it, and its root is in the verinfo.
314         self.share_hash_tree = hashtree.IncompleteHashTree(N)
315         self.share_hash_tree.set_hashes({0: root_hash})
316
317     def decode(self, blocks_and_salts, segnum):
318         """
319         I am a helper method that the mutable file update process uses
320         as a shortcut to decode and decrypt the segments that it needs
321         to fetch in order to perform a file update. I take in a
322         collection of blocks and salts, and pick some of those to make a
323         segment with. I return the plaintext associated with that
324         segment.
325         """
326         # We don't need the block hash trees in this case.
327         self._block_hash_trees = None
328         self._setup_encoding_parameters()
329
330         # _decode_blocks() expects the output of a gatherResults that
331         # contains the outputs of _validate_block() (each of which is a dict
332         # mapping shnum to (block,salt) bytestrings).
333         d = self._decode_blocks([blocks_and_salts], segnum)
334         d.addCallback(self._decrypt_segment)
335         return d
336
337
338     def _setup_encoding_parameters(self):
339         """
340         I set up the encoding parameters, including k, n, the number
341         of segments associated with this file, and the segment decoders.
342         """
343         (seqnum,
344          root_hash,
345          IV,
346          segsize,
347          datalength,
348          k,
349          n,
350          known_prefix,
351          offsets_tuple) = self.verinfo
352         self._required_shares = k
353         self._total_shares = n
354         self._segment_size = segsize
355         self._data_length = datalength
356
357         if not IV:
358             self._version = MDMF_VERSION
359         else:
360             self._version = SDMF_VERSION
361
362         if datalength and segsize:
363             self._num_segments = mathutil.div_ceil(datalength, segsize)
364             self._tail_data_size = datalength % segsize
365         else:
366             self._num_segments = 0
367             self._tail_data_size = 0
368
369         self._segment_decoder = codec.CRSDecoder()
370         self._segment_decoder.set_params(segsize, k, n)
371
372         if  not self._tail_data_size:
373             self._tail_data_size = segsize
374
375         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
376                                                          self._required_shares)
377         if self._tail_segment_size == self._segment_size:
378             self._tail_decoder = self._segment_decoder
379         else:
380             self._tail_decoder = codec.CRSDecoder()
381             self._tail_decoder.set_params(self._tail_segment_size,
382                                           self._required_shares,
383                                           self._total_shares)
384
385         self.log("got encoding parameters: "
386                  "k: %d "
387                  "n: %d "
388                  "%d segments of %d bytes each (%d byte tail segment)" % \
389                  (k, n, self._num_segments, self._segment_size,
390                   self._tail_segment_size))
391
392         # Our last task is to tell the downloader where to start and
393         # where to stop. We use three parameters for that:
394         #   - self._start_segment: the segment that we need to start
395         #     downloading from.
396         #   - self._current_segment: the next segment that we need to
397         #     download.
398         #   - self._last_segment: The last segment that we were asked to
399         #     download.
400         #
401         #  We say that the download is complete when
402         #  self._current_segment > self._last_segment. We use
403         #  self._start_segment and self._last_segment to know when to
404         #  strip things off of segments, and how much to strip.
405         if self._offset:
406             self.log("got offset: %d" % self._offset)
407             # our start segment is the first segment containing the
408             # offset we were given.
409             start = self._offset // self._segment_size
410
411             assert start < self._num_segments
412             self._start_segment = start
413             self.log("got start segment: %d" % self._start_segment)
414         else:
415             self._start_segment = 0
416
417
418         # If self._read_length is None, then we want to read the whole
419         # file. Otherwise, we want to read only part of the file, and
420         # need to figure out where to stop reading.
421         if self._read_length is not None:
422             # our end segment is the last segment containing part of the
423             # segment that we were asked to read.
424             self.log("got read length %d" % self._read_length)
425             if self._read_length != 0:
426                 end_data = self._offset + self._read_length
427
428                 # We don't actually need to read the byte at end_data,
429                 # but the one before it.
430                 end = (end_data - 1) // self._segment_size
431
432                 assert end < self._num_segments
433                 self._last_segment = end
434             else:
435                 self._last_segment = self._start_segment
436             self.log("got end segment: %d" % self._last_segment)
437         else:
438             self._last_segment = self._num_segments - 1
439
440         self._current_segment = self._start_segment
441
442     def _activate_enough_servers(self):
443         """
444         I populate self._active_readers with enough active readers to
445         retrieve the contents of this mutable file. I am called before
446         downloading starts, and (eventually) after each validation
447         error, connection error, or other problem in the download.
448         """
449         # TODO: It would be cool to investigate other heuristics for
450         # reader selection. For instance, the cost (in time the user
451         # spends waiting for their file) of selecting a really slow server
452         # that happens to have a primary share is probably more than
453         # selecting a really fast server that doesn't have a primary
454         # share. Maybe the servermap could be extended to provide this
455         # information; it could keep track of latency information while
456         # it gathers more important data, and then this routine could
457         # use that to select active readers.
458         #
459         # (these and other questions would be easier to answer with a
460         #  robust, configurable tahoe-lafs simulator, which modeled node
461         #  failures, differences in node speed, and other characteristics
462         #  that we expect storage servers to have.  You could have
463         #  presets for really stable grids (like allmydata.com),
464         #  friendnets, make it easy to configure your own settings, and
465         #  then simulate the effect of big changes on these use cases
466         #  instead of just reasoning about what the effect might be. Out
467         #  of scope for MDMF, though.)
468
469         # XXX: Why don't format= log messages work here?
470
471         known_shnums = set(self.remaining_sharemap.keys())
472         used_shnums = set([r.shnum for r in self._active_readers])
473         unused_shnums = known_shnums - used_shnums
474
475         if self._verify:
476             new_shnums = unused_shnums # use them all
477         elif len(self._active_readers) < self._required_shares:
478             # need more shares
479             more = self._required_shares - len(self._active_readers)
480             # We favor lower numbered shares, since FEC is faster with
481             # primary shares than with other shares, and lower-numbered
482             # shares are more likely to be primary than higher numbered
483             # shares.
484             new_shnums = sorted(unused_shnums)[:more]
485             if len(new_shnums) < more:
486                 # We don't have enough readers to retrieve the file; fail.
487                 self._raise_notenoughshareserror()
488         else:
489             new_shnums = []
490
491         self.log("adding %d new servers to the active list" % len(new_shnums))
492         for shnum in new_shnums:
493             reader = self.readers[shnum]
494             self._active_readers.append(reader)
495             self.log("added reader for share %d" % shnum)
496             # Each time we add a reader, we check to see if we need the
497             # private key. If we do, we politely ask for it and then continue
498             # computing. If we find that we haven't gotten it at the end of
499             # segment decoding, then we'll take more drastic measures.
500             if self._need_privkey and not self._node.is_readonly():
501                 d = reader.get_encprivkey()
502                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
503                 # XXX: don't just drop the Deferred. We need error-reporting
504                 # but not flow-control here.
505
506     def _try_to_validate_prefix(self, prefix, reader):
507         """
508         I check that the prefix returned by a candidate server for
509         retrieval matches the prefix that the servermap knows about
510         (and, hence, the prefix that was validated earlier). If it does,
511         I return True, which means that I approve of the use of the
512         candidate server for segment retrieval. If it doesn't, I return
513         False, which means that another server must be chosen.
514         """
515         (seqnum,
516          root_hash,
517          IV,
518          segsize,
519          datalength,
520          k,
521          N,
522          known_prefix,
523          offsets_tuple) = self.verinfo
524         if known_prefix != prefix:
525             self.log("prefix from share %d doesn't match" % reader.shnum)
526             raise UncoordinatedWriteError("Mismatched prefix -- this could "
527                                           "indicate an uncoordinated write")
528         # Otherwise, we're okay -- no issues.
529
530     def _mark_bad_share(self, server, shnum, reader, f):
531         """
532         I mark the given (server, shnum) as a bad share, which means that it
533         will not be used anywhere else.
534
535         There are several reasons to want to mark something as a bad
536         share. These include:
537
538             - A connection error to the server.
539             - A mismatched prefix (that is, a prefix that does not match
540               our local conception of the version information string).
541             - A failing block hash, salt hash, share hash, or other
542               integrity check.
543
544         This method will ensure that readers that we wish to mark bad
545         (for these reasons or other reasons) are not used for the rest
546         of the download. Additionally, it will attempt to tell the
547         remote server (with no guarantee of success) that its share is
548         corrupt.
549         """
550         self.log("marking share %d on server %s as bad" % \
551                  (shnum, server.get_name()))
552         prefix = self.verinfo[-2]
553         self.servermap.mark_bad_share(server, shnum, prefix)
554         self._bad_shares.add((server, shnum, f))
555         self._status.add_problem(server, f)
556         self._last_failure = f
557
558         # Remove the reader from _active_readers
559         self._active_readers.remove(reader)
560         for shnum in list(self.remaining_sharemap.keys()):
561             self.remaining_sharemap.discard(shnum, reader.server)
562
563         if f.check(BadShareError):
564             self.notify_server_corruption(server, shnum, str(f.value))
565
566     def _download_current_segment(self):
567         """
568         I download, validate, decode, decrypt, and assemble the segment
569         that this Retrieve is currently responsible for downloading.
570         """
571         if self._current_segment > self._last_segment:
572             # No more segments to download, we're done.
573             self.log("got plaintext, done")
574             return self._done()
575         elif self._verify and len(self._active_readers) == 0:
576             self.log("no more good shares, no need to keep verifying")
577             return self._done()
578         self.log("on segment %d of %d" %
579                  (self._current_segment + 1, self._num_segments))
580         d = self._process_segment(self._current_segment)
581         d.addCallback(lambda ign: self.loop())
582         return d
583
584     def _process_segment(self, segnum):
585         """
586         I download, validate, decode, and decrypt one segment of the
587         file that this Retrieve is retrieving. This means coordinating
588         the process of getting k blocks of that file, validating them,
589         assembling them into one segment with the decoder, and then
590         decrypting them.
591         """
592         self.log("processing segment %d" % segnum)
593
594         # TODO: The old code uses a marker. Should this code do that
595         # too? What did the Marker do?
596
597         # We need to ask each of our active readers for its block and
598         # salt. We will then validate those. If validation is
599         # successful, we will assemble the results into plaintext.
600         ds = []
601         for reader in self._active_readers:
602             started = time.time()
603             d1 = reader.get_block_and_salt(segnum)
604             d2,d3 = self._get_needed_hashes(reader, segnum)
605             d = deferredutil.gatherResults([d1,d2,d3])
606             d.addCallback(self._validate_block, segnum, reader, reader.server, started)
607             # _handle_bad_share takes care of recoverable errors (by dropping
608             # that share and returning None). Any other errors (i.e. code
609             # bugs) are passed through and cause the retrieve to fail.
610             d.addErrback(self._handle_bad_share, [reader])
611             ds.append(d)
612         dl = deferredutil.gatherResults(ds)
613         if self._verify:
614             dl.addCallback(lambda ignored: "")
615             dl.addCallback(self._set_segment)
616         else:
617             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
618         return dl
619
620
621     def _maybe_decode_and_decrypt_segment(self, results, segnum):
622         """
623         I take the results of fetching and validating the blocks from
624         _process_segment. If validation and fetching succeeded without
625         incident, I will proceed with decoding and decryption. Otherwise, I
626         will do nothing.
627         """
628         self.log("trying to decode and decrypt segment %d" % segnum)
629
630         # 'results' is the output of a gatherResults set up in
631         # _process_segment(). Each component Deferred will either contain the
632         # non-Failure output of _validate_block() for a single block (i.e.
633         # {segnum:(block,salt)}), or None if _validate_block threw an
634         # exception and _validation_or_decoding_failed handled it (by
635         # dropping that server).
636
637         if None in results:
638             self.log("some validation operations failed; not proceeding")
639             return defer.succeed(None)
640         self.log("everything looks ok, building segment %d" % segnum)
641         d = self._decode_blocks(results, segnum)
642         d.addCallback(self._decrypt_segment)
643         # check to see whether we've been paused before writing
644         # anything.
645         d.addCallback(self._check_for_paused)
646         d.addCallback(self._check_for_stopped)
647         d.addCallback(self._set_segment)
648         return d
649
650
651     def _set_segment(self, segment):
652         """
653         Given a plaintext segment, I register that segment with the
654         target that is handling the file download.
655         """
656         self.log("got plaintext for segment %d" % self._current_segment)
657         if self._current_segment == self._start_segment:
658             # We're on the first segment. It's possible that we want
659             # only some part of the end of this segment, and that we
660             # just downloaded the whole thing to get that part. If so,
661             # we need to account for that and give the reader just the
662             # data that they want.
663             n = self._offset % self._segment_size
664             self.log("stripping %d bytes off of the first segment" % n)
665             self.log("original segment length: %d" % len(segment))
666             segment = segment[n:]
667             self.log("new segment length: %d" % len(segment))
668
669         if self._current_segment == self._last_segment and self._read_length is not None:
670             # We're on the last segment. It's possible that we only want
671             # part of the beginning of this segment, and that we
672             # downloaded the whole thing anyway. Make sure to give the
673             # caller only the portion of the segment that they want to
674             # receive.
675             extra = self._read_length
676             if self._start_segment != self._last_segment:
677                 extra -= self._segment_size - \
678                             (self._offset % self._segment_size)
679             extra %= self._segment_size
680             self.log("original segment length: %d" % len(segment))
681             segment = segment[:extra]
682             self.log("new segment length: %d" % len(segment))
683             self.log("only taking %d bytes of the last segment" % extra)
684
685         if not self._verify:
686             self._consumer.write(segment)
687         else:
688             # we don't care about the plaintext if we are doing a verify.
689             segment = None
690         self._current_segment += 1
691
692
693     def _handle_bad_share(self, f, readers):
694         """
695         I am called when a block or a salt fails to correctly validate, or when
696         the decryption or decoding operation fails for some reason.  I react to
697         this failure by notifying the remote server of corruption, and then
698         removing the remote server from further activity.
699         """
700         # these are the errors we can tolerate: by giving up on this share
701         # and finding others to replace it. Any other errors (i.e. coding
702         # bugs) are re-raised, causing the download to fail.
703         f.trap(DeadReferenceError, RemoteException, BadShareError)
704
705         # DeadReferenceError happens when we try to fetch data from a server
706         # that has gone away. RemoteException happens if the server had an
707         # internal error. BadShareError encompasses: (UnknownVersionError,
708         # LayoutInvalid, struct.error) which happen when we get obviously
709         # wrong data, and CorruptShareError which happens later, when we
710         # perform integrity checks on the data.
711
712         assert isinstance(readers, list)
713         bad_shnums = [reader.shnum for reader in readers]
714
715         self.log("validation or decoding failed on share(s) %s, server(s) %s "
716                  ", segment %d: %s" % \
717                  (bad_shnums, readers, self._current_segment, str(f)))
718         for reader in readers:
719             self._mark_bad_share(reader.server, reader.shnum, reader, f)
720         return None
721
722
723     def _validate_block(self, results, segnum, reader, server, started):
724         """
725         I validate a block from one share on a remote server.
726         """
727         # Grab the part of the block hash tree that is necessary to
728         # validate this block, then generate the block hash root.
729         self.log("validating share %d for segment %d" % (reader.shnum,
730                                                              segnum))
731         elapsed = time.time() - started
732         self._status.add_fetch_timing(server, elapsed)
733         self._set_current_status("validating blocks")
734
735         block_and_salt, blockhashes, sharehashes = results
736         block, salt = block_and_salt
737         assert type(block) is str, (block, salt)
738
739         blockhashes = dict(enumerate(blockhashes))
740         self.log("the reader gave me the following blockhashes: %s" % \
741                  blockhashes.keys())
742         self.log("the reader gave me the following sharehashes: %s" % \
743                  sharehashes.keys())
744         bht = self._block_hash_trees[reader.shnum]
745
746         if bht.needed_hashes(segnum, include_leaf=True):
747             try:
748                 bht.set_hashes(blockhashes)
749             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
750                     IndexError), e:
751                 raise CorruptShareError(server,
752                                         reader.shnum,
753                                         "block hash tree failure: %s" % e)
754
755         if self._version == MDMF_VERSION:
756             blockhash = hashutil.block_hash(salt + block)
757         else:
758             blockhash = hashutil.block_hash(block)
759         # If this works without an error, then validation is
760         # successful.
761         try:
762            bht.set_hashes(leaves={segnum: blockhash})
763         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
764                 IndexError), e:
765             raise CorruptShareError(server,
766                                     reader.shnum,
767                                     "block hash tree failure: %s" % e)
768
769         # Reaching this point means that we know that this segment
770         # is correct. Now we need to check to see whether the share
771         # hash chain is also correct.
772         # SDMF wrote share hash chains that didn't contain the
773         # leaves, which would be produced from the block hash tree.
774         # So we need to validate the block hash tree first. If
775         # successful, then bht[0] will contain the root for the
776         # shnum, which will be a leaf in the share hash tree, which
777         # will allow us to validate the rest of the tree.
778         try:
779             self.share_hash_tree.set_hashes(hashes=sharehashes,
780                                         leaves={reader.shnum: bht[0]})
781         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
782                 IndexError), e:
783             raise CorruptShareError(server,
784                                     reader.shnum,
785                                     "corrupt hashes: %s" % e)
786
787         self.log('share %d is valid for segment %d' % (reader.shnum,
788                                                        segnum))
789         return {reader.shnum: (block, salt)}
790
791
792     def _get_needed_hashes(self, reader, segnum):
793         """
794         I get the hashes needed to validate segnum from the reader, then return
795         to my caller when this is done.
796         """
797         bht = self._block_hash_trees[reader.shnum]
798         needed = bht.needed_hashes(segnum, include_leaf=True)
799         # The root of the block hash tree is also a leaf in the share
800         # hash tree. So we don't need to fetch it from the remote
801         # server. In the case of files with one segment, this means that
802         # we won't fetch any block hash tree from the remote server,
803         # since the hash of each share of the file is the entire block
804         # hash tree, and is a leaf in the share hash tree. This is fine,
805         # since any share corruption will be detected in the share hash
806         # tree.
807         #needed.discard(0)
808         self.log("getting blockhashes for segment %d, share %d: %s" % \
809                  (segnum, reader.shnum, str(needed)))
810         # TODO is force_remote necessary here?
811         d1 = reader.get_blockhashes(needed, force_remote=False)
812         if self.share_hash_tree.needed_hashes(reader.shnum):
813             need = self.share_hash_tree.needed_hashes(reader.shnum)
814             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
815                                                                  str(need)))
816             d2 = reader.get_sharehashes(need, force_remote=False)
817         else:
818             d2 = defer.succeed({}) # the logic in the next method
819                                    # expects a dict
820         return d1,d2
821
822
823     def _decode_blocks(self, results, segnum):
824         """
825         I take a list of k blocks and salts, and decode that into a
826         single encrypted segment.
827         """
828         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
829         # want to merge them all
830         blocks_and_salts = {}
831         for d in results:
832             blocks_and_salts.update(d)
833
834         # All of these blocks should have the same salt; in SDMF, it is
835         # the file-wide IV, while in MDMF it is the per-segment salt. In
836         # either case, we just need to get one of them and use it.
837         #
838         # d.items()[0] is like (shnum, (block, salt))
839         # d.items()[0][1] is like (block, salt)
840         # d.items()[0][1][1] is the salt.
841         salt = blocks_and_salts.items()[0][1][1]
842         # Next, extract just the blocks from the dict. We'll use the
843         # salt in the next step.
844         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
845         d2 = dict(share_and_shareids)
846         shareids = []
847         shares = []
848         for shareid, share in d2.items():
849             shareids.append(shareid)
850             shares.append(share)
851
852         self._set_current_status("decoding")
853         started = time.time()
854         assert len(shareids) >= self._required_shares, len(shareids)
855         # zfec really doesn't want extra shares
856         shareids = shareids[:self._required_shares]
857         shares = shares[:self._required_shares]
858         self.log("decoding segment %d" % segnum)
859         if segnum == self._num_segments - 1:
860             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
861         else:
862             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
863         def _process(buffers):
864             segment = "".join(buffers)
865             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
866                      segnum=segnum,
867                      numsegs=self._num_segments,
868                      level=log.NOISY)
869             self.log(" joined length %d, datalength %d" %
870                      (len(segment), self._data_length))
871             if segnum == self._num_segments - 1:
872                 size_to_use = self._tail_data_size
873             else:
874                 size_to_use = self._segment_size
875             segment = segment[:size_to_use]
876             self.log(" segment len=%d" % len(segment))
877             self._status.accumulate_decode_time(time.time() - started)
878             return segment, salt
879         d.addCallback(_process)
880         return d
881
882
883     def _decrypt_segment(self, segment_and_salt):
884         """
885         I take a single segment and its salt, and decrypt it. I return
886         the plaintext of the segment that is in my argument.
887         """
888         segment, salt = segment_and_salt
889         self._set_current_status("decrypting")
890         self.log("decrypting segment %d" % self._current_segment)
891         started = time.time()
892         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
893         decryptor = AES(key)
894         plaintext = decryptor.process(segment)
895         self._status.accumulate_decrypt_time(time.time() - started)
896         return plaintext
897
898
899     def notify_server_corruption(self, server, shnum, reason):
900         rref = server.get_rref()
901         rref.callRemoteOnly("advise_corrupt_share",
902                             "mutable", self._storage_index, shnum, reason)
903
904
905     def _try_to_validate_privkey(self, enc_privkey, reader, server):
906         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
907         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
908         if alleged_writekey != self._node.get_writekey():
909             self.log("invalid privkey from %s shnum %d" %
910                      (reader, reader.shnum),
911                      level=log.WEIRD, umid="YIw4tA")
912             if self._verify:
913                 self.servermap.mark_bad_share(server, reader.shnum,
914                                               self.verinfo[-2])
915                 e = CorruptShareError(server,
916                                       reader.shnum,
917                                       "invalid privkey")
918                 f = failure.Failure(e)
919                 self._bad_shares.add((server, reader.shnum, f))
920             return
921
922         # it's good
923         self.log("got valid privkey from shnum %d on reader %s" %
924                  (reader.shnum, reader))
925         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
926         self._node._populate_encprivkey(enc_privkey)
927         self._node._populate_privkey(privkey)
928         self._need_privkey = False
929
930
931
932     def _done(self):
933         """
934         I am called by _download_current_segment when the download process
935         has finished successfully. After making some useful logging
936         statements, I return the decrypted contents to the owner of this
937         Retrieve object through self._done_deferred.
938         """
939         self._running = False
940         self._status.set_active(False)
941         now = time.time()
942         self._status.timings['total'] = now - self._started
943         self._status.timings['fetch'] = now - self._started_fetching
944         self._status.set_status("Finished")
945         self._status.set_progress(1.0)
946
947         # remember the encoding parameters, use them again next time
948         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
949          offsets_tuple) = self.verinfo
950         self._node._populate_required_shares(k)
951         self._node._populate_total_shares(N)
952
953         if self._verify:
954             ret = self._bad_shares
955             self.log("done verifying, found %d bad shares" % len(ret))
956         else:
957             # TODO: upload status here?
958             ret = self._consumer
959             self._consumer.unregisterProducer()
960         eventually(self._done_deferred.callback, ret)
961
962
963     def _raise_notenoughshareserror(self):
964         """
965         I am called when there are not enough active servers left to complete
966         the download. After making some useful logging statements, I throw an
967         exception to that effect to the caller of this Retrieve object through
968         self._done_deferred.
969         """
970
971         format = ("ran out of servers: "
972                   "have %(have)d of %(total)d segments; "
973                   "found %(bad)d bad shares; "
974                   "have %(remaining)d remaining shares of the right version; "
975                   "encoding %(k)d-of-%(n)d")
976         args = {"have": self._current_segment,
977                 "total": self._num_segments,
978                 "need": self._last_segment,
979                 "k": self._required_shares,
980                 "n": self._total_shares,
981                 "bad": len(self._bad_shares),
982                 "remaining": len(self.remaining_sharemap),
983                }
984         raise NotEnoughSharesError("%s, last failure: %s" %
985                                    (format % args, str(self._last_failure)))
986
987     def _error(self, f):
988         # all errors, including NotEnoughSharesError, land here
989         self._running = False
990         self._status.set_active(False)
991         now = time.time()
992         self._status.timings['total'] = now - self._started
993         self._status.timings['fetch'] = now - self._started_fetching
994         self._status.set_status("Failed")
995         eventually(self._done_deferred.errback, f)