]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
Raise NotEnoughSharesError if there are initially not enough shares for a mutable...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
9      RemoteException
10 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
11      DownloadStopped, MDMF_VERSION, SDMF_VERSION
12 from allmydata.util import hashutil, log, mathutil, deferredutil
13 from allmydata.util.dictutil import DictOfSets
14 from allmydata import hashtree, codec
15 from allmydata.storage.server import si_b2a
16 from pycryptopp.cipher.aes import AES
17 from pycryptopp.publickey import rsa
18
19 from allmydata.mutable.common import CorruptShareError, BadShareError, \
20      UncoordinatedWriteError
21 from allmydata.mutable.layout import MDMFSlotReadProxy
22
23 class RetrieveStatus:
24     implements(IRetrieveStatus)
25     statusid_counter = count(0)
26     def __init__(self):
27         self.timings = {}
28         self.timings["fetch_per_server"] = {}
29         self.timings["decode"] = 0.0
30         self.timings["decrypt"] = 0.0
31         self.timings["cumulative_verify"] = 0.0
32         self._problems = {}
33         self.active = True
34         self.storage_index = None
35         self.helper = False
36         self.encoding = ("?","?")
37         self.size = None
38         self.status = "Not started"
39         self.progress = 0.0
40         self.counter = self.statusid_counter.next()
41         self.started = time.time()
42
43     def get_started(self):
44         return self.started
45     def get_storage_index(self):
46         return self.storage_index
47     def get_encoding(self):
48         return self.encoding
49     def using_helper(self):
50         return self.helper
51     def get_size(self):
52         return self.size
53     def get_status(self):
54         return self.status
55     def get_progress(self):
56         return self.progress
57     def get_active(self):
58         return self.active
59     def get_counter(self):
60         return self.counter
61     def get_problems(self):
62         return self._problems
63
64     def add_fetch_timing(self, server, elapsed):
65         if server not in self.timings["fetch_per_server"]:
66             self.timings["fetch_per_server"][server] = []
67         self.timings["fetch_per_server"][server].append(elapsed)
68     def accumulate_decode_time(self, elapsed):
69         self.timings["decode"] += elapsed
70     def accumulate_decrypt_time(self, elapsed):
71         self.timings["decrypt"] += elapsed
72     def set_storage_index(self, si):
73         self.storage_index = si
74     def set_helper(self, helper):
75         self.helper = helper
76     def set_encoding(self, k, n):
77         self.encoding = (k, n)
78     def set_size(self, size):
79         self.size = size
80     def set_status(self, status):
81         self.status = status
82     def set_progress(self, value):
83         self.progress = value
84     def set_active(self, value):
85         self.active = value
86     def add_problem(self, server, f):
87         serverid = server.get_serverid()
88         self._problems[serverid] = f
89
90 class Marker:
91     pass
92
93 class Retrieve:
94     # this class is currently single-use. Eventually (in MDMF) we will make
95     # it multi-use, in which case you can call download(range) multiple
96     # times, and each will have a separate response chain. However the
97     # Retrieve object will remain tied to a specific version of the file, and
98     # will use a single ServerMap instance.
99     implements(IPushProducer)
100
101     def __init__(self, filenode, storage_broker, servermap, verinfo,
102                  fetch_privkey=False, verify=False):
103         self._node = filenode
104         assert self._node.get_pubkey()
105         self._storage_broker = storage_broker
106         self._storage_index = filenode.get_storage_index()
107         assert self._node.get_readkey()
108         self._last_failure = None
109         prefix = si_b2a(self._storage_index)[:5]
110         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
111         self._running = True
112         self._decoding = False
113         self._bad_shares = set()
114
115         self.servermap = servermap
116         assert self._node.get_pubkey()
117         self.verinfo = verinfo
118         # during repair, we may be called upon to grab the private key, since
119         # it wasn't picked up during a verify=False checker run, and we'll
120         # need it for repair to generate a new version.
121         self._need_privkey = verify or (fetch_privkey
122                                         and not self._node.get_privkey())
123
124         if self._need_privkey:
125             # TODO: Evaluate the need for this. We'll use it if we want
126             # to limit how many queries are on the wire for the privkey
127             # at once.
128             self._privkey_query_markers = [] # one Marker for each time we've
129                                              # tried to get the privkey.
130
131         # verify means that we are using the downloader logic to verify all
132         # of our shares. This tells the downloader a few things.
133         #
134         # 1. We need to download all of the shares.
135         # 2. We don't need to decode or decrypt the shares, since our
136         #    caller doesn't care about the plaintext, only the
137         #    information about which shares are or are not valid.
138         # 3. When we are validating readers, we need to validate the
139         #    signature on the prefix. Do we? We already do this in the
140         #    servermap update?
141         self._verify = verify
142
143         self._status = RetrieveStatus()
144         self._status.set_storage_index(self._storage_index)
145         self._status.set_helper(False)
146         self._status.set_progress(0.0)
147         self._status.set_active(True)
148         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
149          offsets_tuple) = self.verinfo
150         self._status.set_size(datalength)
151         self._status.set_encoding(k, N)
152         self.readers = {}
153         self._stopped = False
154         self._pause_deferred = None
155         self._offset = None
156         self._read_length = None
157         self.log("got seqnum %d" % self.verinfo[0])
158
159
160     def get_status(self):
161         return self._status
162
163     def log(self, *args, **kwargs):
164         if "parent" not in kwargs:
165             kwargs["parent"] = self._log_number
166         if "facility" not in kwargs:
167             kwargs["facility"] = "tahoe.mutable.retrieve"
168         return log.msg(*args, **kwargs)
169
170     def _set_current_status(self, state):
171         seg = "%d/%d" % (self._current_segment, self._last_segment)
172         self._status.set_status("segment %s (%s)" % (seg, state))
173
174     ###################
175     # IPushProducer
176
177     def pauseProducing(self):
178         """
179         I am called by my download target if we have produced too much
180         data for it to handle. I make the downloader stop producing new
181         data until my resumeProducing method is called.
182         """
183         if self._pause_deferred is not None:
184             return
185
186         # fired when the download is unpaused.
187         self._old_status = self._status.get_status()
188         self._set_current_status("paused")
189
190         self._pause_deferred = defer.Deferred()
191
192
193     def resumeProducing(self):
194         """
195         I am called by my download target once it is ready to begin
196         receiving data again.
197         """
198         if self._pause_deferred is None:
199             return
200
201         p = self._pause_deferred
202         self._pause_deferred = None
203         self._status.set_status(self._old_status)
204
205         eventually(p.callback, None)
206
207     def stopProducing(self):
208         self._stopped = True
209         self.resumeProducing()
210
211
212     def _check_for_paused(self, res):
213         """
214         I am called just before a write to the consumer. I return a
215         Deferred that eventually fires with the data that is to be
216         written to the consumer. If the download has not been paused,
217         the Deferred fires immediately. Otherwise, the Deferred fires
218         when the downloader is unpaused.
219         """
220         if self._pause_deferred is not None:
221             d = defer.Deferred()
222             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
223             return d
224         return res
225
226     def _check_for_stopped(self, res):
227         if self._stopped:
228             raise DownloadStopped("our Consumer called stopProducing()")
229         return res
230
231
232     def download(self, consumer=None, offset=0, size=None):
233         assert IConsumer.providedBy(consumer) or self._verify
234
235         if consumer:
236             self._consumer = consumer
237             # we provide IPushProducer, so streaming=True, per
238             # IConsumer.
239             self._consumer.registerProducer(self, streaming=True)
240
241         self._done_deferred = defer.Deferred()
242         self._offset = offset
243         self._read_length = size
244         self._setup_encoding_parameters()
245         self._setup_download()
246         self.log("starting download")
247         self._started_fetching = time.time()
248         # The download process beyond this is a state machine.
249         # _add_active_servers will select the servers that we want to use
250         # for the download, and then attempt to start downloading. After
251         # each segment, it will check for doneness, reacting to broken
252         # servers and corrupt shares as necessary. If it runs out of good
253         # servers before downloading all of the segments, _done_deferred
254         # will errback.  Otherwise, it will eventually callback with the
255         # contents of the mutable file.
256         self.loop()
257         return self._done_deferred
258
259     def loop(self):
260         d = fireEventually(None) # avoid #237 recursion limit problem
261         d.addCallback(lambda ign: self._activate_enough_servers())
262         d.addCallback(lambda ign: self._download_current_segment())
263         # when we're done, _download_current_segment will call _done. If we
264         # aren't, it will call loop() again.
265         d.addErrback(self._error)
266
267     def _setup_download(self):
268         self._started = time.time()
269         self._status.set_status("Retrieving Shares")
270
271         # how many shares do we need?
272         (seqnum,
273          root_hash,
274          IV,
275          segsize,
276          datalength,
277          k,
278          N,
279          prefix,
280          offsets_tuple) = self.verinfo
281
282         # first, which servers can we use?
283         versionmap = self.servermap.make_versionmap()
284         shares = versionmap[self.verinfo]
285         # this sharemap is consumed as we decide to send requests
286         self.remaining_sharemap = DictOfSets()
287         for (shnum, server, timestamp) in shares:
288             self.remaining_sharemap.add(shnum, server)
289             # Reuse the SlotReader from the servermap.
290             key = (self.verinfo, server.get_serverid(),
291                    self._storage_index, shnum)
292             if key in self.servermap.proxies:
293                 reader = self.servermap.proxies[key]
294             else:
295                 reader = MDMFSlotReadProxy(server.get_rref(),
296                                            self._storage_index, shnum, None)
297             reader.server = server
298             self.readers[shnum] = reader
299
300         if len(self.remaining_sharemap) < k:
301             self._raise_notenoughshareserror()
302
303         self.shares = {} # maps shnum to validated blocks
304         self._active_readers = [] # list of active readers for this dl.
305         self._block_hash_trees = {} # shnum => hashtree
306
307         for i in xrange(self._total_shares):
308             # So we don't have to do this later.
309             self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
310
311         # We need one share hash tree for the entire file; its leaves
312         # are the roots of the block hash trees for the shares that
313         # comprise it, and its root is in the verinfo.
314         self.share_hash_tree = hashtree.IncompleteHashTree(N)
315         self.share_hash_tree.set_hashes({0: root_hash})
316
317     def decode(self, blocks_and_salts, segnum):
318         """
319         I am a helper method that the mutable file update process uses
320         as a shortcut to decode and decrypt the segments that it needs
321         to fetch in order to perform a file update. I take in a
322         collection of blocks and salts, and pick some of those to make a
323         segment with. I return the plaintext associated with that
324         segment.
325         """
326         self._block_hash_trees = None
327         self._setup_encoding_parameters()
328
329         # _decode_blocks() expects the output of a gatherResults that
330         # contains the outputs of _validate_block() (each of which is a dict
331         # mapping shnum to (block,salt) bytestrings).
332         d = self._decode_blocks([blocks_and_salts], segnum)
333         d.addCallback(self._decrypt_segment)
334         return d
335
336
337     def _setup_encoding_parameters(self):
338         """
339         I set up the encoding parameters, including k, n, the number
340         of segments associated with this file, and the segment decoders.
341         """
342         (seqnum,
343          root_hash,
344          IV,
345          segsize,
346          datalength,
347          k,
348          n,
349          known_prefix,
350          offsets_tuple) = self.verinfo
351         self._required_shares = k
352         self._total_shares = n
353         self._segment_size = segsize
354         self._data_length = datalength
355
356         if not IV:
357             self._version = MDMF_VERSION
358         else:
359             self._version = SDMF_VERSION
360
361         if datalength and segsize:
362             self._num_segments = mathutil.div_ceil(datalength, segsize)
363             self._tail_data_size = datalength % segsize
364         else:
365             self._num_segments = 0
366             self._tail_data_size = 0
367
368         self._segment_decoder = codec.CRSDecoder()
369         self._segment_decoder.set_params(segsize, k, n)
370
371         if  not self._tail_data_size:
372             self._tail_data_size = segsize
373
374         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
375                                                          self._required_shares)
376         if self._tail_segment_size == self._segment_size:
377             self._tail_decoder = self._segment_decoder
378         else:
379             self._tail_decoder = codec.CRSDecoder()
380             self._tail_decoder.set_params(self._tail_segment_size,
381                                           self._required_shares,
382                                           self._total_shares)
383
384         self.log("got encoding parameters: "
385                  "k: %d "
386                  "n: %d "
387                  "%d segments of %d bytes each (%d byte tail segment)" % \
388                  (k, n, self._num_segments, self._segment_size,
389                   self._tail_segment_size))
390
391         # Our last task is to tell the downloader where to start and
392         # where to stop. We use three parameters for that:
393         #   - self._start_segment: the segment that we need to start
394         #     downloading from.
395         #   - self._current_segment: the next segment that we need to
396         #     download.
397         #   - self._last_segment: The last segment that we were asked to
398         #     download.
399         #
400         #  We say that the download is complete when
401         #  self._current_segment > self._last_segment. We use
402         #  self._start_segment and self._last_segment to know when to
403         #  strip things off of segments, and how much to strip.
404         if self._offset:
405             self.log("got offset: %d" % self._offset)
406             # our start segment is the first segment containing the
407             # offset we were given.
408             start = self._offset // self._segment_size
409
410             assert start < self._num_segments
411             self._start_segment = start
412             self.log("got start segment: %d" % self._start_segment)
413         else:
414             self._start_segment = 0
415
416
417         # If self._read_length is None, then we want to read the whole
418         # file. Otherwise, we want to read only part of the file, and
419         # need to figure out where to stop reading.
420         if self._read_length is not None:
421             # our end segment is the last segment containing part of the
422             # segment that we were asked to read.
423             self.log("got read length %d" % self._read_length)
424             if self._read_length != 0:
425                 end_data = self._offset + self._read_length
426
427                 # We don't actually need to read the byte at end_data,
428                 # but the one before it.
429                 end = (end_data - 1) // self._segment_size
430
431                 assert end < self._num_segments
432                 self._last_segment = end
433             else:
434                 self._last_segment = self._start_segment
435             self.log("got end segment: %d" % self._last_segment)
436         else:
437             self._last_segment = self._num_segments - 1
438
439         self._current_segment = self._start_segment
440
441     def _activate_enough_servers(self):
442         """
443         I populate self._active_readers with enough active readers to
444         retrieve the contents of this mutable file. I am called before
445         downloading starts, and (eventually) after each validation
446         error, connection error, or other problem in the download.
447         """
448         # TODO: It would be cool to investigate other heuristics for
449         # reader selection. For instance, the cost (in time the user
450         # spends waiting for their file) of selecting a really slow server
451         # that happens to have a primary share is probably more than
452         # selecting a really fast server that doesn't have a primary
453         # share. Maybe the servermap could be extended to provide this
454         # information; it could keep track of latency information while
455         # it gathers more important data, and then this routine could
456         # use that to select active readers.
457         #
458         # (these and other questions would be easier to answer with a
459         #  robust, configurable tahoe-lafs simulator, which modeled node
460         #  failures, differences in node speed, and other characteristics
461         #  that we expect storage servers to have.  You could have
462         #  presets for really stable grids (like allmydata.com),
463         #  friendnets, make it easy to configure your own settings, and
464         #  then simulate the effect of big changes on these use cases
465         #  instead of just reasoning about what the effect might be. Out
466         #  of scope for MDMF, though.)
467
468         # XXX: Why don't format= log messages work here?
469
470         known_shnums = set(self.remaining_sharemap.keys())
471         used_shnums = set([r.shnum for r in self._active_readers])
472         unused_shnums = known_shnums - used_shnums
473
474         if self._verify:
475             new_shnums = unused_shnums # use them all
476         elif len(self._active_readers) < self._required_shares:
477             # need more shares
478             more = self._required_shares - len(self._active_readers)
479             # We favor lower numbered shares, since FEC is faster with
480             # primary shares than with other shares, and lower-numbered
481             # shares are more likely to be primary than higher numbered
482             # shares.
483             new_shnums = sorted(unused_shnums)[:more]
484             if len(new_shnums) < more:
485                 # We don't have enough readers to retrieve the file; fail.
486                 self._raise_notenoughshareserror()
487         else:
488             new_shnums = []
489
490         self.log("adding %d new servers to the active list" % len(new_shnums))
491         for shnum in new_shnums:
492             reader = self.readers[shnum]
493             self._active_readers.append(reader)
494             self.log("added reader for share %d" % shnum)
495             # Each time we add a reader, we check to see if we need the
496             # private key. If we do, we politely ask for it and then continue
497             # computing. If we find that we haven't gotten it at the end of
498             # segment decoding, then we'll take more drastic measures.
499             if self._need_privkey and not self._node.is_readonly():
500                 d = reader.get_encprivkey()
501                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
502                 # XXX: don't just drop the Deferred. We need error-reporting
503                 # but not flow-control here.
504
505     def _try_to_validate_prefix(self, prefix, reader):
506         """
507         I check that the prefix returned by a candidate server for
508         retrieval matches the prefix that the servermap knows about
509         (and, hence, the prefix that was validated earlier). If it does,
510         I return True, which means that I approve of the use of the
511         candidate server for segment retrieval. If it doesn't, I return
512         False, which means that another server must be chosen.
513         """
514         (seqnum,
515          root_hash,
516          IV,
517          segsize,
518          datalength,
519          k,
520          N,
521          known_prefix,
522          offsets_tuple) = self.verinfo
523         if known_prefix != prefix:
524             self.log("prefix from share %d doesn't match" % reader.shnum)
525             raise UncoordinatedWriteError("Mismatched prefix -- this could "
526                                           "indicate an uncoordinated write")
527         # Otherwise, we're okay -- no issues.
528
529     def _mark_bad_share(self, server, shnum, reader, f):
530         """
531         I mark the given (server, shnum) as a bad share, which means that it
532         will not be used anywhere else.
533
534         There are several reasons to want to mark something as a bad
535         share. These include:
536
537             - A connection error to the server.
538             - A mismatched prefix (that is, a prefix that does not match
539               our local conception of the version information string).
540             - A failing block hash, salt hash, share hash, or other
541               integrity check.
542
543         This method will ensure that readers that we wish to mark bad
544         (for these reasons or other reasons) are not used for the rest
545         of the download. Additionally, it will attempt to tell the
546         remote server (with no guarantee of success) that its share is
547         corrupt.
548         """
549         self.log("marking share %d on server %s as bad" % \
550                  (shnum, server.get_name()))
551         prefix = self.verinfo[-2]
552         self.servermap.mark_bad_share(server, shnum, prefix)
553         self._bad_shares.add((server, shnum, f))
554         self._status.add_problem(server, f)
555         self._last_failure = f
556
557         # Remove the reader from _active_readers
558         self._active_readers.remove(reader)
559         for shnum in list(self.remaining_sharemap.keys()):
560             self.remaining_sharemap.discard(shnum, reader.server)
561
562         if f.check(BadShareError):
563             self.notify_server_corruption(server, shnum, str(f.value))
564
565     def _download_current_segment(self):
566         """
567         I download, validate, decode, decrypt, and assemble the segment
568         that this Retrieve is currently responsible for downloading.
569         """
570         if self._current_segment > self._last_segment:
571             # No more segments to download, we're done.
572             self.log("got plaintext, done")
573             return self._done()
574         elif self._verify and len(self._active_readers) == 0:
575             self.log("no more good shares, no need to keep verifying")
576             return self._done()
577         self.log("on segment %d of %d" %
578                  (self._current_segment + 1, self._num_segments))
579         d = self._process_segment(self._current_segment)
580         d.addCallback(lambda ign: self.loop())
581         return d
582
583     def _process_segment(self, segnum):
584         """
585         I download, validate, decode, and decrypt one segment of the
586         file that this Retrieve is retrieving. This means coordinating
587         the process of getting k blocks of that file, validating them,
588         assembling them into one segment with the decoder, and then
589         decrypting them.
590         """
591         self.log("processing segment %d" % segnum)
592
593         # TODO: The old code uses a marker. Should this code do that
594         # too? What did the Marker do?
595
596         # We need to ask each of our active readers for its block and
597         # salt. We will then validate those. If validation is
598         # successful, we will assemble the results into plaintext.
599         ds = []
600         for reader in self._active_readers:
601             started = time.time()
602             d1 = reader.get_block_and_salt(segnum)
603             d2,d3 = self._get_needed_hashes(reader, segnum)
604             d = deferredutil.gatherResults([d1,d2,d3])
605             d.addCallback(self._validate_block, segnum, reader, reader.server, started)
606             # _handle_bad_share takes care of recoverable errors (by dropping
607             # that share and returning None). Any other errors (i.e. code
608             # bugs) are passed through and cause the retrieve to fail.
609             d.addErrback(self._handle_bad_share, [reader])
610             ds.append(d)
611         dl = deferredutil.gatherResults(ds)
612         if self._verify:
613             dl.addCallback(lambda ignored: "")
614             dl.addCallback(self._set_segment)
615         else:
616             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
617         return dl
618
619
620     def _maybe_decode_and_decrypt_segment(self, results, segnum):
621         """
622         I take the results of fetching and validating the blocks from
623         _process_segment. If validation and fetching succeeded without
624         incident, I will proceed with decoding and decryption. Otherwise, I
625         will do nothing.
626         """
627         self.log("trying to decode and decrypt segment %d" % segnum)
628
629         # 'results' is the output of a gatherResults set up in
630         # _process_segment(). Each component Deferred will either contain the
631         # non-Failure output of _validate_block() for a single block (i.e.
632         # {segnum:(block,salt)}), or None if _validate_block threw an
633         # exception and _validation_or_decoding_failed handled it (by
634         # dropping that server).
635
636         if None in results:
637             self.log("some validation operations failed; not proceeding")
638             return defer.succeed(None)
639         self.log("everything looks ok, building segment %d" % segnum)
640         d = self._decode_blocks(results, segnum)
641         d.addCallback(self._decrypt_segment)
642         # check to see whether we've been paused before writing
643         # anything.
644         d.addCallback(self._check_for_paused)
645         d.addCallback(self._check_for_stopped)
646         d.addCallback(self._set_segment)
647         return d
648
649
650     def _set_segment(self, segment):
651         """
652         Given a plaintext segment, I register that segment with the
653         target that is handling the file download.
654         """
655         self.log("got plaintext for segment %d" % self._current_segment)
656         if self._current_segment == self._start_segment:
657             # We're on the first segment. It's possible that we want
658             # only some part of the end of this segment, and that we
659             # just downloaded the whole thing to get that part. If so,
660             # we need to account for that and give the reader just the
661             # data that they want.
662             n = self._offset % self._segment_size
663             self.log("stripping %d bytes off of the first segment" % n)
664             self.log("original segment length: %d" % len(segment))
665             segment = segment[n:]
666             self.log("new segment length: %d" % len(segment))
667
668         if self._current_segment == self._last_segment and self._read_length is not None:
669             # We're on the last segment. It's possible that we only want
670             # part of the beginning of this segment, and that we
671             # downloaded the whole thing anyway. Make sure to give the
672             # caller only the portion of the segment that they want to
673             # receive.
674             extra = self._read_length
675             if self._start_segment != self._last_segment:
676                 extra -= self._segment_size - \
677                             (self._offset % self._segment_size)
678             extra %= self._segment_size
679             self.log("original segment length: %d" % len(segment))
680             segment = segment[:extra]
681             self.log("new segment length: %d" % len(segment))
682             self.log("only taking %d bytes of the last segment" % extra)
683
684         if not self._verify:
685             self._consumer.write(segment)
686         else:
687             # we don't care about the plaintext if we are doing a verify.
688             segment = None
689         self._current_segment += 1
690
691
692     def _handle_bad_share(self, f, readers):
693         """
694         I am called when a block or a salt fails to correctly validate, or when
695         the decryption or decoding operation fails for some reason.  I react to
696         this failure by notifying the remote server of corruption, and then
697         removing the remote server from further activity.
698         """
699         # these are the errors we can tolerate: by giving up on this share
700         # and finding others to replace it. Any other errors (i.e. coding
701         # bugs) are re-raised, causing the download to fail.
702         f.trap(DeadReferenceError, RemoteException, BadShareError)
703
704         # DeadReferenceError happens when we try to fetch data from a server
705         # that has gone away. RemoteException happens if the server had an
706         # internal error. BadShareError encompasses: (UnknownVersionError,
707         # LayoutInvalid, struct.error) which happen when we get obviously
708         # wrong data, and CorruptShareError which happens later, when we
709         # perform integrity checks on the data.
710
711         assert isinstance(readers, list)
712         bad_shnums = [reader.shnum for reader in readers]
713
714         self.log("validation or decoding failed on share(s) %s, server(s) %s "
715                  ", segment %d: %s" % \
716                  (bad_shnums, readers, self._current_segment, str(f)))
717         for reader in readers:
718             self._mark_bad_share(reader.server, reader.shnum, reader, f)
719         return None
720
721
722     def _validate_block(self, results, segnum, reader, server, started):
723         """
724         I validate a block from one share on a remote server.
725         """
726         # Grab the part of the block hash tree that is necessary to
727         # validate this block, then generate the block hash root.
728         self.log("validating share %d for segment %d" % (reader.shnum,
729                                                              segnum))
730         elapsed = time.time() - started
731         self._status.add_fetch_timing(server, elapsed)
732         self._set_current_status("validating blocks")
733
734         block_and_salt, blockhashes, sharehashes = results
735         block, salt = block_and_salt
736         assert type(block) is str, (block, salt)
737
738         blockhashes = dict(enumerate(blockhashes))
739         self.log("the reader gave me the following blockhashes: %s" % \
740                  blockhashes.keys())
741         self.log("the reader gave me the following sharehashes: %s" % \
742                  sharehashes.keys())
743         bht = self._block_hash_trees[reader.shnum]
744
745         if bht.needed_hashes(segnum, include_leaf=True):
746             try:
747                 bht.set_hashes(blockhashes)
748             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
749                     IndexError), e:
750                 raise CorruptShareError(server,
751                                         reader.shnum,
752                                         "block hash tree failure: %s" % e)
753
754         if self._version == MDMF_VERSION:
755             blockhash = hashutil.block_hash(salt + block)
756         else:
757             blockhash = hashutil.block_hash(block)
758         # If this works without an error, then validation is
759         # successful.
760         try:
761            bht.set_hashes(leaves={segnum: blockhash})
762         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
763                 IndexError), e:
764             raise CorruptShareError(server,
765                                     reader.shnum,
766                                     "block hash tree failure: %s" % e)
767
768         # Reaching this point means that we know that this segment
769         # is correct. Now we need to check to see whether the share
770         # hash chain is also correct.
771         # SDMF wrote share hash chains that didn't contain the
772         # leaves, which would be produced from the block hash tree.
773         # So we need to validate the block hash tree first. If
774         # successful, then bht[0] will contain the root for the
775         # shnum, which will be a leaf in the share hash tree, which
776         # will allow us to validate the rest of the tree.
777         try:
778             self.share_hash_tree.set_hashes(hashes=sharehashes,
779                                         leaves={reader.shnum: bht[0]})
780         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
781                 IndexError), e:
782             raise CorruptShareError(server,
783                                     reader.shnum,
784                                     "corrupt hashes: %s" % e)
785
786         self.log('share %d is valid for segment %d' % (reader.shnum,
787                                                        segnum))
788         return {reader.shnum: (block, salt)}
789
790
791     def _get_needed_hashes(self, reader, segnum):
792         """
793         I get the hashes needed to validate segnum from the reader, then return
794         to my caller when this is done.
795         """
796         bht = self._block_hash_trees[reader.shnum]
797         needed = bht.needed_hashes(segnum, include_leaf=True)
798         # The root of the block hash tree is also a leaf in the share
799         # hash tree. So we don't need to fetch it from the remote
800         # server. In the case of files with one segment, this means that
801         # we won't fetch any block hash tree from the remote server,
802         # since the hash of each share of the file is the entire block
803         # hash tree, and is a leaf in the share hash tree. This is fine,
804         # since any share corruption will be detected in the share hash
805         # tree.
806         #needed.discard(0)
807         self.log("getting blockhashes for segment %d, share %d: %s" % \
808                  (segnum, reader.shnum, str(needed)))
809         # TODO is force_remote necessary here?
810         d1 = reader.get_blockhashes(needed, force_remote=False)
811         if self.share_hash_tree.needed_hashes(reader.shnum):
812             need = self.share_hash_tree.needed_hashes(reader.shnum)
813             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
814                                                                  str(need)))
815             d2 = reader.get_sharehashes(need, force_remote=False)
816         else:
817             d2 = defer.succeed({}) # the logic in the next method
818                                    # expects a dict
819         return d1,d2
820
821
822     def _decode_blocks(self, results, segnum):
823         """
824         I take a list of k blocks and salts, and decode that into a
825         single encrypted segment.
826         """
827         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
828         # want to merge them all
829         blocks_and_salts = {}
830         for d in results:
831             blocks_and_salts.update(d)
832
833         # All of these blocks should have the same salt; in SDMF, it is
834         # the file-wide IV, while in MDMF it is the per-segment salt. In
835         # either case, we just need to get one of them and use it.
836         #
837         # d.items()[0] is like (shnum, (block, salt))
838         # d.items()[0][1] is like (block, salt)
839         # d.items()[0][1][1] is the salt.
840         salt = blocks_and_salts.items()[0][1][1]
841         # Next, extract just the blocks from the dict. We'll use the
842         # salt in the next step.
843         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
844         d2 = dict(share_and_shareids)
845         shareids = []
846         shares = []
847         for shareid, share in d2.items():
848             shareids.append(shareid)
849             shares.append(share)
850
851         self._set_current_status("decoding")
852         started = time.time()
853         assert len(shareids) >= self._required_shares, len(shareids)
854         # zfec really doesn't want extra shares
855         shareids = shareids[:self._required_shares]
856         shares = shares[:self._required_shares]
857         self.log("decoding segment %d" % segnum)
858         if segnum == self._num_segments - 1:
859             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
860         else:
861             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
862         def _process(buffers):
863             segment = "".join(buffers)
864             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
865                      segnum=segnum,
866                      numsegs=self._num_segments,
867                      level=log.NOISY)
868             self.log(" joined length %d, datalength %d" %
869                      (len(segment), self._data_length))
870             if segnum == self._num_segments - 1:
871                 size_to_use = self._tail_data_size
872             else:
873                 size_to_use = self._segment_size
874             segment = segment[:size_to_use]
875             self.log(" segment len=%d" % len(segment))
876             self._status.accumulate_decode_time(time.time() - started)
877             return segment, salt
878         d.addCallback(_process)
879         return d
880
881
882     def _decrypt_segment(self, segment_and_salt):
883         """
884         I take a single segment and its salt, and decrypt it. I return
885         the plaintext of the segment that is in my argument.
886         """
887         segment, salt = segment_and_salt
888         self._set_current_status("decrypting")
889         self.log("decrypting segment %d" % self._current_segment)
890         started = time.time()
891         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
892         decryptor = AES(key)
893         plaintext = decryptor.process(segment)
894         self._status.accumulate_decrypt_time(time.time() - started)
895         return plaintext
896
897
898     def notify_server_corruption(self, server, shnum, reason):
899         rref = server.get_rref()
900         rref.callRemoteOnly("advise_corrupt_share",
901                             "mutable", self._storage_index, shnum, reason)
902
903
904     def _try_to_validate_privkey(self, enc_privkey, reader, server):
905         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
906         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
907         if alleged_writekey != self._node.get_writekey():
908             self.log("invalid privkey from %s shnum %d" %
909                      (reader, reader.shnum),
910                      level=log.WEIRD, umid="YIw4tA")
911             if self._verify:
912                 self.servermap.mark_bad_share(server, reader.shnum,
913                                               self.verinfo[-2])
914                 e = CorruptShareError(server,
915                                       reader.shnum,
916                                       "invalid privkey")
917                 f = failure.Failure(e)
918                 self._bad_shares.add((server, reader.shnum, f))
919             return
920
921         # it's good
922         self.log("got valid privkey from shnum %d on reader %s" %
923                  (reader.shnum, reader))
924         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
925         self._node._populate_encprivkey(enc_privkey)
926         self._node._populate_privkey(privkey)
927         self._need_privkey = False
928
929
930
931     def _done(self):
932         """
933         I am called by _download_current_segment when the download process
934         has finished successfully. After making some useful logging
935         statements, I return the decrypted contents to the owner of this
936         Retrieve object through self._done_deferred.
937         """
938         self._running = False
939         self._status.set_active(False)
940         now = time.time()
941         self._status.timings['total'] = now - self._started
942         self._status.timings['fetch'] = now - self._started_fetching
943         self._status.set_status("Finished")
944         self._status.set_progress(1.0)
945
946         # remember the encoding parameters, use them again next time
947         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
948          offsets_tuple) = self.verinfo
949         self._node._populate_required_shares(k)
950         self._node._populate_total_shares(N)
951
952         if self._verify:
953             ret = self._bad_shares
954             self.log("done verifying, found %d bad shares" % len(ret))
955         else:
956             # TODO: upload status here?
957             ret = self._consumer
958             self._consumer.unregisterProducer()
959         eventually(self._done_deferred.callback, ret)
960
961
962     def _raise_notenoughshareserror(self):
963         """
964         I am called by _activate_enough_servers when there are not enough
965         active servers left to complete the download. After making some
966         useful logging statements, I throw an exception to that effect
967         to the caller of this Retrieve object through
968         self._done_deferred.
969         """
970
971         format = ("ran out of servers: "
972                   "have %(have)d of %(total)d segments "
973                   "found %(bad)d bad shares "
974                   "encoding %(k)d-of-%(n)d")
975         args = {"have": self._current_segment,
976                 "total": self._num_segments,
977                 "need": self._last_segment,
978                 "k": self._required_shares,
979                 "n": self._total_shares,
980                 "bad": len(self._bad_shares)}
981         raise NotEnoughSharesError("%s, last failure: %s" %
982                                    (format % args, str(self._last_failure)))
983
984     def _error(self, f):
985         # all errors, including NotEnoughSharesError, land here
986         self._running = False
987         self._status.set_active(False)
988         now = time.time()
989         self._status.timings['total'] = now - self._started
990         self._status.timings['fetch'] = now - self._started_fetching
991         self._status.set_status("Failed")
992         eventually(self._done_deferred.errback, f)