]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
Retrieve: implement/test stopProducing
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
10      DownloadStopped, MDMF_VERSION, SDMF_VERSION
11 from allmydata.util import hashutil, log, mathutil
12 from allmydata.util.dictutil import DictOfSets
13 from allmydata import hashtree, codec
14 from allmydata.storage.server import si_b2a
15 from pycryptopp.cipher.aes import AES
16 from pycryptopp.publickey import rsa
17
18 from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
19 from allmydata.mutable.layout import MDMFSlotReadProxy
20
21 class RetrieveStatus:
22     implements(IRetrieveStatus)
23     statusid_counter = count(0)
24     def __init__(self):
25         self.timings = {}
26         self.timings["fetch_per_server"] = {}
27         self.timings["decode"] = 0.0
28         self.timings["decrypt"] = 0.0
29         self.timings["cumulative_verify"] = 0.0
30         self.problems = {}
31         self.active = True
32         self.storage_index = None
33         self.helper = False
34         self.encoding = ("?","?")
35         self.size = None
36         self.status = "Not started"
37         self.progress = 0.0
38         self.counter = self.statusid_counter.next()
39         self.started = time.time()
40
41     def get_started(self):
42         return self.started
43     def get_storage_index(self):
44         return self.storage_index
45     def get_encoding(self):
46         return self.encoding
47     def using_helper(self):
48         return self.helper
49     def get_size(self):
50         return self.size
51     def get_status(self):
52         return self.status
53     def get_progress(self):
54         return self.progress
55     def get_active(self):
56         return self.active
57     def get_counter(self):
58         return self.counter
59
60     def add_fetch_timing(self, peerid, elapsed):
61         if peerid not in self.timings["fetch_per_server"]:
62             self.timings["fetch_per_server"][peerid] = []
63         self.timings["fetch_per_server"][peerid].append(elapsed)
64     def accumulate_decode_time(self, elapsed):
65         self.timings["decode"] += elapsed
66     def accumulate_decrypt_time(self, elapsed):
67         self.timings["decrypt"] += elapsed
68     def set_storage_index(self, si):
69         self.storage_index = si
70     def set_helper(self, helper):
71         self.helper = helper
72     def set_encoding(self, k, n):
73         self.encoding = (k, n)
74     def set_size(self, size):
75         self.size = size
76     def set_status(self, status):
77         self.status = status
78     def set_progress(self, value):
79         self.progress = value
80     def set_active(self, value):
81         self.active = value
82
83 class Marker:
84     pass
85
86 class Retrieve:
87     # this class is currently single-use. Eventually (in MDMF) we will make
88     # it multi-use, in which case you can call download(range) multiple
89     # times, and each will have a separate response chain. However the
90     # Retrieve object will remain tied to a specific version of the file, and
91     # will use a single ServerMap instance.
92     implements(IPushProducer)
93
94     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False,
95                  verify=False):
96         self._node = filenode
97         assert self._node.get_pubkey()
98         self._storage_index = filenode.get_storage_index()
99         assert self._node.get_readkey()
100         self._last_failure = None
101         prefix = si_b2a(self._storage_index)[:5]
102         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
103         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
104         self._running = True
105         self._decoding = False
106         self._bad_shares = set()
107
108         self.servermap = servermap
109         assert self._node.get_pubkey()
110         self.verinfo = verinfo
111         # during repair, we may be called upon to grab the private key, since
112         # it wasn't picked up during a verify=False checker run, and we'll
113         # need it for repair to generate a new version.
114         self._need_privkey = verify or (fetch_privkey
115                                         and not self._node.get_privkey())
116
117         if self._need_privkey:
118             # TODO: Evaluate the need for this. We'll use it if we want
119             # to limit how many queries are on the wire for the privkey
120             # at once.
121             self._privkey_query_markers = [] # one Marker for each time we've
122                                              # tried to get the privkey.
123
124         # verify means that we are using the downloader logic to verify all
125         # of our shares. This tells the downloader a few things.
126         # 
127         # 1. We need to download all of the shares.
128         # 2. We don't need to decode or decrypt the shares, since our
129         #    caller doesn't care about the plaintext, only the
130         #    information about which shares are or are not valid.
131         # 3. When we are validating readers, we need to validate the
132         #    signature on the prefix. Do we? We already do this in the
133         #    servermap update?
134         self._verify = verify
135
136         self._status = RetrieveStatus()
137         self._status.set_storage_index(self._storage_index)
138         self._status.set_helper(False)
139         self._status.set_progress(0.0)
140         self._status.set_active(True)
141         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
142          offsets_tuple) = self.verinfo
143         self._status.set_size(datalength)
144         self._status.set_encoding(k, N)
145         self.readers = {}
146         self._stopped = False
147         self._pause_deferred = None
148         self._offset = None
149         self._read_length = None
150         self.log("got seqnum %d" % self.verinfo[0])
151
152
153     def get_status(self):
154         return self._status
155
156     def log(self, *args, **kwargs):
157         if "parent" not in kwargs:
158             kwargs["parent"] = self._log_number
159         if "facility" not in kwargs:
160             kwargs["facility"] = "tahoe.mutable.retrieve"
161         return log.msg(*args, **kwargs)
162
163     def _set_current_status(self, state):
164         seg = "%d/%d" % (self._current_segment, self._last_segment)
165         self._status.set_status("segment %s (%s)" % (seg, state))
166
167     ###################
168     # IPushProducer
169
170     def pauseProducing(self):
171         """
172         I am called by my download target if we have produced too much
173         data for it to handle. I make the downloader stop producing new
174         data until my resumeProducing method is called.
175         """
176         if self._pause_deferred is not None:
177             return
178
179         # fired when the download is unpaused. 
180         self._old_status = self._status.get_status()
181         self._set_current_status("paused")
182
183         self._pause_deferred = defer.Deferred()
184
185
186     def resumeProducing(self):
187         """
188         I am called by my download target once it is ready to begin
189         receiving data again.
190         """
191         if self._pause_deferred is None:
192             return
193
194         p = self._pause_deferred
195         self._pause_deferred = None
196         self._status.set_status(self._old_status)
197
198         eventually(p.callback, None)
199
200     def stopProducing(self):
201         self._stopped = True
202         self.resumeProducing()
203
204
205     def _check_for_paused(self, res):
206         """
207         I am called just before a write to the consumer. I return a
208         Deferred that eventually fires with the data that is to be
209         written to the consumer. If the download has not been paused,
210         the Deferred fires immediately. Otherwise, the Deferred fires
211         when the downloader is unpaused.
212         """
213         if self._stopped:
214             raise DownloadStopped("our Consumer called stopProducing()")
215         if self._pause_deferred is not None:
216             d = defer.Deferred()
217             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
218             return d
219         return defer.succeed(res)
220
221
222     def download(self, consumer=None, offset=0, size=None):
223         assert IConsumer.providedBy(consumer) or self._verify
224
225         if consumer:
226             self._consumer = consumer
227             # we provide IPushProducer, so streaming=True, per
228             # IConsumer.
229             self._consumer.registerProducer(self, streaming=True)
230
231         self._done_deferred = defer.Deferred()
232         self._offset = offset
233         self._read_length = size
234         self._setup_download()
235         self._setup_encoding_parameters()
236         self.log("starting download")
237         self._started_fetching = time.time()
238         # The download process beyond this is a state machine.
239         # _add_active_peers will select the peers that we want to use
240         # for the download, and then attempt to start downloading. After
241         # each segment, it will check for doneness, reacting to broken
242         # peers and corrupt shares as necessary. If it runs out of good
243         # peers before downloading all of the segments, _done_deferred
244         # will errback.  Otherwise, it will eventually callback with the
245         # contents of the mutable file.
246         self.loop()
247         return self._done_deferred
248
249     def loop(self):
250         d = fireEventually(None) # avoid #237 recursion limit problem
251         d.addCallback(lambda ign: self._activate_enough_peers())
252         d.addCallback(lambda ign: self._download_current_segment())
253         # when we're done, _download_current_segment will call _done. If we
254         # aren't, it will call loop() again.
255         d.addErrback(self._error)
256
257     def _setup_download(self):
258         self._started = time.time()
259         self._status.set_status("Retrieving Shares")
260
261         # how many shares do we need?
262         (seqnum,
263          root_hash,
264          IV,
265          segsize,
266          datalength,
267          k,
268          N,
269          prefix,
270          offsets_tuple) = self.verinfo
271
272         # first, which servers can we use?
273         versionmap = self.servermap.make_versionmap()
274         shares = versionmap[self.verinfo]
275         # this sharemap is consumed as we decide to send requests
276         self.remaining_sharemap = DictOfSets()
277         for (shnum, peerid, timestamp) in shares:
278             self.remaining_sharemap.add(shnum, peerid)
279             # If the servermap update fetched anything, it fetched at least 1
280             # KiB, so we ask for that much.
281             # TODO: Change the cache methods to allow us to fetch all of the
282             # data that they have, then change this method to do that.
283             any_cache = self._node._read_from_cache(self.verinfo, shnum,
284                                                     0, 1000)
285             ss = self.servermap.connections[peerid]
286             reader = MDMFSlotReadProxy(ss,
287                                        self._storage_index,
288                                        shnum,
289                                        any_cache)
290             reader.peerid = peerid
291             self.readers[shnum] = reader
292         assert len(self.remaining_sharemap) >= k
293
294         self.shares = {} # maps shnum to validated blocks
295         self._active_readers = [] # list of active readers for this dl.
296         self._block_hash_trees = {} # shnum => hashtree
297
298         # We need one share hash tree for the entire file; its leaves
299         # are the roots of the block hash trees for the shares that
300         # comprise it, and its root is in the verinfo.
301         self.share_hash_tree = hashtree.IncompleteHashTree(N)
302         self.share_hash_tree.set_hashes({0: root_hash})
303
304     def decode(self, blocks_and_salts, segnum):
305         """
306         I am a helper method that the mutable file update process uses
307         as a shortcut to decode and decrypt the segments that it needs
308         to fetch in order to perform a file update. I take in a
309         collection of blocks and salts, and pick some of those to make a
310         segment with. I return the plaintext associated with that
311         segment.
312         """
313         # shnum => block hash tree. Unused, but setup_encoding_parameters will
314         # want to set this.
315         self._block_hash_trees = None
316         self._setup_encoding_parameters()
317
318         # This is the form expected by decode.
319         blocks_and_salts = blocks_and_salts.items()
320         blocks_and_salts = [(True, [d]) for d in blocks_and_salts]
321
322         d = self._decode_blocks(blocks_and_salts, segnum)
323         d.addCallback(self._decrypt_segment)
324         return d
325
326
327     def _setup_encoding_parameters(self):
328         """
329         I set up the encoding parameters, including k, n, the number
330         of segments associated with this file, and the segment decoders.
331         """
332         (seqnum,
333          root_hash,
334          IV,
335          segsize,
336          datalength,
337          k,
338          n,
339          known_prefix,
340          offsets_tuple) = self.verinfo
341         self._required_shares = k
342         self._total_shares = n
343         self._segment_size = segsize
344         self._data_length = datalength
345
346         if not IV:
347             self._version = MDMF_VERSION
348         else:
349             self._version = SDMF_VERSION
350
351         if datalength and segsize:
352             self._num_segments = mathutil.div_ceil(datalength, segsize)
353             self._tail_data_size = datalength % segsize
354         else:
355             self._num_segments = 0
356             self._tail_data_size = 0
357
358         self._segment_decoder = codec.CRSDecoder()
359         self._segment_decoder.set_params(segsize, k, n)
360
361         if  not self._tail_data_size:
362             self._tail_data_size = segsize
363
364         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
365                                                          self._required_shares)
366         if self._tail_segment_size == self._segment_size:
367             self._tail_decoder = self._segment_decoder
368         else:
369             self._tail_decoder = codec.CRSDecoder()
370             self._tail_decoder.set_params(self._tail_segment_size,
371                                           self._required_shares,
372                                           self._total_shares)
373
374         self.log("got encoding parameters: "
375                  "k: %d "
376                  "n: %d "
377                  "%d segments of %d bytes each (%d byte tail segment)" % \
378                  (k, n, self._num_segments, self._segment_size,
379                   self._tail_segment_size))
380
381         if self._block_hash_trees is not None:
382             for i in xrange(self._total_shares):
383                 # So we don't have to do this later.
384                 self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
385
386         # Our last task is to tell the downloader where to start and
387         # where to stop. We use three parameters for that:
388         #   - self._start_segment: the segment that we need to start
389         #     downloading from. 
390         #   - self._current_segment: the next segment that we need to
391         #     download.
392         #   - self._last_segment: The last segment that we were asked to
393         #     download.
394         #
395         #  We say that the download is complete when
396         #  self._current_segment > self._last_segment. We use
397         #  self._start_segment and self._last_segment to know when to
398         #  strip things off of segments, and how much to strip.
399         if self._offset:
400             self.log("got offset: %d" % self._offset)
401             # our start segment is the first segment containing the
402             # offset we were given. 
403             start = self._offset // self._segment_size
404
405             assert start < self._num_segments
406             self._start_segment = start
407             self.log("got start segment: %d" % self._start_segment)
408         else:
409             self._start_segment = 0
410
411
412         # If self._read_length is None, then we want to read the whole
413         # file. Otherwise, we want to read only part of the file, and
414         # need to figure out where to stop reading.
415         if self._read_length is not None:
416             # our end segment is the last segment containing part of the
417             # segment that we were asked to read.
418             self.log("got read length %d" % self._read_length)
419             if self._read_length != 0:
420                 end_data = self._offset + self._read_length
421
422                 # We don't actually need to read the byte at end_data,
423                 # but the one before it.
424                 end = (end_data - 1) // self._segment_size
425
426                 assert end < self._num_segments
427                 self._last_segment = end
428             else:
429                 self._last_segment = self._start_segment
430             self.log("got end segment: %d" % self._last_segment)
431         else:
432             self._last_segment = self._num_segments - 1
433
434         self._current_segment = self._start_segment
435
436     def _activate_enough_peers(self):
437         """
438         I populate self._active_readers with enough active readers to
439         retrieve the contents of this mutable file. I am called before
440         downloading starts, and (eventually) after each validation
441         error, connection error, or other problem in the download.
442         """
443         # TODO: It would be cool to investigate other heuristics for
444         # reader selection. For instance, the cost (in time the user
445         # spends waiting for their file) of selecting a really slow peer
446         # that happens to have a primary share is probably more than
447         # selecting a really fast peer that doesn't have a primary
448         # share. Maybe the servermap could be extended to provide this
449         # information; it could keep track of latency information while
450         # it gathers more important data, and then this routine could
451         # use that to select active readers.
452         #
453         # (these and other questions would be easier to answer with a
454         #  robust, configurable tahoe-lafs simulator, which modeled node
455         #  failures, differences in node speed, and other characteristics
456         #  that we expect storage servers to have.  You could have
457         #  presets for really stable grids (like allmydata.com),
458         #  friendnets, make it easy to configure your own settings, and
459         #  then simulate the effect of big changes on these use cases
460         #  instead of just reasoning about what the effect might be. Out
461         #  of scope for MDMF, though.)
462
463         # We need at least self._required_shares readers to download a
464         # segment. If we're verifying, we need all shares.
465         if self._verify:
466             needed = self._total_shares
467         else:
468             needed = self._required_shares
469         # XXX: Why don't format= log messages work here?
470         self.log("adding %d peers to the active peers list" % needed)
471
472         if len(self._active_readers) >= needed:
473             # enough shares are active
474             return
475
476         more = needed - len(self._active_readers)
477         known_shnums = set(self.remaining_sharemap.keys())
478         used_shnums = set([r.shnum for r in self._active_readers])
479         unused_shnums = known_shnums - used_shnums
480         # We favor lower numbered shares, since FEC is faster with
481         # primary shares than with other shares, and lower-numbered
482         # shares are more likely to be primary than higher numbered
483         # shares.
484         new_shnums = sorted(unused_shnums)[:more]
485         if len(new_shnums) < more and not self._verify:
486             # We don't have enough readers to retrieve the file; fail.
487             self._raise_notenoughshareserror()
488
489         for shnum in new_shnums:
490             reader = self.readers[shnum]
491             self._active_readers.append(reader)
492             self.log("added reader for share %d" % shnum)
493             # Each time we add a reader, we check to see if we need the
494             # private key. If we do, we politely ask for it and then continue
495             # computing. If we find that we haven't gotten it at the end of
496             # segment decoding, then we'll take more drastic measures.
497             if self._need_privkey and not self._node.is_readonly():
498                 d = reader.get_encprivkey()
499                 d.addCallback(self._try_to_validate_privkey, reader)
500                 # XXX: don't just drop the Deferred. We need error-reporting
501                 # but not flow-control here.
502         assert len(self._active_readers) >= self._required_shares
503
504     def _try_to_validate_prefix(self, prefix, reader):
505         """
506         I check that the prefix returned by a candidate server for
507         retrieval matches the prefix that the servermap knows about
508         (and, hence, the prefix that was validated earlier). If it does,
509         I return True, which means that I approve of the use of the
510         candidate server for segment retrieval. If it doesn't, I return
511         False, which means that another server must be chosen.
512         """
513         (seqnum,
514          root_hash,
515          IV,
516          segsize,
517          datalength,
518          k,
519          N,
520          known_prefix,
521          offsets_tuple) = self.verinfo
522         if known_prefix != prefix:
523             self.log("prefix from share %d doesn't match" % reader.shnum)
524             raise UncoordinatedWriteError("Mismatched prefix -- this could "
525                                           "indicate an uncoordinated write")
526         # Otherwise, we're okay -- no issues.
527
528
529     def _remove_reader(self, reader):
530         """
531         At various points, we will wish to remove a peer from
532         consideration and/or use. These include, but are not necessarily
533         limited to:
534
535             - A connection error.
536             - A mismatched prefix (that is, a prefix that does not match
537               our conception of the version information string).
538             - A failing block hash, salt hash, or share hash, which can
539               indicate disk failure/bit flips, or network trouble.
540
541         This method will do that. I will make sure that the
542         (shnum,reader) combination represented by my reader argument is
543         not used for anything else during this download. I will not
544         advise the reader of any corruption, something that my callers
545         may wish to do on their own.
546         """
547         # TODO: When you're done writing this, see if this is ever
548         # actually used for something that _mark_bad_share isn't. I have
549         # a feeling that they will be used for very similar things, and
550         # that having them both here is just going to be an epic amount
551         # of code duplication.
552         #
553         # (well, okay, not epic, but meaningful)
554         self.log("removing reader %s" % reader)
555         # Remove the reader from _active_readers
556         self._active_readers.remove(reader)
557         # TODO: self.readers.remove(reader)?
558         for shnum in list(self.remaining_sharemap.keys()):
559             self.remaining_sharemap.discard(shnum, reader.peerid)
560
561
562     def _mark_bad_share(self, reader, f):
563         """
564         I mark the (peerid, shnum) encapsulated by my reader argument as
565         a bad share, which means that it will not be used anywhere else.
566
567         There are several reasons to want to mark something as a bad
568         share. These include:
569
570             - A connection error to the peer.
571             - A mismatched prefix (that is, a prefix that does not match
572               our local conception of the version information string).
573             - A failing block hash, salt hash, share hash, or other
574               integrity check.
575
576         This method will ensure that readers that we wish to mark bad
577         (for these reasons or other reasons) are not used for the rest
578         of the download. Additionally, it will attempt to tell the
579         remote peer (with no guarantee of success) that its share is
580         corrupt.
581         """
582         self.log("marking share %d on server %s as bad" % \
583                  (reader.shnum, reader))
584         prefix = self.verinfo[-2]
585         self.servermap.mark_bad_share(reader.peerid,
586                                       reader.shnum,
587                                       prefix)
588         self._remove_reader(reader)
589         self._bad_shares.add((reader.peerid, reader.shnum, f))
590         self._status.problems[reader.peerid] = f
591         self._last_failure = f
592         self.notify_server_corruption(reader.peerid, reader.shnum,
593                                       str(f.value))
594
595
596     def _download_current_segment(self):
597         """
598         I download, validate, decode, decrypt, and assemble the segment
599         that this Retrieve is currently responsible for downloading.
600         """
601         assert len(self._active_readers) >= self._required_shares
602         if self._current_segment > self._last_segment:
603             # No more segments to download, we're done.
604             self.log("got plaintext, done")
605             return self._done()
606         self.log("on segment %d of %d" %
607                  (self._current_segment + 1, self._num_segments))
608         d = self._process_segment(self._current_segment)
609         d.addCallback(lambda ign: self.loop())
610         return d
611
612     def _process_segment(self, segnum):
613         """
614         I download, validate, decode, and decrypt one segment of the
615         file that this Retrieve is retrieving. This means coordinating
616         the process of getting k blocks of that file, validating them,
617         assembling them into one segment with the decoder, and then
618         decrypting them.
619         """
620         self.log("processing segment %d" % segnum)
621
622         # TODO: The old code uses a marker. Should this code do that
623         # too? What did the Marker do?
624         assert len(self._active_readers) >= self._required_shares
625
626         # We need to ask each of our active readers for its block and
627         # salt. We will then validate those. If validation is
628         # successful, we will assemble the results into plaintext.
629         ds = []
630         for reader in self._active_readers:
631             started = time.time()
632             d = reader.get_block_and_salt(segnum)
633             d2 = self._get_needed_hashes(reader, segnum)
634             dl = defer.DeferredList([d, d2], consumeErrors=True)
635             dl.addCallback(self._validate_block, segnum, reader, started)
636             dl.addErrback(self._validation_or_decoding_failed, [reader])
637             ds.append(dl)
638         dl = defer.DeferredList(ds)
639         if self._verify:
640             dl.addCallback(lambda ignored: "")
641             dl.addCallback(self._set_segment)
642         else:
643             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
644         return dl
645
646
647     def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum):
648         """
649         I take the results of fetching and validating the blocks from a
650         callback chain in another method. If the results are such that
651         they tell me that validation and fetching succeeded without
652         incident, I will proceed with decoding and decryption.
653         Otherwise, I will do nothing.
654         """
655         self.log("trying to decode and decrypt segment %d" % segnum)
656         failures = False
657         for block_and_salt in blocks_and_salts:
658             if not block_and_salt[0] or block_and_salt[1] == None:
659                 self.log("some validation operations failed; not proceeding")
660                 failures = True
661                 break
662         if not failures:
663             self.log("everything looks ok, building segment %d" % segnum)
664             d = self._decode_blocks(blocks_and_salts, segnum)
665             d.addCallback(self._decrypt_segment)
666             d.addErrback(self._validation_or_decoding_failed,
667                          self._active_readers)
668             # check to see whether we've been paused before writing
669             # anything.
670             d.addCallback(self._check_for_paused)
671             d.addCallback(self._set_segment)
672             return d
673         else:
674             return defer.succeed(None)
675
676
677     def _set_segment(self, segment):
678         """
679         Given a plaintext segment, I register that segment with the
680         target that is handling the file download.
681         """
682         self.log("got plaintext for segment %d" % self._current_segment)
683         if self._current_segment == self._start_segment:
684             # We're on the first segment. It's possible that we want
685             # only some part of the end of this segment, and that we
686             # just downloaded the whole thing to get that part. If so,
687             # we need to account for that and give the reader just the
688             # data that they want.
689             n = self._offset % self._segment_size
690             self.log("stripping %d bytes off of the first segment" % n)
691             self.log("original segment length: %d" % len(segment))
692             segment = segment[n:]
693             self.log("new segment length: %d" % len(segment))
694
695         if self._current_segment == self._last_segment and self._read_length is not None:
696             # We're on the last segment. It's possible that we only want
697             # part of the beginning of this segment, and that we
698             # downloaded the whole thing anyway. Make sure to give the
699             # caller only the portion of the segment that they want to
700             # receive.
701             extra = self._read_length
702             if self._start_segment != self._last_segment:
703                 extra -= self._segment_size - \
704                             (self._offset % self._segment_size)
705             extra %= self._segment_size
706             self.log("original segment length: %d" % len(segment))
707             segment = segment[:extra]
708             self.log("new segment length: %d" % len(segment))
709             self.log("only taking %d bytes of the last segment" % extra)
710
711         if not self._verify:
712             self._consumer.write(segment)
713         else:
714             # we don't care about the plaintext if we are doing a verify.
715             segment = None
716         self._current_segment += 1
717
718
719     def _validation_or_decoding_failed(self, f, readers):
720         """
721         I am called when a block or a salt fails to correctly validate, or when
722         the decryption or decoding operation fails for some reason.  I react to
723         this failure by notifying the remote server of corruption, and then
724         removing the remote peer from further activity.
725         """
726         assert isinstance(readers, list)
727         bad_shnums = [reader.shnum for reader in readers]
728
729         self.log("validation or decoding failed on share(s) %s, peer(s) %s "
730                  ", segment %d: %s" % \
731                  (bad_shnums, readers, self._current_segment, str(f)))
732         for reader in readers:
733             self._mark_bad_share(reader, f)
734         return
735
736
737     def _validate_block(self, results, segnum, reader, started):
738         """
739         I validate a block from one share on a remote server.
740         """
741         # Grab the part of the block hash tree that is necessary to
742         # validate this block, then generate the block hash root.
743         self.log("validating share %d for segment %d" % (reader.shnum,
744                                                              segnum))
745         elapsed = time.time() - started
746         self._status.add_fetch_timing(reader.peerid, elapsed)
747         self._set_current_status("validating blocks")
748         # Did we fail to fetch either of the things that we were
749         # supposed to? Fail if so.
750         if not results[0][0] and results[1][0]:
751             # handled by the errback handler.
752
753             # These all get batched into one query, so the resulting
754             # failure should be the same for all of them, so we can just
755             # use the first one.
756             assert isinstance(results[0][1], failure.Failure)
757
758             f = results[0][1]
759             raise CorruptShareError(reader.peerid,
760                                     reader.shnum,
761                                     "Connection error: %s" % str(f))
762
763         block_and_salt, block_and_sharehashes = results
764         block, salt = block_and_salt[1]
765         blockhashes, sharehashes = block_and_sharehashes[1]
766
767         blockhashes = dict(enumerate(blockhashes[1]))
768         self.log("the reader gave me the following blockhashes: %s" % \
769                  blockhashes.keys())
770         self.log("the reader gave me the following sharehashes: %s" % \
771                  sharehashes[1].keys())
772         bht = self._block_hash_trees[reader.shnum]
773
774         if bht.needed_hashes(segnum, include_leaf=True):
775             try:
776                 bht.set_hashes(blockhashes)
777             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
778                     IndexError), e:
779                 raise CorruptShareError(reader.peerid,
780                                         reader.shnum,
781                                         "block hash tree failure: %s" % e)
782
783         if self._version == MDMF_VERSION:
784             blockhash = hashutil.block_hash(salt + block)
785         else:
786             blockhash = hashutil.block_hash(block)
787         # If this works without an error, then validation is
788         # successful.
789         try:
790            bht.set_hashes(leaves={segnum: blockhash})
791         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
792                 IndexError), e:
793             raise CorruptShareError(reader.peerid,
794                                     reader.shnum,
795                                     "block hash tree failure: %s" % e)
796
797         # Reaching this point means that we know that this segment
798         # is correct. Now we need to check to see whether the share
799         # hash chain is also correct. 
800         # SDMF wrote share hash chains that didn't contain the
801         # leaves, which would be produced from the block hash tree.
802         # So we need to validate the block hash tree first. If
803         # successful, then bht[0] will contain the root for the
804         # shnum, which will be a leaf in the share hash tree, which
805         # will allow us to validate the rest of the tree.
806         if self.share_hash_tree.needed_hashes(reader.shnum,
807                                               include_leaf=True) or \
808                                               self._verify:
809             try:
810                 self.share_hash_tree.set_hashes(hashes=sharehashes[1],
811                                             leaves={reader.shnum: bht[0]})
812             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
813                     IndexError), e:
814                 raise CorruptShareError(reader.peerid,
815                                         reader.shnum,
816                                         "corrupt hashes: %s" % e)
817
818         self.log('share %d is valid for segment %d' % (reader.shnum,
819                                                        segnum))
820         return {reader.shnum: (block, salt)}
821
822
823     def _get_needed_hashes(self, reader, segnum):
824         """
825         I get the hashes needed to validate segnum from the reader, then return
826         to my caller when this is done.
827         """
828         bht = self._block_hash_trees[reader.shnum]
829         needed = bht.needed_hashes(segnum, include_leaf=True)
830         # The root of the block hash tree is also a leaf in the share
831         # hash tree. So we don't need to fetch it from the remote
832         # server. In the case of files with one segment, this means that
833         # we won't fetch any block hash tree from the remote server,
834         # since the hash of each share of the file is the entire block
835         # hash tree, and is a leaf in the share hash tree. This is fine,
836         # since any share corruption will be detected in the share hash
837         # tree.
838         #needed.discard(0)
839         self.log("getting blockhashes for segment %d, share %d: %s" % \
840                  (segnum, reader.shnum, str(needed)))
841         d1 = reader.get_blockhashes(needed, force_remote=True)
842         if self.share_hash_tree.needed_hashes(reader.shnum):
843             need = self.share_hash_tree.needed_hashes(reader.shnum)
844             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
845                                                                  str(need)))
846             d2 = reader.get_sharehashes(need, force_remote=True)
847         else:
848             d2 = defer.succeed({}) # the logic in the next method
849                                    # expects a dict
850         dl = defer.DeferredList([d1, d2], consumeErrors=True)
851         return dl
852
853
854     def _decode_blocks(self, blocks_and_salts, segnum):
855         """
856         I take a list of k blocks and salts, and decode that into a
857         single encrypted segment.
858         """
859         d = {}
860         # We want to merge our dictionaries to the form 
861         # {shnum: blocks_and_salts}
862         #
863         # The dictionaries come from validate block that way, so we just
864         # need to merge them.
865         for block_and_salt in blocks_and_salts:
866             d.update(block_and_salt[1])
867
868         # All of these blocks should have the same salt; in SDMF, it is
869         # the file-wide IV, while in MDMF it is the per-segment salt. In
870         # either case, we just need to get one of them and use it.
871         #
872         # d.items()[0] is like (shnum, (block, salt))
873         # d.items()[0][1] is like (block, salt)
874         # d.items()[0][1][1] is the salt.
875         salt = d.items()[0][1][1]
876         # Next, extract just the blocks from the dict. We'll use the
877         # salt in the next step.
878         share_and_shareids = [(k, v[0]) for k, v in d.items()]
879         d2 = dict(share_and_shareids)
880         shareids = []
881         shares = []
882         for shareid, share in d2.items():
883             shareids.append(shareid)
884             shares.append(share)
885
886         self._set_current_status("decoding")
887         started = time.time()
888         assert len(shareids) >= self._required_shares, len(shareids)
889         # zfec really doesn't want extra shares
890         shareids = shareids[:self._required_shares]
891         shares = shares[:self._required_shares]
892         self.log("decoding segment %d" % segnum)
893         if segnum == self._num_segments - 1:
894             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
895         else:
896             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
897         def _process(buffers):
898             segment = "".join(buffers)
899             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
900                      segnum=segnum,
901                      numsegs=self._num_segments,
902                      level=log.NOISY)
903             self.log(" joined length %d, datalength %d" %
904                      (len(segment), self._data_length))
905             if segnum == self._num_segments - 1:
906                 size_to_use = self._tail_data_size
907             else:
908                 size_to_use = self._segment_size
909             segment = segment[:size_to_use]
910             self.log(" segment len=%d" % len(segment))
911             self._status.accumulate_decode_time(time.time() - started)
912             return segment, salt
913         d.addCallback(_process)
914         return d
915
916
917     def _decrypt_segment(self, segment_and_salt):
918         """
919         I take a single segment and its salt, and decrypt it. I return
920         the plaintext of the segment that is in my argument.
921         """
922         segment, salt = segment_and_salt
923         self._set_current_status("decrypting")
924         self.log("decrypting segment %d" % self._current_segment)
925         started = time.time()
926         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
927         decryptor = AES(key)
928         plaintext = decryptor.process(segment)
929         self._status.accumulate_decrypt_time(time.time() - started)
930         return plaintext
931
932
933     def notify_server_corruption(self, peerid, shnum, reason):
934         ss = self.servermap.connections[peerid]
935         ss.callRemoteOnly("advise_corrupt_share",
936                           "mutable", self._storage_index, shnum, reason)
937
938
939     def _try_to_validate_privkey(self, enc_privkey, reader):
940         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
941         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
942         if alleged_writekey != self._node.get_writekey():
943             self.log("invalid privkey from %s shnum %d" %
944                      (reader, reader.shnum),
945                      level=log.WEIRD, umid="YIw4tA")
946             if self._verify:
947                 self.servermap.mark_bad_share(reader.peerid, reader.shnum,
948                                               self.verinfo[-2])
949                 e = CorruptShareError(reader.peerid,
950                                       reader.shnum,
951                                       "invalid privkey")
952                 f = failure.Failure(e)
953                 self._bad_shares.add((reader.peerid, reader.shnum, f))
954             return
955
956         # it's good
957         self.log("got valid privkey from shnum %d on reader %s" %
958                  (reader.shnum, reader))
959         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
960         self._node._populate_encprivkey(enc_privkey)
961         self._node._populate_privkey(privkey)
962         self._need_privkey = False
963
964
965
966     def _done(self):
967         """
968         I am called by _download_current_segment when the download process
969         has finished successfully. After making some useful logging
970         statements, I return the decrypted contents to the owner of this
971         Retrieve object through self._done_deferred.
972         """
973         self._running = False
974         self._status.set_active(False)
975         now = time.time()
976         self._status.timings['total'] = now - self._started
977         self._status.timings['fetch'] = now - self._started_fetching
978         self._status.set_status("Finished")
979         self._status.set_progress(1.0)
980
981         # remember the encoding parameters, use them again next time
982         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
983          offsets_tuple) = self.verinfo
984         self._node._populate_required_shares(k)
985         self._node._populate_total_shares(N)
986
987         if self._verify:
988             ret = list(self._bad_shares)
989             self.log("done verifying, found %d bad shares" % len(ret))
990         else:
991             # TODO: upload status here?
992             ret = self._consumer
993             self._consumer.unregisterProducer()
994         eventually(self._done_deferred.callback, ret)
995
996
997     def _raise_notenoughshareserror(self):
998         """
999         I am called by _activate_enough_peers when there are not enough
1000         active peers left to complete the download. After making some
1001         useful logging statements, I throw an exception to that effect
1002         to the caller of this Retrieve object through
1003         self._done_deferred.
1004         """
1005
1006         format = ("ran out of peers: "
1007                   "have %(have)d of %(total)d segments "
1008                   "found %(bad)d bad shares "
1009                   "encoding %(k)d-of-%(n)d")
1010         args = {"have": self._current_segment,
1011                 "total": self._num_segments,
1012                 "need": self._last_segment,
1013                 "k": self._required_shares,
1014                 "n": self._total_shares,
1015                 "bad": len(self._bad_shares)}
1016         raise NotEnoughSharesError("%s, last failure: %s" %
1017                                    (format % args, str(self._last_failure)))
1018
1019     def _error(self, f):
1020         # all errors, including NotEnoughSharesError, land here
1021         self._running = False
1022         self._status.set_active(False)
1023         now = time.time()
1024         self._status.timings['total'] = now - self._started
1025         self._status.timings['fetch'] = now - self._started_fetching
1026         self._status.set_status("Failed")
1027         eventually(self._done_deferred.errback, f)