]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
Ensure that verification proceeds and stops when appropriate.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
9      RemoteException
10 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
11      DownloadStopped, MDMF_VERSION, SDMF_VERSION
12 from allmydata.util import hashutil, log, mathutil, deferredutil
13 from allmydata.util.dictutil import DictOfSets
14 from allmydata import hashtree, codec
15 from allmydata.storage.server import si_b2a
16 from pycryptopp.cipher.aes import AES
17 from pycryptopp.publickey import rsa
18
19 from allmydata.mutable.common import CorruptShareError, BadShareError, \
20      UncoordinatedWriteError
21 from allmydata.mutable.layout import MDMFSlotReadProxy
22
23 class RetrieveStatus:
24     implements(IRetrieveStatus)
25     statusid_counter = count(0)
26     def __init__(self):
27         self.timings = {}
28         self.timings["fetch_per_server"] = {}
29         self.timings["decode"] = 0.0
30         self.timings["decrypt"] = 0.0
31         self.timings["cumulative_verify"] = 0.0
32         self._problems = {}
33         self.active = True
34         self.storage_index = None
35         self.helper = False
36         self.encoding = ("?","?")
37         self.size = None
38         self.status = "Not started"
39         self.progress = 0.0
40         self.counter = self.statusid_counter.next()
41         self.started = time.time()
42
43     def get_started(self):
44         return self.started
45     def get_storage_index(self):
46         return self.storage_index
47     def get_encoding(self):
48         return self.encoding
49     def using_helper(self):
50         return self.helper
51     def get_size(self):
52         return self.size
53     def get_status(self):
54         return self.status
55     def get_progress(self):
56         return self.progress
57     def get_active(self):
58         return self.active
59     def get_counter(self):
60         return self.counter
61     def get_problems(self):
62         return self._problems
63
64     def add_fetch_timing(self, server, elapsed):
65         serverid = server.get_serverid()
66         if serverid not in self.timings["fetch_per_server"]:
67             self.timings["fetch_per_server"][serverid] = []
68         self.timings["fetch_per_server"][serverid].append(elapsed)
69     def accumulate_decode_time(self, elapsed):
70         self.timings["decode"] += elapsed
71     def accumulate_decrypt_time(self, elapsed):
72         self.timings["decrypt"] += elapsed
73     def set_storage_index(self, si):
74         self.storage_index = si
75     def set_helper(self, helper):
76         self.helper = helper
77     def set_encoding(self, k, n):
78         self.encoding = (k, n)
79     def set_size(self, size):
80         self.size = size
81     def set_status(self, status):
82         self.status = status
83     def set_progress(self, value):
84         self.progress = value
85     def set_active(self, value):
86         self.active = value
87     def add_problem(self, server, f):
88         serverid = server.get_serverid()
89         self._problems[serverid] = f
90
91 class Marker:
92     pass
93
94 class Retrieve:
95     # this class is currently single-use. Eventually (in MDMF) we will make
96     # it multi-use, in which case you can call download(range) multiple
97     # times, and each will have a separate response chain. However the
98     # Retrieve object will remain tied to a specific version of the file, and
99     # will use a single ServerMap instance.
100     implements(IPushProducer)
101
102     def __init__(self, filenode, storage_broker, servermap, verinfo,
103                  fetch_privkey=False, verify=False):
104         self._node = filenode
105         assert self._node.get_pubkey()
106         self._storage_broker = storage_broker
107         self._storage_index = filenode.get_storage_index()
108         assert self._node.get_readkey()
109         self._last_failure = None
110         prefix = si_b2a(self._storage_index)[:5]
111         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
112         self._running = True
113         self._decoding = False
114         self._bad_shares = set()
115
116         self.servermap = servermap
117         assert self._node.get_pubkey()
118         self.verinfo = verinfo
119         # during repair, we may be called upon to grab the private key, since
120         # it wasn't picked up during a verify=False checker run, and we'll
121         # need it for repair to generate a new version.
122         self._need_privkey = verify or (fetch_privkey
123                                         and not self._node.get_privkey())
124
125         if self._need_privkey:
126             # TODO: Evaluate the need for this. We'll use it if we want
127             # to limit how many queries are on the wire for the privkey
128             # at once.
129             self._privkey_query_markers = [] # one Marker for each time we've
130                                              # tried to get the privkey.
131
132         # verify means that we are using the downloader logic to verify all
133         # of our shares. This tells the downloader a few things.
134         # 
135         # 1. We need to download all of the shares.
136         # 2. We don't need to decode or decrypt the shares, since our
137         #    caller doesn't care about the plaintext, only the
138         #    information about which shares are or are not valid.
139         # 3. When we are validating readers, we need to validate the
140         #    signature on the prefix. Do we? We already do this in the
141         #    servermap update?
142         self._verify = verify
143
144         self._status = RetrieveStatus()
145         self._status.set_storage_index(self._storage_index)
146         self._status.set_helper(False)
147         self._status.set_progress(0.0)
148         self._status.set_active(True)
149         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
150          offsets_tuple) = self.verinfo
151         self._status.set_size(datalength)
152         self._status.set_encoding(k, N)
153         self.readers = {}
154         self._stopped = False
155         self._pause_deferred = None
156         self._offset = None
157         self._read_length = None
158         self.log("got seqnum %d" % self.verinfo[0])
159
160
161     def get_status(self):
162         return self._status
163
164     def log(self, *args, **kwargs):
165         if "parent" not in kwargs:
166             kwargs["parent"] = self._log_number
167         if "facility" not in kwargs:
168             kwargs["facility"] = "tahoe.mutable.retrieve"
169         return log.msg(*args, **kwargs)
170
171     def _set_current_status(self, state):
172         seg = "%d/%d" % (self._current_segment, self._last_segment)
173         self._status.set_status("segment %s (%s)" % (seg, state))
174
175     ###################
176     # IPushProducer
177
178     def pauseProducing(self):
179         """
180         I am called by my download target if we have produced too much
181         data for it to handle. I make the downloader stop producing new
182         data until my resumeProducing method is called.
183         """
184         if self._pause_deferred is not None:
185             return
186
187         # fired when the download is unpaused.
188         self._old_status = self._status.get_status()
189         self._set_current_status("paused")
190
191         self._pause_deferred = defer.Deferred()
192
193
194     def resumeProducing(self):
195         """
196         I am called by my download target once it is ready to begin
197         receiving data again.
198         """
199         if self._pause_deferred is None:
200             return
201
202         p = self._pause_deferred
203         self._pause_deferred = None
204         self._status.set_status(self._old_status)
205
206         eventually(p.callback, None)
207
208     def stopProducing(self):
209         self._stopped = True
210         self.resumeProducing()
211
212
213     def _check_for_paused(self, res):
214         """
215         I am called just before a write to the consumer. I return a
216         Deferred that eventually fires with the data that is to be
217         written to the consumer. If the download has not been paused,
218         the Deferred fires immediately. Otherwise, the Deferred fires
219         when the downloader is unpaused.
220         """
221         if self._pause_deferred is not None:
222             d = defer.Deferred()
223             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
224             return d
225         return res
226
227     def _check_for_stopped(self, res):
228         if self._stopped:
229             raise DownloadStopped("our Consumer called stopProducing()")
230         return res
231
232
233     def download(self, consumer=None, offset=0, size=None):
234         assert IConsumer.providedBy(consumer) or self._verify
235
236         if consumer:
237             self._consumer = consumer
238             # we provide IPushProducer, so streaming=True, per
239             # IConsumer.
240             self._consumer.registerProducer(self, streaming=True)
241
242         self._done_deferred = defer.Deferred()
243         self._offset = offset
244         self._read_length = size
245         self._setup_download()
246         self._setup_encoding_parameters()
247         self.log("starting download")
248         self._started_fetching = time.time()
249         # The download process beyond this is a state machine.
250         # _add_active_servers will select the servers that we want to use
251         # for the download, and then attempt to start downloading. After
252         # each segment, it will check for doneness, reacting to broken
253         # servers and corrupt shares as necessary. If it runs out of good
254         # servers before downloading all of the segments, _done_deferred
255         # will errback.  Otherwise, it will eventually callback with the
256         # contents of the mutable file.
257         self.loop()
258         return self._done_deferred
259
260     def loop(self):
261         d = fireEventually(None) # avoid #237 recursion limit problem
262         d.addCallback(lambda ign: self._activate_enough_servers())
263         d.addCallback(lambda ign: self._download_current_segment())
264         # when we're done, _download_current_segment will call _done. If we
265         # aren't, it will call loop() again.
266         d.addErrback(self._error)
267
268     def _setup_download(self):
269         self._started = time.time()
270         self._status.set_status("Retrieving Shares")
271
272         # how many shares do we need?
273         (seqnum,
274          root_hash,
275          IV,
276          segsize,
277          datalength,
278          k,
279          N,
280          prefix,
281          offsets_tuple) = self.verinfo
282
283         # first, which servers can we use?
284         versionmap = self.servermap.make_versionmap()
285         shares = versionmap[self.verinfo]
286         # this sharemap is consumed as we decide to send requests
287         self.remaining_sharemap = DictOfSets()
288         for (shnum, server, timestamp) in shares:
289             self.remaining_sharemap.add(shnum, server)
290             # If the servermap update fetched anything, it fetched at least 1
291             # KiB, so we ask for that much.
292             # TODO: Change the cache methods to allow us to fetch all of the
293             # data that they have, then change this method to do that.
294             any_cache = self._node._read_from_cache(self.verinfo, shnum,
295                                                     0, 1000)
296             reader = MDMFSlotReadProxy(server.get_rref(),
297                                        self._storage_index,
298                                        shnum,
299                                        any_cache)
300             reader.server = server
301             self.readers[shnum] = reader
302         assert len(self.remaining_sharemap) >= k
303
304         self.shares = {} # maps shnum to validated blocks
305         self._active_readers = [] # list of active readers for this dl.
306         self._block_hash_trees = {} # shnum => hashtree
307
308         # We need one share hash tree for the entire file; its leaves
309         # are the roots of the block hash trees for the shares that
310         # comprise it, and its root is in the verinfo.
311         self.share_hash_tree = hashtree.IncompleteHashTree(N)
312         self.share_hash_tree.set_hashes({0: root_hash})
313
314     def decode(self, blocks_and_salts, segnum):
315         """
316         I am a helper method that the mutable file update process uses
317         as a shortcut to decode and decrypt the segments that it needs
318         to fetch in order to perform a file update. I take in a
319         collection of blocks and salts, and pick some of those to make a
320         segment with. I return the plaintext associated with that
321         segment.
322         """
323         # shnum => block hash tree. Unused, but setup_encoding_parameters will
324         # want to set this.
325         self._block_hash_trees = None
326         self._setup_encoding_parameters()
327
328         # _decode_blocks() expects the output of a gatherResults that
329         # contains the outputs of _validate_block() (each of which is a dict
330         # mapping shnum to (block,salt) bytestrings).
331         d = self._decode_blocks([blocks_and_salts], segnum)
332         d.addCallback(self._decrypt_segment)
333         return d
334
335
336     def _setup_encoding_parameters(self):
337         """
338         I set up the encoding parameters, including k, n, the number
339         of segments associated with this file, and the segment decoders.
340         """
341         (seqnum,
342          root_hash,
343          IV,
344          segsize,
345          datalength,
346          k,
347          n,
348          known_prefix,
349          offsets_tuple) = self.verinfo
350         self._required_shares = k
351         self._total_shares = n
352         self._segment_size = segsize
353         self._data_length = datalength
354
355         if not IV:
356             self._version = MDMF_VERSION
357         else:
358             self._version = SDMF_VERSION
359
360         if datalength and segsize:
361             self._num_segments = mathutil.div_ceil(datalength, segsize)
362             self._tail_data_size = datalength % segsize
363         else:
364             self._num_segments = 0
365             self._tail_data_size = 0
366
367         self._segment_decoder = codec.CRSDecoder()
368         self._segment_decoder.set_params(segsize, k, n)
369
370         if  not self._tail_data_size:
371             self._tail_data_size = segsize
372
373         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
374                                                          self._required_shares)
375         if self._tail_segment_size == self._segment_size:
376             self._tail_decoder = self._segment_decoder
377         else:
378             self._tail_decoder = codec.CRSDecoder()
379             self._tail_decoder.set_params(self._tail_segment_size,
380                                           self._required_shares,
381                                           self._total_shares)
382
383         self.log("got encoding parameters: "
384                  "k: %d "
385                  "n: %d "
386                  "%d segments of %d bytes each (%d byte tail segment)" % \
387                  (k, n, self._num_segments, self._segment_size,
388                   self._tail_segment_size))
389
390         if self._block_hash_trees is not None:
391             for i in xrange(self._total_shares):
392                 # So we don't have to do this later.
393                 self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
394
395         # Our last task is to tell the downloader where to start and
396         # where to stop. We use three parameters for that:
397         #   - self._start_segment: the segment that we need to start
398         #     downloading from. 
399         #   - self._current_segment: the next segment that we need to
400         #     download.
401         #   - self._last_segment: The last segment that we were asked to
402         #     download.
403         #
404         #  We say that the download is complete when
405         #  self._current_segment > self._last_segment. We use
406         #  self._start_segment and self._last_segment to know when to
407         #  strip things off of segments, and how much to strip.
408         if self._offset:
409             self.log("got offset: %d" % self._offset)
410             # our start segment is the first segment containing the
411             # offset we were given. 
412             start = self._offset // self._segment_size
413
414             assert start < self._num_segments
415             self._start_segment = start
416             self.log("got start segment: %d" % self._start_segment)
417         else:
418             self._start_segment = 0
419
420
421         # If self._read_length is None, then we want to read the whole
422         # file. Otherwise, we want to read only part of the file, and
423         # need to figure out where to stop reading.
424         if self._read_length is not None:
425             # our end segment is the last segment containing part of the
426             # segment that we were asked to read.
427             self.log("got read length %d" % self._read_length)
428             if self._read_length != 0:
429                 end_data = self._offset + self._read_length
430
431                 # We don't actually need to read the byte at end_data,
432                 # but the one before it.
433                 end = (end_data - 1) // self._segment_size
434
435                 assert end < self._num_segments
436                 self._last_segment = end
437             else:
438                 self._last_segment = self._start_segment
439             self.log("got end segment: %d" % self._last_segment)
440         else:
441             self._last_segment = self._num_segments - 1
442
443         self._current_segment = self._start_segment
444
445     def _activate_enough_servers(self):
446         """
447         I populate self._active_readers with enough active readers to
448         retrieve the contents of this mutable file. I am called before
449         downloading starts, and (eventually) after each validation
450         error, connection error, or other problem in the download.
451         """
452         # TODO: It would be cool to investigate other heuristics for
453         # reader selection. For instance, the cost (in time the user
454         # spends waiting for their file) of selecting a really slow server
455         # that happens to have a primary share is probably more than
456         # selecting a really fast server that doesn't have a primary
457         # share. Maybe the servermap could be extended to provide this
458         # information; it could keep track of latency information while
459         # it gathers more important data, and then this routine could
460         # use that to select active readers.
461         #
462         # (these and other questions would be easier to answer with a
463         #  robust, configurable tahoe-lafs simulator, which modeled node
464         #  failures, differences in node speed, and other characteristics
465         #  that we expect storage servers to have.  You could have
466         #  presets for really stable grids (like allmydata.com),
467         #  friendnets, make it easy to configure your own settings, and
468         #  then simulate the effect of big changes on these use cases
469         #  instead of just reasoning about what the effect might be. Out
470         #  of scope for MDMF, though.)
471
472         # XXX: Why don't format= log messages work here?
473
474         known_shnums = set(self.remaining_sharemap.keys())
475         used_shnums = set([r.shnum for r in self._active_readers])
476         unused_shnums = known_shnums - used_shnums
477
478         if self._verify:
479             new_shnums = unused_shnums # use them all
480         elif len(self._active_readers) < self._required_shares:
481             # need more shares
482             more = self._required_shares - len(self._active_readers)
483             # We favor lower numbered shares, since FEC is faster with
484             # primary shares than with other shares, and lower-numbered
485             # shares are more likely to be primary than higher numbered
486             # shares.
487             new_shnums = sorted(unused_shnums)[:more]
488             if len(new_shnums) < more:
489                 # We don't have enough readers to retrieve the file; fail.
490                 self._raise_notenoughshareserror()
491         else:
492             new_shnums = []
493
494         self.log("adding %d new servers to the active list" % len(new_shnums))
495         for shnum in new_shnums:
496             reader = self.readers[shnum]
497             self._active_readers.append(reader)
498             self.log("added reader for share %d" % shnum)
499             # Each time we add a reader, we check to see if we need the
500             # private key. If we do, we politely ask for it and then continue
501             # computing. If we find that we haven't gotten it at the end of
502             # segment decoding, then we'll take more drastic measures.
503             if self._need_privkey and not self._node.is_readonly():
504                 d = reader.get_encprivkey()
505                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
506                 # XXX: don't just drop the Deferred. We need error-reporting
507                 # but not flow-control here.
508
509     def _try_to_validate_prefix(self, prefix, reader):
510         """
511         I check that the prefix returned by a candidate server for
512         retrieval matches the prefix that the servermap knows about
513         (and, hence, the prefix that was validated earlier). If it does,
514         I return True, which means that I approve of the use of the
515         candidate server for segment retrieval. If it doesn't, I return
516         False, which means that another server must be chosen.
517         """
518         (seqnum,
519          root_hash,
520          IV,
521          segsize,
522          datalength,
523          k,
524          N,
525          known_prefix,
526          offsets_tuple) = self.verinfo
527         if known_prefix != prefix:
528             self.log("prefix from share %d doesn't match" % reader.shnum)
529             raise UncoordinatedWriteError("Mismatched prefix -- this could "
530                                           "indicate an uncoordinated write")
531         # Otherwise, we're okay -- no issues.
532
533
534     def _remove_reader(self, reader):
535         """
536         At various points, we will wish to remove a server from
537         consideration and/or use. These include, but are not necessarily
538         limited to:
539
540             - A connection error.
541             - A mismatched prefix (that is, a prefix that does not match
542               our conception of the version information string).
543             - A failing block hash, salt hash, or share hash, which can
544               indicate disk failure/bit flips, or network trouble.
545
546         This method will do that. I will make sure that the
547         (shnum,reader) combination represented by my reader argument is
548         not used for anything else during this download. I will not
549         advise the reader of any corruption, something that my callers
550         may wish to do on their own.
551         """
552         # TODO: When you're done writing this, see if this is ever
553         # actually used for something that _mark_bad_share isn't. I have
554         # a feeling that they will be used for very similar things, and
555         # that having them both here is just going to be an epic amount
556         # of code duplication.
557         #
558         # (well, okay, not epic, but meaningful)
559         self.log("removing reader %s" % reader)
560         # Remove the reader from _active_readers
561         self._active_readers.remove(reader)
562         # TODO: self.readers.remove(reader)?
563         for shnum in list(self.remaining_sharemap.keys()):
564             self.remaining_sharemap.discard(shnum, reader.server)
565
566
567     def _mark_bad_share(self, server, shnum, reader, f):
568         """
569         I mark the given (server, shnum) as a bad share, which means that it
570         will not be used anywhere else.
571
572         There are several reasons to want to mark something as a bad
573         share. These include:
574
575             - A connection error to the server.
576             - A mismatched prefix (that is, a prefix that does not match
577               our local conception of the version information string).
578             - A failing block hash, salt hash, share hash, or other
579               integrity check.
580
581         This method will ensure that readers that we wish to mark bad
582         (for these reasons or other reasons) are not used for the rest
583         of the download. Additionally, it will attempt to tell the
584         remote server (with no guarantee of success) that its share is
585         corrupt.
586         """
587         self.log("marking share %d on server %s as bad" % \
588                  (shnum, server.get_name()))
589         prefix = self.verinfo[-2]
590         self.servermap.mark_bad_share(server, shnum, prefix)
591         self._remove_reader(reader)
592         self._bad_shares.add((server, shnum, f))
593         self._status.add_problem(server, f)
594         self._last_failure = f
595         if f.check(BadShareError):
596             self.notify_server_corruption(server, shnum, str(f.value))
597
598
599     def _download_current_segment(self):
600         """
601         I download, validate, decode, decrypt, and assemble the segment
602         that this Retrieve is currently responsible for downloading.
603         """
604         if self._current_segment > self._last_segment:
605             # No more segments to download, we're done.
606             self.log("got plaintext, done")
607             return self._done()
608         elif self._verify and len(self._active_readers) == 0:
609             self.log("no more good shares, no need to keep verifying")
610             return self._done()
611         self.log("on segment %d of %d" %
612                  (self._current_segment + 1, self._num_segments))
613         d = self._process_segment(self._current_segment)
614         d.addCallback(lambda ign: self.loop())
615         return d
616
617     def _process_segment(self, segnum):
618         """
619         I download, validate, decode, and decrypt one segment of the
620         file that this Retrieve is retrieving. This means coordinating
621         the process of getting k blocks of that file, validating them,
622         assembling them into one segment with the decoder, and then
623         decrypting them.
624         """
625         self.log("processing segment %d" % segnum)
626
627         # TODO: The old code uses a marker. Should this code do that
628         # too? What did the Marker do?
629
630         # We need to ask each of our active readers for its block and
631         # salt. We will then validate those. If validation is
632         # successful, we will assemble the results into plaintext.
633         ds = []
634         for reader in self._active_readers:
635             started = time.time()
636             d1 = reader.get_block_and_salt(segnum)
637             d2,d3 = self._get_needed_hashes(reader, segnum)
638             d = deferredutil.gatherResults([d1,d2,d3])
639             d.addCallback(self._validate_block, segnum, reader, reader.server, started)
640             # _handle_bad_share takes care of recoverable errors (by dropping
641             # that share and returning None). Any other errors (i.e. code
642             # bugs) are passed through and cause the retrieve to fail.
643             d.addErrback(self._handle_bad_share, [reader])
644             ds.append(d)
645         dl = deferredutil.gatherResults(ds)
646         if self._verify:
647             dl.addCallback(lambda ignored: "")
648             dl.addCallback(self._set_segment)
649         else:
650             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
651         return dl
652
653
654     def _maybe_decode_and_decrypt_segment(self, results, segnum):
655         """
656         I take the results of fetching and validating the blocks from
657         _process_segment. If validation and fetching succeeded without
658         incident, I will proceed with decoding and decryption. Otherwise, I
659         will do nothing.
660         """
661         self.log("trying to decode and decrypt segment %d" % segnum)
662
663         # 'results' is the output of a gatherResults set up in
664         # _process_segment(). Each component Deferred will either contain the
665         # non-Failure output of _validate_block() for a single block (i.e.
666         # {segnum:(block,salt)}), or None if _validate_block threw an
667         # exception and _validation_or_decoding_failed handled it (by
668         # dropping that server).
669
670         if None in results:
671             self.log("some validation operations failed; not proceeding")
672             return defer.succeed(None)
673         self.log("everything looks ok, building segment %d" % segnum)
674         d = self._decode_blocks(results, segnum)
675         d.addCallback(self._decrypt_segment)
676         # check to see whether we've been paused before writing
677         # anything.
678         d.addCallback(self._check_for_paused)
679         d.addCallback(self._check_for_stopped)
680         d.addCallback(self._set_segment)
681         return d
682
683
684     def _set_segment(self, segment):
685         """
686         Given a plaintext segment, I register that segment with the
687         target that is handling the file download.
688         """
689         self.log("got plaintext for segment %d" % self._current_segment)
690         if self._current_segment == self._start_segment:
691             # We're on the first segment. It's possible that we want
692             # only some part of the end of this segment, and that we
693             # just downloaded the whole thing to get that part. If so,
694             # we need to account for that and give the reader just the
695             # data that they want.
696             n = self._offset % self._segment_size
697             self.log("stripping %d bytes off of the first segment" % n)
698             self.log("original segment length: %d" % len(segment))
699             segment = segment[n:]
700             self.log("new segment length: %d" % len(segment))
701
702         if self._current_segment == self._last_segment and self._read_length is not None:
703             # We're on the last segment. It's possible that we only want
704             # part of the beginning of this segment, and that we
705             # downloaded the whole thing anyway. Make sure to give the
706             # caller only the portion of the segment that they want to
707             # receive.
708             extra = self._read_length
709             if self._start_segment != self._last_segment:
710                 extra -= self._segment_size - \
711                             (self._offset % self._segment_size)
712             extra %= self._segment_size
713             self.log("original segment length: %d" % len(segment))
714             segment = segment[:extra]
715             self.log("new segment length: %d" % len(segment))
716             self.log("only taking %d bytes of the last segment" % extra)
717
718         if not self._verify:
719             self._consumer.write(segment)
720         else:
721             # we don't care about the plaintext if we are doing a verify.
722             segment = None
723         self._current_segment += 1
724
725
726     def _handle_bad_share(self, f, readers):
727         """
728         I am called when a block or a salt fails to correctly validate, or when
729         the decryption or decoding operation fails for some reason.  I react to
730         this failure by notifying the remote server of corruption, and then
731         removing the remote server from further activity.
732         """
733         # these are the errors we can tolerate: by giving up on this share
734         # and finding others to replace it. Any other errors (i.e. coding
735         # bugs) are re-raised, causing the download to fail.
736         f.trap(DeadReferenceError, RemoteException, BadShareError)
737
738         # DeadReferenceError happens when we try to fetch data from a server
739         # that has gone away. RemoteException happens if the server had an
740         # internal error. BadShareError encompasses: (UnknownVersionError,
741         # LayoutInvalid, struct.error) which happen when we get obviously
742         # wrong data, and CorruptShareError which happens later, when we
743         # perform integrity checks on the data.
744
745         assert isinstance(readers, list)
746         bad_shnums = [reader.shnum for reader in readers]
747
748         self.log("validation or decoding failed on share(s) %s, server(s) %s "
749                  ", segment %d: %s" % \
750                  (bad_shnums, readers, self._current_segment, str(f)))
751         for reader in readers:
752             self._mark_bad_share(reader.server, reader.shnum, reader, f)
753         return None
754
755
756     def _validate_block(self, results, segnum, reader, server, started):
757         """
758         I validate a block from one share on a remote server.
759         """
760         # Grab the part of the block hash tree that is necessary to
761         # validate this block, then generate the block hash root.
762         self.log("validating share %d for segment %d" % (reader.shnum,
763                                                              segnum))
764         elapsed = time.time() - started
765         self._status.add_fetch_timing(server, elapsed)
766         self._set_current_status("validating blocks")
767
768         block_and_salt, blockhashes, sharehashes = results
769         block, salt = block_and_salt
770
771         blockhashes = dict(enumerate(blockhashes))
772         self.log("the reader gave me the following blockhashes: %s" % \
773                  blockhashes.keys())
774         self.log("the reader gave me the following sharehashes: %s" % \
775                  sharehashes.keys())
776         bht = self._block_hash_trees[reader.shnum]
777
778         if bht.needed_hashes(segnum, include_leaf=True):
779             try:
780                 bht.set_hashes(blockhashes)
781             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
782                     IndexError), e:
783                 raise CorruptShareError(server,
784                                         reader.shnum,
785                                         "block hash tree failure: %s" % e)
786
787         if self._version == MDMF_VERSION:
788             blockhash = hashutil.block_hash(salt + block)
789         else:
790             blockhash = hashutil.block_hash(block)
791         # If this works without an error, then validation is
792         # successful.
793         try:
794            bht.set_hashes(leaves={segnum: blockhash})
795         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
796                 IndexError), e:
797             raise CorruptShareError(server,
798                                     reader.shnum,
799                                     "block hash tree failure: %s" % e)
800
801         # Reaching this point means that we know that this segment
802         # is correct. Now we need to check to see whether the share
803         # hash chain is also correct. 
804         # SDMF wrote share hash chains that didn't contain the
805         # leaves, which would be produced from the block hash tree.
806         # So we need to validate the block hash tree first. If
807         # successful, then bht[0] will contain the root for the
808         # shnum, which will be a leaf in the share hash tree, which
809         # will allow us to validate the rest of the tree.
810         try:
811             self.share_hash_tree.set_hashes(hashes=sharehashes,
812                                         leaves={reader.shnum: bht[0]})
813         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
814                 IndexError), e:
815             raise CorruptShareError(server,
816                                     reader.shnum,
817                                     "corrupt hashes: %s" % e)
818
819         self.log('share %d is valid for segment %d' % (reader.shnum,
820                                                        segnum))
821         return {reader.shnum: (block, salt)}
822
823
824     def _get_needed_hashes(self, reader, segnum):
825         """
826         I get the hashes needed to validate segnum from the reader, then return
827         to my caller when this is done.
828         """
829         bht = self._block_hash_trees[reader.shnum]
830         needed = bht.needed_hashes(segnum, include_leaf=True)
831         # The root of the block hash tree is also a leaf in the share
832         # hash tree. So we don't need to fetch it from the remote
833         # server. In the case of files with one segment, this means that
834         # we won't fetch any block hash tree from the remote server,
835         # since the hash of each share of the file is the entire block
836         # hash tree, and is a leaf in the share hash tree. This is fine,
837         # since any share corruption will be detected in the share hash
838         # tree.
839         #needed.discard(0)
840         self.log("getting blockhashes for segment %d, share %d: %s" % \
841                  (segnum, reader.shnum, str(needed)))
842         d1 = reader.get_blockhashes(needed, force_remote=True)
843         if self.share_hash_tree.needed_hashes(reader.shnum):
844             need = self.share_hash_tree.needed_hashes(reader.shnum)
845             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
846                                                                  str(need)))
847             d2 = reader.get_sharehashes(need, force_remote=True)
848         else:
849             d2 = defer.succeed({}) # the logic in the next method
850                                    # expects a dict
851         return d1,d2
852
853
854     def _decode_blocks(self, results, segnum):
855         """
856         I take a list of k blocks and salts, and decode that into a
857         single encrypted segment.
858         """
859         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
860         # want to merge them all
861         blocks_and_salts = {}
862         for d in results:
863             blocks_and_salts.update(d)
864
865         # All of these blocks should have the same salt; in SDMF, it is
866         # the file-wide IV, while in MDMF it is the per-segment salt. In
867         # either case, we just need to get one of them and use it.
868         #
869         # d.items()[0] is like (shnum, (block, salt))
870         # d.items()[0][1] is like (block, salt)
871         # d.items()[0][1][1] is the salt.
872         salt = blocks_and_salts.items()[0][1][1]
873         # Next, extract just the blocks from the dict. We'll use the
874         # salt in the next step.
875         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
876         d2 = dict(share_and_shareids)
877         shareids = []
878         shares = []
879         for shareid, share in d2.items():
880             shareids.append(shareid)
881             shares.append(share)
882
883         self._set_current_status("decoding")
884         started = time.time()
885         assert len(shareids) >= self._required_shares, len(shareids)
886         # zfec really doesn't want extra shares
887         shareids = shareids[:self._required_shares]
888         shares = shares[:self._required_shares]
889         self.log("decoding segment %d" % segnum)
890         if segnum == self._num_segments - 1:
891             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
892         else:
893             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
894         def _process(buffers):
895             segment = "".join(buffers)
896             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
897                      segnum=segnum,
898                      numsegs=self._num_segments,
899                      level=log.NOISY)
900             self.log(" joined length %d, datalength %d" %
901                      (len(segment), self._data_length))
902             if segnum == self._num_segments - 1:
903                 size_to_use = self._tail_data_size
904             else:
905                 size_to_use = self._segment_size
906             segment = segment[:size_to_use]
907             self.log(" segment len=%d" % len(segment))
908             self._status.accumulate_decode_time(time.time() - started)
909             return segment, salt
910         d.addCallback(_process)
911         return d
912
913
914     def _decrypt_segment(self, segment_and_salt):
915         """
916         I take a single segment and its salt, and decrypt it. I return
917         the plaintext of the segment that is in my argument.
918         """
919         segment, salt = segment_and_salt
920         self._set_current_status("decrypting")
921         self.log("decrypting segment %d" % self._current_segment)
922         started = time.time()
923         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
924         decryptor = AES(key)
925         plaintext = decryptor.process(segment)
926         self._status.accumulate_decrypt_time(time.time() - started)
927         return plaintext
928
929
930     def notify_server_corruption(self, server, shnum, reason):
931         rref = server.get_rref()
932         rref.callRemoteOnly("advise_corrupt_share",
933                             "mutable", self._storage_index, shnum, reason)
934
935
936     def _try_to_validate_privkey(self, enc_privkey, reader, server):
937         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
938         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
939         if alleged_writekey != self._node.get_writekey():
940             self.log("invalid privkey from %s shnum %d" %
941                      (reader, reader.shnum),
942                      level=log.WEIRD, umid="YIw4tA")
943             if self._verify:
944                 self.servermap.mark_bad_share(server, reader.shnum,
945                                               self.verinfo[-2])
946                 e = CorruptShareError(server,
947                                       reader.shnum,
948                                       "invalid privkey")
949                 f = failure.Failure(e)
950                 self._bad_shares.add((server, reader.shnum, f))
951             return
952
953         # it's good
954         self.log("got valid privkey from shnum %d on reader %s" %
955                  (reader.shnum, reader))
956         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
957         self._node._populate_encprivkey(enc_privkey)
958         self._node._populate_privkey(privkey)
959         self._need_privkey = False
960
961
962
963     def _done(self):
964         """
965         I am called by _download_current_segment when the download process
966         has finished successfully. After making some useful logging
967         statements, I return the decrypted contents to the owner of this
968         Retrieve object through self._done_deferred.
969         """
970         self._running = False
971         self._status.set_active(False)
972         now = time.time()
973         self._status.timings['total'] = now - self._started
974         self._status.timings['fetch'] = now - self._started_fetching
975         self._status.set_status("Finished")
976         self._status.set_progress(1.0)
977
978         # remember the encoding parameters, use them again next time
979         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
980          offsets_tuple) = self.verinfo
981         self._node._populate_required_shares(k)
982         self._node._populate_total_shares(N)
983
984         if self._verify:
985             ret = self._bad_shares
986             self.log("done verifying, found %d bad shares" % len(ret))
987         else:
988             # TODO: upload status here?
989             ret = self._consumer
990             self._consumer.unregisterProducer()
991         eventually(self._done_deferred.callback, ret)
992
993
994     def _raise_notenoughshareserror(self):
995         """
996         I am called by _activate_enough_servers when there are not enough
997         active servers left to complete the download. After making some
998         useful logging statements, I throw an exception to that effect
999         to the caller of this Retrieve object through
1000         self._done_deferred.
1001         """
1002
1003         format = ("ran out of servers: "
1004                   "have %(have)d of %(total)d segments "
1005                   "found %(bad)d bad shares "
1006                   "encoding %(k)d-of-%(n)d")
1007         args = {"have": self._current_segment,
1008                 "total": self._num_segments,
1009                 "need": self._last_segment,
1010                 "k": self._required_shares,
1011                 "n": self._total_shares,
1012                 "bad": len(self._bad_shares)}
1013         raise NotEnoughSharesError("%s, last failure: %s" %
1014                                    (format % args, str(self._last_failure)))
1015
1016     def _error(self, f):
1017         # all errors, including NotEnoughSharesError, land here
1018         self._running = False
1019         self._status.set_active(False)
1020         now = time.time()
1021         self._status.timings['total'] = now - self._started
1022         self._status.timings['fetch'] = now - self._started_fetching
1023         self._status.set_status("Failed")
1024         eventually(self._done_deferred.errback, f)