]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
Retrieve: rewrite flow-control: use a top-level loop() to catch all errors
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
10                                  MDMF_VERSION, SDMF_VERSION
11 from allmydata.util import hashutil, log, mathutil
12 from allmydata.util.dictutil import DictOfSets
13 from allmydata import hashtree, codec
14 from allmydata.storage.server import si_b2a
15 from pycryptopp.cipher.aes import AES
16 from pycryptopp.publickey import rsa
17
18 from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
19 from allmydata.mutable.layout import MDMFSlotReadProxy
20
21 class RetrieveStatus:
22     implements(IRetrieveStatus)
23     statusid_counter = count(0)
24     def __init__(self):
25         self.timings = {}
26         self.timings["fetch_per_server"] = {}
27         self.timings["decode"] = 0.0
28         self.timings["decrypt"] = 0.0
29         self.timings["cumulative_verify"] = 0.0
30         self.problems = {}
31         self.active = True
32         self.storage_index = None
33         self.helper = False
34         self.encoding = ("?","?")
35         self.size = None
36         self.status = "Not started"
37         self.progress = 0.0
38         self.counter = self.statusid_counter.next()
39         self.started = time.time()
40
41     def get_started(self):
42         return self.started
43     def get_storage_index(self):
44         return self.storage_index
45     def get_encoding(self):
46         return self.encoding
47     def using_helper(self):
48         return self.helper
49     def get_size(self):
50         return self.size
51     def get_status(self):
52         return self.status
53     def get_progress(self):
54         return self.progress
55     def get_active(self):
56         return self.active
57     def get_counter(self):
58         return self.counter
59
60     def add_fetch_timing(self, peerid, elapsed):
61         if peerid not in self.timings["fetch_per_server"]:
62             self.timings["fetch_per_server"][peerid] = []
63         self.timings["fetch_per_server"][peerid].append(elapsed)
64     def accumulate_decode_time(self, elapsed):
65         self.timings["decode"] += elapsed
66     def accumulate_decrypt_time(self, elapsed):
67         self.timings["decrypt"] += elapsed
68     def set_storage_index(self, si):
69         self.storage_index = si
70     def set_helper(self, helper):
71         self.helper = helper
72     def set_encoding(self, k, n):
73         self.encoding = (k, n)
74     def set_size(self, size):
75         self.size = size
76     def set_status(self, status):
77         self.status = status
78     def set_progress(self, value):
79         self.progress = value
80     def set_active(self, value):
81         self.active = value
82
83 class Marker:
84     pass
85
86 class Retrieve:
87     # this class is currently single-use. Eventually (in MDMF) we will make
88     # it multi-use, in which case you can call download(range) multiple
89     # times, and each will have a separate response chain. However the
90     # Retrieve object will remain tied to a specific version of the file, and
91     # will use a single ServerMap instance.
92     implements(IPushProducer)
93
94     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False,
95                  verify=False):
96         self._node = filenode
97         assert self._node.get_pubkey()
98         self._storage_index = filenode.get_storage_index()
99         assert self._node.get_readkey()
100         self._last_failure = None
101         prefix = si_b2a(self._storage_index)[:5]
102         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
103         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
104         self._running = True
105         self._decoding = False
106         self._bad_shares = set()
107
108         self.servermap = servermap
109         assert self._node.get_pubkey()
110         self.verinfo = verinfo
111         # during repair, we may be called upon to grab the private key, since
112         # it wasn't picked up during a verify=False checker run, and we'll
113         # need it for repair to generate a new version.
114         self._need_privkey = verify or (fetch_privkey
115                                         and not self._node.get_privkey())
116
117         if self._need_privkey:
118             # TODO: Evaluate the need for this. We'll use it if we want
119             # to limit how many queries are on the wire for the privkey
120             # at once.
121             self._privkey_query_markers = [] # one Marker for each time we've
122                                              # tried to get the privkey.
123
124         # verify means that we are using the downloader logic to verify all
125         # of our shares. This tells the downloader a few things.
126         # 
127         # 1. We need to download all of the shares.
128         # 2. We don't need to decode or decrypt the shares, since our
129         #    caller doesn't care about the plaintext, only the
130         #    information about which shares are or are not valid.
131         # 3. When we are validating readers, we need to validate the
132         #    signature on the prefix. Do we? We already do this in the
133         #    servermap update?
134         self._verify = verify
135
136         self._status = RetrieveStatus()
137         self._status.set_storage_index(self._storage_index)
138         self._status.set_helper(False)
139         self._status.set_progress(0.0)
140         self._status.set_active(True)
141         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
142          offsets_tuple) = self.verinfo
143         self._status.set_size(datalength)
144         self._status.set_encoding(k, N)
145         self.readers = {}
146         self._pause_deferred = None
147         self._offset = None
148         self._read_length = None
149         self.log("got seqnum %d" % self.verinfo[0])
150
151
152     def get_status(self):
153         return self._status
154
155     def log(self, *args, **kwargs):
156         if "parent" not in kwargs:
157             kwargs["parent"] = self._log_number
158         if "facility" not in kwargs:
159             kwargs["facility"] = "tahoe.mutable.retrieve"
160         return log.msg(*args, **kwargs)
161
162     def _set_current_status(self, state):
163         seg = "%d/%d" % (self._current_segment, self._last_segment)
164         self._status.set_status("segment %s (%s)" % (seg, state))
165
166     ###################
167     # IPushProducer
168
169     def pauseProducing(self):
170         """
171         I am called by my download target if we have produced too much
172         data for it to handle. I make the downloader stop producing new
173         data until my resumeProducing method is called.
174         """
175         if self._pause_deferred is not None:
176             return
177
178         # fired when the download is unpaused. 
179         self._old_status = self._status.get_status()
180         self._set_current_status("paused")
181
182         self._pause_deferred = defer.Deferred()
183
184
185     def resumeProducing(self):
186         """
187         I am called by my download target once it is ready to begin
188         receiving data again.
189         """
190         if self._pause_deferred is None:
191             return
192
193         p = self._pause_deferred
194         self._pause_deferred = None
195         self._status.set_status(self._old_status)
196
197         eventually(p.callback, None)
198
199
200     def _check_for_paused(self, res):
201         """
202         I am called just before a write to the consumer. I return a
203         Deferred that eventually fires with the data that is to be
204         written to the consumer. If the download has not been paused,
205         the Deferred fires immediately. Otherwise, the Deferred fires
206         when the downloader is unpaused.
207         """
208         if self._pause_deferred is not None:
209             d = defer.Deferred()
210             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
211             return d
212         return defer.succeed(res)
213
214
215     def download(self, consumer=None, offset=0, size=None):
216         assert IConsumer.providedBy(consumer) or self._verify
217
218         if consumer:
219             self._consumer = consumer
220             # we provide IPushProducer, so streaming=True, per
221             # IConsumer.
222             self._consumer.registerProducer(self, streaming=True)
223
224         self._done_deferred = defer.Deferred()
225         self._offset = offset
226         self._read_length = size
227         self._setup_download()
228         self._setup_encoding_parameters()
229         self.log("starting download")
230         self._started_fetching = time.time()
231         # The download process beyond this is a state machine.
232         # _add_active_peers will select the peers that we want to use
233         # for the download, and then attempt to start downloading. After
234         # each segment, it will check for doneness, reacting to broken
235         # peers and corrupt shares as necessary. If it runs out of good
236         # peers before downloading all of the segments, _done_deferred
237         # will errback.  Otherwise, it will eventually callback with the
238         # contents of the mutable file.
239         self.loop()
240         return self._done_deferred
241
242     def loop(self):
243         d = fireEventually(None) # avoid #237 recursion limit problem
244         d.addCallback(lambda ign: self._activate_enough_peers())
245         d.addCallback(lambda ign: self._download_current_segment())
246         # when we're done, _download_current_segment will call _done. If we
247         # aren't, it will call loop() again.
248         d.addErrback(self._error)
249
250     def _setup_download(self):
251         self._started = time.time()
252         self._status.set_status("Retrieving Shares")
253
254         # how many shares do we need?
255         (seqnum,
256          root_hash,
257          IV,
258          segsize,
259          datalength,
260          k,
261          N,
262          prefix,
263          offsets_tuple) = self.verinfo
264
265         # first, which servers can we use?
266         versionmap = self.servermap.make_versionmap()
267         shares = versionmap[self.verinfo]
268         # this sharemap is consumed as we decide to send requests
269         self.remaining_sharemap = DictOfSets()
270         for (shnum, peerid, timestamp) in shares:
271             self.remaining_sharemap.add(shnum, peerid)
272             # If the servermap update fetched anything, it fetched at least 1
273             # KiB, so we ask for that much.
274             # TODO: Change the cache methods to allow us to fetch all of the
275             # data that they have, then change this method to do that.
276             any_cache = self._node._read_from_cache(self.verinfo, shnum,
277                                                     0, 1000)
278             ss = self.servermap.connections[peerid]
279             reader = MDMFSlotReadProxy(ss,
280                                        self._storage_index,
281                                        shnum,
282                                        any_cache)
283             reader.peerid = peerid
284             self.readers[shnum] = reader
285         assert len(self.remaining_sharemap) >= k
286
287         self.shares = {} # maps shnum to validated blocks
288         self._active_readers = [] # list of active readers for this dl.
289         self._validated_readers = set() # set of readers that we have
290                                         # validated the prefix of
291         self._block_hash_trees = {} # shnum => hashtree
292
293         # We need one share hash tree for the entire file; its leaves
294         # are the roots of the block hash trees for the shares that
295         # comprise it, and its root is in the verinfo.
296         self.share_hash_tree = hashtree.IncompleteHashTree(N)
297         self.share_hash_tree.set_hashes({0: root_hash})
298
299     def decode(self, blocks_and_salts, segnum):
300         """
301         I am a helper method that the mutable file update process uses
302         as a shortcut to decode and decrypt the segments that it needs
303         to fetch in order to perform a file update. I take in a
304         collection of blocks and salts, and pick some of those to make a
305         segment with. I return the plaintext associated with that
306         segment.
307         """
308         # shnum => block hash tree. Unused, but setup_encoding_parameters will
309         # want to set this.
310         self._block_hash_trees = None
311         self._setup_encoding_parameters()
312
313         # This is the form expected by decode.
314         blocks_and_salts = blocks_and_salts.items()
315         blocks_and_salts = [(True, [d]) for d in blocks_and_salts]
316
317         d = self._decode_blocks(blocks_and_salts, segnum)
318         d.addCallback(self._decrypt_segment)
319         return d
320
321
322     def _setup_encoding_parameters(self):
323         """
324         I set up the encoding parameters, including k, n, the number
325         of segments associated with this file, and the segment decoders.
326         """
327         (seqnum,
328          root_hash,
329          IV,
330          segsize,
331          datalength,
332          k,
333          n,
334          known_prefix,
335          offsets_tuple) = self.verinfo
336         self._required_shares = k
337         self._total_shares = n
338         self._segment_size = segsize
339         self._data_length = datalength
340
341         if not IV:
342             self._version = MDMF_VERSION
343         else:
344             self._version = SDMF_VERSION
345
346         if datalength and segsize:
347             self._num_segments = mathutil.div_ceil(datalength, segsize)
348             self._tail_data_size = datalength % segsize
349         else:
350             self._num_segments = 0
351             self._tail_data_size = 0
352
353         self._segment_decoder = codec.CRSDecoder()
354         self._segment_decoder.set_params(segsize, k, n)
355
356         if  not self._tail_data_size:
357             self._tail_data_size = segsize
358
359         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
360                                                          self._required_shares)
361         if self._tail_segment_size == self._segment_size:
362             self._tail_decoder = self._segment_decoder
363         else:
364             self._tail_decoder = codec.CRSDecoder()
365             self._tail_decoder.set_params(self._tail_segment_size,
366                                           self._required_shares,
367                                           self._total_shares)
368
369         self.log("got encoding parameters: "
370                  "k: %d "
371                  "n: %d "
372                  "%d segments of %d bytes each (%d byte tail segment)" % \
373                  (k, n, self._num_segments, self._segment_size,
374                   self._tail_segment_size))
375
376         if self._block_hash_trees is not None:
377             for i in xrange(self._total_shares):
378                 # So we don't have to do this later.
379                 self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
380
381         # Our last task is to tell the downloader where to start and
382         # where to stop. We use three parameters for that:
383         #   - self._start_segment: the segment that we need to start
384         #     downloading from. 
385         #   - self._current_segment: the next segment that we need to
386         #     download.
387         #   - self._last_segment: The last segment that we were asked to
388         #     download.
389         #
390         #  We say that the download is complete when
391         #  self._current_segment > self._last_segment. We use
392         #  self._start_segment and self._last_segment to know when to
393         #  strip things off of segments, and how much to strip.
394         if self._offset:
395             self.log("got offset: %d" % self._offset)
396             # our start segment is the first segment containing the
397             # offset we were given. 
398             start = self._offset // self._segment_size
399
400             assert start < self._num_segments
401             self._start_segment = start
402             self.log("got start segment: %d" % self._start_segment)
403         else:
404             self._start_segment = 0
405
406
407         # If self._read_length is None, then we want to read the whole
408         # file. Otherwise, we want to read only part of the file, and
409         # need to figure out where to stop reading.
410         if self._read_length is not None:
411             # our end segment is the last segment containing part of the
412             # segment that we were asked to read.
413             self.log("got read length %d" % self._read_length)
414             if self._read_length != 0:
415                 end_data = self._offset + self._read_length
416
417                 # We don't actually need to read the byte at end_data,
418                 # but the one before it.
419                 end = (end_data - 1) // self._segment_size
420
421                 assert end < self._num_segments
422                 self._last_segment = end
423             else:
424                 self._last_segment = self._start_segment
425             self.log("got end segment: %d" % self._last_segment)
426         else:
427             self._last_segment = self._num_segments - 1
428
429         self._current_segment = self._start_segment
430
431     def _activate_enough_peers(self):
432         """
433         I populate self._active_readers with enough active readers to
434         retrieve the contents of this mutable file. I am called before
435         downloading starts, and (eventually) after each validation
436         error, connection error, or other problem in the download.
437         """
438         # TODO: It would be cool to investigate other heuristics for
439         # reader selection. For instance, the cost (in time the user
440         # spends waiting for their file) of selecting a really slow peer
441         # that happens to have a primary share is probably more than
442         # selecting a really fast peer that doesn't have a primary
443         # share. Maybe the servermap could be extended to provide this
444         # information; it could keep track of latency information while
445         # it gathers more important data, and then this routine could
446         # use that to select active readers.
447         #
448         # (these and other questions would be easier to answer with a
449         #  robust, configurable tahoe-lafs simulator, which modeled node
450         #  failures, differences in node speed, and other characteristics
451         #  that we expect storage servers to have.  You could have
452         #  presets for really stable grids (like allmydata.com),
453         #  friendnets, make it easy to configure your own settings, and
454         #  then simulate the effect of big changes on these use cases
455         #  instead of just reasoning about what the effect might be. Out
456         #  of scope for MDMF, though.)
457
458         # We need at least self._required_shares readers to download a
459         # segment. If we're verifying, we need all shares.
460         if self._verify:
461             needed = self._total_shares
462         else:
463             needed = self._required_shares
464         # XXX: Why don't format= log messages work here?
465         self.log("adding %d peers to the active peers list" % needed)
466
467         if len(self._active_readers) >= needed:
468             # enough shares are active
469             return
470
471         more = needed - len(self._active_readers)
472         known_shnums = set(self.remaining_sharemap.keys())
473         used_shnums = set([r.shnum for r in self._active_readers])
474         unused_shnums = known_shnums - used_shnums
475         # We favor lower numbered shares, since FEC is faster with
476         # primary shares than with other shares, and lower-numbered
477         # shares are more likely to be primary than higher numbered
478         # shares.
479         new_shnums = sorted(unused_shnums)[:more]
480         if len(new_shnums) < more and not self._verify:
481             # We don't have enough readers to retrieve the file; fail.
482             self._raise_notenoughshareserror()
483
484         for shnum in new_shnums:
485             reader = self.readers[shnum]
486             self._active_readers.append(reader)
487             self._validated_readers.add(reader)
488             self.log("added reader for share %d" % shnum)
489             # Each time we validate a reader, we check to see if we need the
490             # private key. If we do, we politely ask for it and then continue
491             # computing. If we find that we haven't gotten it at the end of
492             # segment decoding, then we'll take more drastic measures.
493             if self._need_privkey and not self._node.is_readonly():
494                 d = reader.get_encprivkey()
495                 d.addCallback(self._try_to_validate_privkey, reader)
496                 # XXX: don't just drop the Deferred. We need error-reporting
497                 # but not flow-control here.
498         assert len(self._active_readers) >= self._required_shares
499
500     def _try_to_validate_prefix(self, prefix, reader):
501         """
502         I check that the prefix returned by a candidate server for
503         retrieval matches the prefix that the servermap knows about
504         (and, hence, the prefix that was validated earlier). If it does,
505         I return True, which means that I approve of the use of the
506         candidate server for segment retrieval. If it doesn't, I return
507         False, which means that another server must be chosen.
508         """
509         (seqnum,
510          root_hash,
511          IV,
512          segsize,
513          datalength,
514          k,
515          N,
516          known_prefix,
517          offsets_tuple) = self.verinfo
518         if known_prefix != prefix:
519             self.log("prefix from share %d doesn't match" % reader.shnum)
520             raise UncoordinatedWriteError("Mismatched prefix -- this could "
521                                           "indicate an uncoordinated write")
522         # Otherwise, we're okay -- no issues.
523
524
525     def _remove_reader(self, reader):
526         """
527         At various points, we will wish to remove a peer from
528         consideration and/or use. These include, but are not necessarily
529         limited to:
530
531             - A connection error.
532             - A mismatched prefix (that is, a prefix that does not match
533               our conception of the version information string).
534             - A failing block hash, salt hash, or share hash, which can
535               indicate disk failure/bit flips, or network trouble.
536
537         This method will do that. I will make sure that the
538         (shnum,reader) combination represented by my reader argument is
539         not used for anything else during this download. I will not
540         advise the reader of any corruption, something that my callers
541         may wish to do on their own.
542         """
543         # TODO: When you're done writing this, see if this is ever
544         # actually used for something that _mark_bad_share isn't. I have
545         # a feeling that they will be used for very similar things, and
546         # that having them both here is just going to be an epic amount
547         # of code duplication.
548         #
549         # (well, okay, not epic, but meaningful)
550         self.log("removing reader %s" % reader)
551         # Remove the reader from _active_readers
552         self._active_readers.remove(reader)
553         # TODO: self.readers.remove(reader)?
554         for shnum in list(self.remaining_sharemap.keys()):
555             self.remaining_sharemap.discard(shnum, reader.peerid)
556
557
558     def _mark_bad_share(self, reader, f):
559         """
560         I mark the (peerid, shnum) encapsulated by my reader argument as
561         a bad share, which means that it will not be used anywhere else.
562
563         There are several reasons to want to mark something as a bad
564         share. These include:
565
566             - A connection error to the peer.
567             - A mismatched prefix (that is, a prefix that does not match
568               our local conception of the version information string).
569             - A failing block hash, salt hash, share hash, or other
570               integrity check.
571
572         This method will ensure that readers that we wish to mark bad
573         (for these reasons or other reasons) are not used for the rest
574         of the download. Additionally, it will attempt to tell the
575         remote peer (with no guarantee of success) that its share is
576         corrupt.
577         """
578         self.log("marking share %d on server %s as bad" % \
579                  (reader.shnum, reader))
580         prefix = self.verinfo[-2]
581         self.servermap.mark_bad_share(reader.peerid,
582                                       reader.shnum,
583                                       prefix)
584         self._remove_reader(reader)
585         self._bad_shares.add((reader.peerid, reader.shnum, f))
586         self._status.problems[reader.peerid] = f
587         self._last_failure = f
588         self.notify_server_corruption(reader.peerid, reader.shnum,
589                                       str(f.value))
590
591
592     def _download_current_segment(self):
593         """
594         I download, validate, decode, decrypt, and assemble the segment
595         that this Retrieve is currently responsible for downloading.
596         """
597         assert len(self._active_readers) >= self._required_shares
598         if self._current_segment > self._last_segment:
599             # No more segments to download, we're done.
600             self.log("got plaintext, done")
601             return self._done()
602         self.log("on segment %d of %d" %
603                  (self._current_segment + 1, self._num_segments))
604         d = self._process_segment(self._current_segment)
605         d.addCallback(lambda ign: self.loop())
606         return d
607
608     def _process_segment(self, segnum):
609         """
610         I download, validate, decode, and decrypt one segment of the
611         file that this Retrieve is retrieving. This means coordinating
612         the process of getting k blocks of that file, validating them,
613         assembling them into one segment with the decoder, and then
614         decrypting them.
615         """
616         self.log("processing segment %d" % segnum)
617
618         # TODO: The old code uses a marker. Should this code do that
619         # too? What did the Marker do?
620         assert len(self._active_readers) >= self._required_shares
621
622         # We need to ask each of our active readers for its block and
623         # salt. We will then validate those. If validation is
624         # successful, we will assemble the results into plaintext.
625         ds = []
626         for reader in self._active_readers:
627             started = time.time()
628             d = reader.get_block_and_salt(segnum)
629             d2 = self._get_needed_hashes(reader, segnum)
630             dl = defer.DeferredList([d, d2], consumeErrors=True)
631             dl.addCallback(self._validate_block, segnum, reader, started)
632             dl.addErrback(self._validation_or_decoding_failed, [reader])
633             ds.append(dl)
634         dl = defer.DeferredList(ds)
635         if self._verify:
636             dl.addCallback(lambda ignored: "")
637             dl.addCallback(self._set_segment)
638         else:
639             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
640         return dl
641
642
643     def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum):
644         """
645         I take the results of fetching and validating the blocks from a
646         callback chain in another method. If the results are such that
647         they tell me that validation and fetching succeeded without
648         incident, I will proceed with decoding and decryption.
649         Otherwise, I will do nothing.
650         """
651         self.log("trying to decode and decrypt segment %d" % segnum)
652         failures = False
653         for block_and_salt in blocks_and_salts:
654             if not block_and_salt[0] or block_and_salt[1] == None:
655                 self.log("some validation operations failed; not proceeding")
656                 failures = True
657                 break
658         if not failures:
659             self.log("everything looks ok, building segment %d" % segnum)
660             d = self._decode_blocks(blocks_and_salts, segnum)
661             d.addCallback(self._decrypt_segment)
662             d.addErrback(self._validation_or_decoding_failed,
663                          self._active_readers)
664             # check to see whether we've been paused before writing
665             # anything.
666             d.addCallback(self._check_for_paused)
667             d.addCallback(self._set_segment)
668             return d
669         else:
670             return defer.succeed(None)
671
672
673     def _set_segment(self, segment):
674         """
675         Given a plaintext segment, I register that segment with the
676         target that is handling the file download.
677         """
678         self.log("got plaintext for segment %d" % self._current_segment)
679         if self._current_segment == self._start_segment:
680             # We're on the first segment. It's possible that we want
681             # only some part of the end of this segment, and that we
682             # just downloaded the whole thing to get that part. If so,
683             # we need to account for that and give the reader just the
684             # data that they want.
685             n = self._offset % self._segment_size
686             self.log("stripping %d bytes off of the first segment" % n)
687             self.log("original segment length: %d" % len(segment))
688             segment = segment[n:]
689             self.log("new segment length: %d" % len(segment))
690
691         if self._current_segment == self._last_segment and self._read_length is not None:
692             # We're on the last segment. It's possible that we only want
693             # part of the beginning of this segment, and that we
694             # downloaded the whole thing anyway. Make sure to give the
695             # caller only the portion of the segment that they want to
696             # receive.
697             extra = self._read_length
698             if self._start_segment != self._last_segment:
699                 extra -= self._segment_size - \
700                             (self._offset % self._segment_size)
701             extra %= self._segment_size
702             self.log("original segment length: %d" % len(segment))
703             segment = segment[:extra]
704             self.log("new segment length: %d" % len(segment))
705             self.log("only taking %d bytes of the last segment" % extra)
706
707         if not self._verify:
708             self._consumer.write(segment)
709         else:
710             # we don't care about the plaintext if we are doing a verify.
711             segment = None
712         self._current_segment += 1
713
714
715     def _validation_or_decoding_failed(self, f, readers):
716         """
717         I am called when a block or a salt fails to correctly validate, or when
718         the decryption or decoding operation fails for some reason.  I react to
719         this failure by notifying the remote server of corruption, and then
720         removing the remote peer from further activity.
721         """
722         assert isinstance(readers, list)
723         bad_shnums = [reader.shnum for reader in readers]
724
725         self.log("validation or decoding failed on share(s) %s, peer(s) %s "
726                  ", segment %d: %s" % \
727                  (bad_shnums, readers, self._current_segment, str(f)))
728         for reader in readers:
729             self._mark_bad_share(reader, f)
730         return
731
732
733     def _validate_block(self, results, segnum, reader, started):
734         """
735         I validate a block from one share on a remote server.
736         """
737         # Grab the part of the block hash tree that is necessary to
738         # validate this block, then generate the block hash root.
739         self.log("validating share %d for segment %d" % (reader.shnum,
740                                                              segnum))
741         elapsed = time.time() - started
742         self._status.add_fetch_timing(reader.peerid, elapsed)
743         self._set_current_status("validating blocks")
744         # Did we fail to fetch either of the things that we were
745         # supposed to? Fail if so.
746         if not results[0][0] and results[1][0]:
747             # handled by the errback handler.
748
749             # These all get batched into one query, so the resulting
750             # failure should be the same for all of them, so we can just
751             # use the first one.
752             assert isinstance(results[0][1], failure.Failure)
753
754             f = results[0][1]
755             raise CorruptShareError(reader.peerid,
756                                     reader.shnum,
757                                     "Connection error: %s" % str(f))
758
759         block_and_salt, block_and_sharehashes = results
760         block, salt = block_and_salt[1]
761         blockhashes, sharehashes = block_and_sharehashes[1]
762
763         blockhashes = dict(enumerate(blockhashes[1]))
764         self.log("the reader gave me the following blockhashes: %s" % \
765                  blockhashes.keys())
766         self.log("the reader gave me the following sharehashes: %s" % \
767                  sharehashes[1].keys())
768         bht = self._block_hash_trees[reader.shnum]
769
770         if bht.needed_hashes(segnum, include_leaf=True):
771             try:
772                 bht.set_hashes(blockhashes)
773             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
774                     IndexError), e:
775                 raise CorruptShareError(reader.peerid,
776                                         reader.shnum,
777                                         "block hash tree failure: %s" % e)
778
779         if self._version == MDMF_VERSION:
780             blockhash = hashutil.block_hash(salt + block)
781         else:
782             blockhash = hashutil.block_hash(block)
783         # If this works without an error, then validation is
784         # successful.
785         try:
786            bht.set_hashes(leaves={segnum: blockhash})
787         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
788                 IndexError), e:
789             raise CorruptShareError(reader.peerid,
790                                     reader.shnum,
791                                     "block hash tree failure: %s" % e)
792
793         # Reaching this point means that we know that this segment
794         # is correct. Now we need to check to see whether the share
795         # hash chain is also correct. 
796         # SDMF wrote share hash chains that didn't contain the
797         # leaves, which would be produced from the block hash tree.
798         # So we need to validate the block hash tree first. If
799         # successful, then bht[0] will contain the root for the
800         # shnum, which will be a leaf in the share hash tree, which
801         # will allow us to validate the rest of the tree.
802         if self.share_hash_tree.needed_hashes(reader.shnum,
803                                               include_leaf=True) or \
804                                               self._verify:
805             try:
806                 self.share_hash_tree.set_hashes(hashes=sharehashes[1],
807                                             leaves={reader.shnum: bht[0]})
808             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
809                     IndexError), e:
810                 raise CorruptShareError(reader.peerid,
811                                         reader.shnum,
812                                         "corrupt hashes: %s" % e)
813
814         self.log('share %d is valid for segment %d' % (reader.shnum,
815                                                        segnum))
816         return {reader.shnum: (block, salt)}
817
818
819     def _get_needed_hashes(self, reader, segnum):
820         """
821         I get the hashes needed to validate segnum from the reader, then return
822         to my caller when this is done.
823         """
824         bht = self._block_hash_trees[reader.shnum]
825         needed = bht.needed_hashes(segnum, include_leaf=True)
826         # The root of the block hash tree is also a leaf in the share
827         # hash tree. So we don't need to fetch it from the remote
828         # server. In the case of files with one segment, this means that
829         # we won't fetch any block hash tree from the remote server,
830         # since the hash of each share of the file is the entire block
831         # hash tree, and is a leaf in the share hash tree. This is fine,
832         # since any share corruption will be detected in the share hash
833         # tree.
834         #needed.discard(0)
835         self.log("getting blockhashes for segment %d, share %d: %s" % \
836                  (segnum, reader.shnum, str(needed)))
837         d1 = reader.get_blockhashes(needed, force_remote=True)
838         if self.share_hash_tree.needed_hashes(reader.shnum):
839             need = self.share_hash_tree.needed_hashes(reader.shnum)
840             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
841                                                                  str(need)))
842             d2 = reader.get_sharehashes(need, force_remote=True)
843         else:
844             d2 = defer.succeed({}) # the logic in the next method
845                                    # expects a dict
846         dl = defer.DeferredList([d1, d2], consumeErrors=True)
847         return dl
848
849
850     def _decode_blocks(self, blocks_and_salts, segnum):
851         """
852         I take a list of k blocks and salts, and decode that into a
853         single encrypted segment.
854         """
855         d = {}
856         # We want to merge our dictionaries to the form 
857         # {shnum: blocks_and_salts}
858         #
859         # The dictionaries come from validate block that way, so we just
860         # need to merge them.
861         for block_and_salt in blocks_and_salts:
862             d.update(block_and_salt[1])
863
864         # All of these blocks should have the same salt; in SDMF, it is
865         # the file-wide IV, while in MDMF it is the per-segment salt. In
866         # either case, we just need to get one of them and use it.
867         #
868         # d.items()[0] is like (shnum, (block, salt))
869         # d.items()[0][1] is like (block, salt)
870         # d.items()[0][1][1] is the salt.
871         salt = d.items()[0][1][1]
872         # Next, extract just the blocks from the dict. We'll use the
873         # salt in the next step.
874         share_and_shareids = [(k, v[0]) for k, v in d.items()]
875         d2 = dict(share_and_shareids)
876         shareids = []
877         shares = []
878         for shareid, share in d2.items():
879             shareids.append(shareid)
880             shares.append(share)
881
882         self._set_current_status("decoding")
883         started = time.time()
884         assert len(shareids) >= self._required_shares, len(shareids)
885         # zfec really doesn't want extra shares
886         shareids = shareids[:self._required_shares]
887         shares = shares[:self._required_shares]
888         self.log("decoding segment %d" % segnum)
889         if segnum == self._num_segments - 1:
890             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
891         else:
892             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
893         def _process(buffers):
894             segment = "".join(buffers)
895             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
896                      segnum=segnum,
897                      numsegs=self._num_segments,
898                      level=log.NOISY)
899             self.log(" joined length %d, datalength %d" %
900                      (len(segment), self._data_length))
901             if segnum == self._num_segments - 1:
902                 size_to_use = self._tail_data_size
903             else:
904                 size_to_use = self._segment_size
905             segment = segment[:size_to_use]
906             self.log(" segment len=%d" % len(segment))
907             self._status.accumulate_decode_time(time.time() - started)
908             return segment, salt
909         d.addCallback(_process)
910         return d
911
912
913     def _decrypt_segment(self, segment_and_salt):
914         """
915         I take a single segment and its salt, and decrypt it. I return
916         the plaintext of the segment that is in my argument.
917         """
918         segment, salt = segment_and_salt
919         self._set_current_status("decrypting")
920         self.log("decrypting segment %d" % self._current_segment)
921         started = time.time()
922         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
923         decryptor = AES(key)
924         plaintext = decryptor.process(segment)
925         self._status.accumulate_decrypt_time(time.time() - started)
926         return plaintext
927
928
929     def notify_server_corruption(self, peerid, shnum, reason):
930         ss = self.servermap.connections[peerid]
931         ss.callRemoteOnly("advise_corrupt_share",
932                           "mutable", self._storage_index, shnum, reason)
933
934
935     def _try_to_validate_privkey(self, enc_privkey, reader):
936         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
937         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
938         if alleged_writekey != self._node.get_writekey():
939             self.log("invalid privkey from %s shnum %d" %
940                      (reader, reader.shnum),
941                      level=log.WEIRD, umid="YIw4tA")
942             if self._verify:
943                 self.servermap.mark_bad_share(reader.peerid, reader.shnum,
944                                               self.verinfo[-2])
945                 e = CorruptShareError(reader.peerid,
946                                       reader.shnum,
947                                       "invalid privkey")
948                 f = failure.Failure(e)
949                 self._bad_shares.add((reader.peerid, reader.shnum, f))
950             return
951
952         # it's good
953         self.log("got valid privkey from shnum %d on reader %s" %
954                  (reader.shnum, reader))
955         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
956         self._node._populate_encprivkey(enc_privkey)
957         self._node._populate_privkey(privkey)
958         self._need_privkey = False
959
960
961
962     def _done(self):
963         """
964         I am called by _download_current_segment when the download process
965         has finished successfully. After making some useful logging
966         statements, I return the decrypted contents to the owner of this
967         Retrieve object through self._done_deferred.
968         """
969         self._running = False
970         self._status.set_active(False)
971         now = time.time()
972         self._status.timings['total'] = now - self._started
973         self._status.timings['fetch'] = now - self._started_fetching
974         self._status.set_status("Finished")
975         self._status.set_progress(1.0)
976
977         # remember the encoding parameters, use them again next time
978         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
979          offsets_tuple) = self.verinfo
980         self._node._populate_required_shares(k)
981         self._node._populate_total_shares(N)
982
983         if self._verify:
984             ret = list(self._bad_shares)
985             self.log("done verifying, found %d bad shares" % len(ret))
986         else:
987             # TODO: upload status here?
988             ret = self._consumer
989             self._consumer.unregisterProducer()
990         eventually(self._done_deferred.callback, ret)
991
992
993     def _raise_notenoughshareserror(self):
994         """
995         I am called by _activate_enough_peers when there are not enough
996         active peers left to complete the download. After making some
997         useful logging statements, I throw an exception to that effect
998         to the caller of this Retrieve object through
999         self._done_deferred.
1000         """
1001
1002         format = ("ran out of peers: "
1003                   "have %(have)d of %(total)d segments "
1004                   "found %(bad)d bad shares "
1005                   "encoding %(k)d-of-%(n)d")
1006         args = {"have": self._current_segment,
1007                 "total": self._num_segments,
1008                 "need": self._last_segment,
1009                 "k": self._required_shares,
1010                 "n": self._total_shares,
1011                 "bad": len(self._bad_shares)}
1012         raise NotEnoughSharesError("%s, last failure: %s" %
1013                                    (format % args, str(self._last_failure)))
1014
1015     def _error(self, f):
1016         # all errors, including NotEnoughSharesError, land here
1017         self._running = False
1018         self._status.set_active(False)
1019         now = time.time()
1020         self._status.timings['total'] = now - self._started
1021         self._status.timings['fetch'] = now - self._started_fetching
1022         self._status.set_status("Failed")
1023         eventually(self._done_deferred.errback, f)