]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
debugprint the values of blocks and hashes thereof; make the test data and the seg...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
10      DownloadStopped, MDMF_VERSION, SDMF_VERSION
11 from allmydata.util import hashutil, log, mathutil
12 from allmydata.util.dictutil import DictOfSets
13 from allmydata import hashtree, codec
14 from allmydata.storage.server import si_b2a
15 from pycryptopp.cipher.aes import AES
16 from pycryptopp.publickey import rsa
17
18 from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
19 from allmydata.mutable.layout import MDMFSlotReadProxy
20
21 class RetrieveStatus:
22     implements(IRetrieveStatus)
23     statusid_counter = count(0)
24     def __init__(self):
25         self.timings = {}
26         self.timings["fetch_per_server"] = {}
27         self.timings["decode"] = 0.0
28         self.timings["decrypt"] = 0.0
29         self.timings["cumulative_verify"] = 0.0
30         self.problems = {}
31         self.active = True
32         self.storage_index = None
33         self.helper = False
34         self.encoding = ("?","?")
35         self.size = None
36         self.status = "Not started"
37         self.progress = 0.0
38         self.counter = self.statusid_counter.next()
39         self.started = time.time()
40
41     def get_started(self):
42         return self.started
43     def get_storage_index(self):
44         return self.storage_index
45     def get_encoding(self):
46         return self.encoding
47     def using_helper(self):
48         return self.helper
49     def get_size(self):
50         return self.size
51     def get_status(self):
52         return self.status
53     def get_progress(self):
54         return self.progress
55     def get_active(self):
56         return self.active
57     def get_counter(self):
58         return self.counter
59
60     def add_fetch_timing(self, peerid, elapsed):
61         if peerid not in self.timings["fetch_per_server"]:
62             self.timings["fetch_per_server"][peerid] = []
63         self.timings["fetch_per_server"][peerid].append(elapsed)
64     def accumulate_decode_time(self, elapsed):
65         self.timings["decode"] += elapsed
66     def accumulate_decrypt_time(self, elapsed):
67         self.timings["decrypt"] += elapsed
68     def set_storage_index(self, si):
69         self.storage_index = si
70     def set_helper(self, helper):
71         self.helper = helper
72     def set_encoding(self, k, n):
73         self.encoding = (k, n)
74     def set_size(self, size):
75         self.size = size
76     def set_status(self, status):
77         self.status = status
78     def set_progress(self, value):
79         self.progress = value
80     def set_active(self, value):
81         self.active = value
82
83 class Marker:
84     pass
85
86 class Retrieve:
87     # this class is currently single-use. Eventually (in MDMF) we will make
88     # it multi-use, in which case you can call download(range) multiple
89     # times, and each will have a separate response chain. However the
90     # Retrieve object will remain tied to a specific version of the file, and
91     # will use a single ServerMap instance.
92     implements(IPushProducer)
93
94     def __init__(self, filenode, servermap, verinfo, fetch_privkey=False,
95                  verify=False):
96         self._node = filenode
97         assert self._node.get_pubkey()
98         self._storage_index = filenode.get_storage_index()
99         assert self._node.get_readkey()
100         self._last_failure = None
101         prefix = si_b2a(self._storage_index)[:5]
102         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
103         self._outstanding_queries = {} # maps (peerid,shnum) to start_time
104         self._running = True
105         self._decoding = False
106         self._bad_shares = set()
107
108         self.servermap = servermap
109         assert self._node.get_pubkey()
110         self.verinfo = verinfo
111         # during repair, we may be called upon to grab the private key, since
112         # it wasn't picked up during a verify=False checker run, and we'll
113         # need it for repair to generate a new version.
114         self._need_privkey = verify or (fetch_privkey
115                                         and not self._node.get_privkey())
116
117         if self._need_privkey:
118             # TODO: Evaluate the need for this. We'll use it if we want
119             # to limit how many queries are on the wire for the privkey
120             # at once.
121             self._privkey_query_markers = [] # one Marker for each time we've
122                                              # tried to get the privkey.
123
124         # verify means that we are using the downloader logic to verify all
125         # of our shares. This tells the downloader a few things.
126         # 
127         # 1. We need to download all of the shares.
128         # 2. We don't need to decode or decrypt the shares, since our
129         #    caller doesn't care about the plaintext, only the
130         #    information about which shares are or are not valid.
131         # 3. When we are validating readers, we need to validate the
132         #    signature on the prefix. Do we? We already do this in the
133         #    servermap update?
134         self._verify = verify
135
136         self._status = RetrieveStatus()
137         self._status.set_storage_index(self._storage_index)
138         self._status.set_helper(False)
139         self._status.set_progress(0.0)
140         self._status.set_active(True)
141         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
142          offsets_tuple) = self.verinfo
143         self._status.set_size(datalength)
144         self._status.set_encoding(k, N)
145         self.readers = {}
146         self._stopped = False
147         self._pause_deferred = None
148         self._offset = None
149         self._read_length = None
150         self.log("got seqnum %d" % self.verinfo[0])
151
152
153     def get_status(self):
154         return self._status
155
156     def log(self, *args, **kwargs):
157         if "parent" not in kwargs:
158             kwargs["parent"] = self._log_number
159         if "facility" not in kwargs:
160             kwargs["facility"] = "tahoe.mutable.retrieve"
161         return log.msg(*args, **kwargs)
162
163     def _set_current_status(self, state):
164         seg = "%d/%d" % (self._current_segment, self._last_segment)
165         self._status.set_status("segment %s (%s)" % (seg, state))
166
167     ###################
168     # IPushProducer
169
170     def pauseProducing(self):
171         """
172         I am called by my download target if we have produced too much
173         data for it to handle. I make the downloader stop producing new
174         data until my resumeProducing method is called.
175         """
176         if self._pause_deferred is not None:
177             return
178
179         # fired when the download is unpaused.
180         self._old_status = self._status.get_status()
181         self._set_current_status("paused")
182
183         self._pause_deferred = defer.Deferred()
184
185
186     def resumeProducing(self):
187         """
188         I am called by my download target once it is ready to begin
189         receiving data again.
190         """
191         if self._pause_deferred is None:
192             return
193
194         p = self._pause_deferred
195         self._pause_deferred = None
196         self._status.set_status(self._old_status)
197
198         eventually(p.callback, None)
199
200     def stopProducing(self):
201         self._stopped = True
202         self.resumeProducing()
203
204
205     def _check_for_paused(self, res):
206         """
207         I am called just before a write to the consumer. I return a
208         Deferred that eventually fires with the data that is to be
209         written to the consumer. If the download has not been paused,
210         the Deferred fires immediately. Otherwise, the Deferred fires
211         when the downloader is unpaused.
212         """
213         if self._pause_deferred is not None:
214             d = defer.Deferred()
215             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
216             return d
217         return res
218
219     def _check_for_stopped(self, res):
220         if self._stopped:
221             raise DownloadStopped("our Consumer called stopProducing()")
222         return res
223
224
225     def download(self, consumer=None, offset=0, size=None):
226         assert IConsumer.providedBy(consumer) or self._verify
227
228         if consumer:
229             self._consumer = consumer
230             # we provide IPushProducer, so streaming=True, per
231             # IConsumer.
232             self._consumer.registerProducer(self, streaming=True)
233
234         self._done_deferred = defer.Deferred()
235         self._offset = offset
236         self._read_length = size
237         self._setup_download()
238         self._setup_encoding_parameters()
239         self.log("starting download")
240         self._started_fetching = time.time()
241         # The download process beyond this is a state machine.
242         # _add_active_peers will select the peers that we want to use
243         # for the download, and then attempt to start downloading. After
244         # each segment, it will check for doneness, reacting to broken
245         # peers and corrupt shares as necessary. If it runs out of good
246         # peers before downloading all of the segments, _done_deferred
247         # will errback.  Otherwise, it will eventually callback with the
248         # contents of the mutable file.
249         self.loop()
250         return self._done_deferred
251
252     def loop(self):
253         d = fireEventually(None) # avoid #237 recursion limit problem
254         d.addCallback(lambda ign: self._activate_enough_peers())
255         d.addCallback(lambda ign: self._download_current_segment())
256         # when we're done, _download_current_segment will call _done. If we
257         # aren't, it will call loop() again.
258         d.addErrback(self._error)
259
260     def _setup_download(self):
261         self._started = time.time()
262         self._status.set_status("Retrieving Shares")
263
264         # how many shares do we need?
265         (seqnum,
266          root_hash,
267          IV,
268          segsize,
269          datalength,
270          k,
271          N,
272          prefix,
273          offsets_tuple) = self.verinfo
274
275         # first, which servers can we use?
276         versionmap = self.servermap.make_versionmap()
277         shares = versionmap[self.verinfo]
278         # this sharemap is consumed as we decide to send requests
279         self.remaining_sharemap = DictOfSets()
280         for (shnum, peerid, timestamp) in shares:
281             self.remaining_sharemap.add(shnum, peerid)
282             # If the servermap update fetched anything, it fetched at least 1
283             # KiB, so we ask for that much.
284             # TODO: Change the cache methods to allow us to fetch all of the
285             # data that they have, then change this method to do that.
286             any_cache = self._node._read_from_cache(self.verinfo, shnum,
287                                                     0, 1000)
288             ss = self.servermap.connections[peerid]
289             reader = MDMFSlotReadProxy(ss,
290                                        self._storage_index,
291                                        shnum,
292                                        any_cache)
293             reader.peerid = peerid
294             self.readers[shnum] = reader
295         assert len(self.remaining_sharemap) >= k
296
297         self.shares = {} # maps shnum to validated blocks
298         self._active_readers = [] # list of active readers for this dl.
299         self._block_hash_trees = {} # shnum => hashtree
300
301         # We need one share hash tree for the entire file; its leaves
302         # are the roots of the block hash trees for the shares that
303         # comprise it, and its root is in the verinfo.
304         self.share_hash_tree = hashtree.IncompleteHashTree(N)
305         self.share_hash_tree.set_hashes({0: root_hash})
306
307     def decode(self, blocks_and_salts, segnum):
308         """
309         I am a helper method that the mutable file update process uses
310         as a shortcut to decode and decrypt the segments that it needs
311         to fetch in order to perform a file update. I take in a
312         collection of blocks and salts, and pick some of those to make a
313         segment with. I return the plaintext associated with that
314         segment.
315         """
316         # shnum => block hash tree. Unused, but setup_encoding_parameters will
317         # want to set this.
318         self._block_hash_trees = None
319         self._setup_encoding_parameters()
320
321         # This is the form expected by decode.
322         blocks_and_salts = blocks_and_salts.items()
323         blocks_and_salts = [(True, [d]) for d in blocks_and_salts]
324
325         d = self._decode_blocks(blocks_and_salts, segnum)
326         d.addCallback(self._decrypt_segment)
327         return d
328
329
330     def _setup_encoding_parameters(self):
331         """
332         I set up the encoding parameters, including k, n, the number
333         of segments associated with this file, and the segment decoders.
334         """
335         (seqnum,
336          root_hash,
337          IV,
338          segsize,
339          datalength,
340          k,
341          n,
342          known_prefix,
343          offsets_tuple) = self.verinfo
344         self._required_shares = k
345         self._total_shares = n
346         self._segment_size = segsize
347         self._data_length = datalength
348
349         if not IV:
350             self._version = MDMF_VERSION
351         else:
352             self._version = SDMF_VERSION
353
354         if datalength and segsize:
355             self._num_segments = mathutil.div_ceil(datalength, segsize)
356             self._tail_data_size = datalength % segsize
357         else:
358             self._num_segments = 0
359             self._tail_data_size = 0
360
361         self._segment_decoder = codec.CRSDecoder()
362         self._segment_decoder.set_params(segsize, k, n)
363
364         if  not self._tail_data_size:
365             self._tail_data_size = segsize
366
367         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
368                                                          self._required_shares)
369         if self._tail_segment_size == self._segment_size:
370             self._tail_decoder = self._segment_decoder
371         else:
372             self._tail_decoder = codec.CRSDecoder()
373             self._tail_decoder.set_params(self._tail_segment_size,
374                                           self._required_shares,
375                                           self._total_shares)
376
377         self.log("got encoding parameters: "
378                  "k: %d "
379                  "n: %d "
380                  "%d segments of %d bytes each (%d byte tail segment)" % \
381                  (k, n, self._num_segments, self._segment_size,
382                   self._tail_segment_size))
383
384         if self._block_hash_trees is not None:
385             for i in xrange(self._total_shares):
386                 # So we don't have to do this later.
387                 self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
388
389         # Our last task is to tell the downloader where to start and
390         # where to stop. We use three parameters for that:
391         #   - self._start_segment: the segment that we need to start
392         #     downloading from. 
393         #   - self._current_segment: the next segment that we need to
394         #     download.
395         #   - self._last_segment: The last segment that we were asked to
396         #     download.
397         #
398         #  We say that the download is complete when
399         #  self._current_segment > self._last_segment. We use
400         #  self._start_segment and self._last_segment to know when to
401         #  strip things off of segments, and how much to strip.
402         if self._offset:
403             self.log("got offset: %d" % self._offset)
404             # our start segment is the first segment containing the
405             # offset we were given. 
406             start = self._offset // self._segment_size
407
408             assert start < self._num_segments
409             self._start_segment = start
410             self.log("got start segment: %d" % self._start_segment)
411         else:
412             self._start_segment = 0
413
414
415         # If self._read_length is None, then we want to read the whole
416         # file. Otherwise, we want to read only part of the file, and
417         # need to figure out where to stop reading.
418         if self._read_length is not None:
419             # our end segment is the last segment containing part of the
420             # segment that we were asked to read.
421             self.log("got read length %d" % self._read_length)
422             if self._read_length != 0:
423                 end_data = self._offset + self._read_length
424
425                 # We don't actually need to read the byte at end_data,
426                 # but the one before it.
427                 end = (end_data - 1) // self._segment_size
428
429                 assert end < self._num_segments
430                 self._last_segment = end
431             else:
432                 self._last_segment = self._start_segment
433             self.log("got end segment: %d" % self._last_segment)
434         else:
435             self._last_segment = self._num_segments - 1
436
437         self._current_segment = self._start_segment
438
439     def _activate_enough_peers(self):
440         """
441         I populate self._active_readers with enough active readers to
442         retrieve the contents of this mutable file. I am called before
443         downloading starts, and (eventually) after each validation
444         error, connection error, or other problem in the download.
445         """
446         # TODO: It would be cool to investigate other heuristics for
447         # reader selection. For instance, the cost (in time the user
448         # spends waiting for their file) of selecting a really slow peer
449         # that happens to have a primary share is probably more than
450         # selecting a really fast peer that doesn't have a primary
451         # share. Maybe the servermap could be extended to provide this
452         # information; it could keep track of latency information while
453         # it gathers more important data, and then this routine could
454         # use that to select active readers.
455         #
456         # (these and other questions would be easier to answer with a
457         #  robust, configurable tahoe-lafs simulator, which modeled node
458         #  failures, differences in node speed, and other characteristics
459         #  that we expect storage servers to have.  You could have
460         #  presets for really stable grids (like allmydata.com),
461         #  friendnets, make it easy to configure your own settings, and
462         #  then simulate the effect of big changes on these use cases
463         #  instead of just reasoning about what the effect might be. Out
464         #  of scope for MDMF, though.)
465
466         # XXX: Why don't format= log messages work here?
467
468         known_shnums = set(self.remaining_sharemap.keys())
469         used_shnums = set([r.shnum for r in self._active_readers])
470         unused_shnums = known_shnums - used_shnums
471
472         if self._verify:
473             new_shnums = unused_shnums # use them all
474         elif len(self._active_readers) < self._required_shares:
475             # need more shares
476             more = self._required_shares - len(self._active_readers)
477             # We favor lower numbered shares, since FEC is faster with
478             # primary shares than with other shares, and lower-numbered
479             # shares are more likely to be primary than higher numbered
480             # shares.
481             new_shnums = sorted(unused_shnums)[:more]
482             if len(new_shnums) < more:
483                 # We don't have enough readers to retrieve the file; fail.
484                 self._raise_notenoughshareserror()
485         else:
486             new_shnums = []
487
488         self.log("adding %d new peers to the active list" % len(new_shnums))
489         for shnum in new_shnums:
490             reader = self.readers[shnum]
491             self._active_readers.append(reader)
492             self.log("added reader for share %d" % shnum)
493             # Each time we add a reader, we check to see if we need the
494             # private key. If we do, we politely ask for it and then continue
495             # computing. If we find that we haven't gotten it at the end of
496             # segment decoding, then we'll take more drastic measures.
497             if self._need_privkey and not self._node.is_readonly():
498                 d = reader.get_encprivkey()
499                 d.addCallback(self._try_to_validate_privkey, reader)
500                 # XXX: don't just drop the Deferred. We need error-reporting
501                 # but not flow-control here.
502         assert len(self._active_readers) >= self._required_shares
503
504     def _try_to_validate_prefix(self, prefix, reader):
505         """
506         I check that the prefix returned by a candidate server for
507         retrieval matches the prefix that the servermap knows about
508         (and, hence, the prefix that was validated earlier). If it does,
509         I return True, which means that I approve of the use of the
510         candidate server for segment retrieval. If it doesn't, I return
511         False, which means that another server must be chosen.
512         """
513         (seqnum,
514          root_hash,
515          IV,
516          segsize,
517          datalength,
518          k,
519          N,
520          known_prefix,
521          offsets_tuple) = self.verinfo
522         if known_prefix != prefix:
523             self.log("prefix from share %d doesn't match" % reader.shnum)
524             raise UncoordinatedWriteError("Mismatched prefix -- this could "
525                                           "indicate an uncoordinated write")
526         # Otherwise, we're okay -- no issues.
527
528
529     def _remove_reader(self, reader):
530         """
531         At various points, we will wish to remove a peer from
532         consideration and/or use. These include, but are not necessarily
533         limited to:
534
535             - A connection error.
536             - A mismatched prefix (that is, a prefix that does not match
537               our conception of the version information string).
538             - A failing block hash, salt hash, or share hash, which can
539               indicate disk failure/bit flips, or network trouble.
540
541         This method will do that. I will make sure that the
542         (shnum,reader) combination represented by my reader argument is
543         not used for anything else during this download. I will not
544         advise the reader of any corruption, something that my callers
545         may wish to do on their own.
546         """
547         # TODO: When you're done writing this, see if this is ever
548         # actually used for something that _mark_bad_share isn't. I have
549         # a feeling that they will be used for very similar things, and
550         # that having them both here is just going to be an epic amount
551         # of code duplication.
552         #
553         # (well, okay, not epic, but meaningful)
554         self.log("removing reader %s" % reader)
555         # Remove the reader from _active_readers
556         self._active_readers.remove(reader)
557         # TODO: self.readers.remove(reader)?
558         for shnum in list(self.remaining_sharemap.keys()):
559             self.remaining_sharemap.discard(shnum, reader.peerid)
560
561
562     def _mark_bad_share(self, reader, f):
563         """
564         I mark the (peerid, shnum) encapsulated by my reader argument as
565         a bad share, which means that it will not be used anywhere else.
566
567         There are several reasons to want to mark something as a bad
568         share. These include:
569
570             - A connection error to the peer.
571             - A mismatched prefix (that is, a prefix that does not match
572               our local conception of the version information string).
573             - A failing block hash, salt hash, share hash, or other
574               integrity check.
575
576         This method will ensure that readers that we wish to mark bad
577         (for these reasons or other reasons) are not used for the rest
578         of the download. Additionally, it will attempt to tell the
579         remote peer (with no guarantee of success) that its share is
580         corrupt.
581         """
582         self.log("marking share %d on server %s as bad" % \
583                  (reader.shnum, reader))
584         prefix = self.verinfo[-2]
585         self.servermap.mark_bad_share(reader.peerid,
586                                       reader.shnum,
587                                       prefix)
588         self._remove_reader(reader)
589         self._bad_shares.add((reader.peerid, reader.shnum, f))
590         self._status.problems[reader.peerid] = f
591         self._last_failure = f
592         self.notify_server_corruption(reader.peerid, reader.shnum,
593                                       str(f.value))
594
595
596     def _download_current_segment(self):
597         """
598         I download, validate, decode, decrypt, and assemble the segment
599         that this Retrieve is currently responsible for downloading.
600         """
601         assert len(self._active_readers) >= self._required_shares
602         if self._current_segment > self._last_segment:
603             # No more segments to download, we're done.
604             self.log("got plaintext, done")
605             return self._done()
606         self.log("on segment %d of %d" %
607                  (self._current_segment + 1, self._num_segments))
608         d = self._process_segment(self._current_segment)
609         d.addCallback(lambda ign: self.loop())
610         return d
611
612     def _process_segment(self, segnum):
613         """
614         I download, validate, decode, and decrypt one segment of the
615         file that this Retrieve is retrieving. This means coordinating
616         the process of getting k blocks of that file, validating them,
617         assembling them into one segment with the decoder, and then
618         decrypting them.
619         """
620         self.log("processing segment %d" % segnum)
621
622         # TODO: The old code uses a marker. Should this code do that
623         # too? What did the Marker do?
624         assert len(self._active_readers) >= self._required_shares
625
626         # We need to ask each of our active readers for its block and
627         # salt. We will then validate those. If validation is
628         # successful, we will assemble the results into plaintext.
629         ds = []
630         for reader in self._active_readers:
631             started = time.time()
632             d = reader.get_block_and_salt(segnum)
633             d2 = self._get_needed_hashes(reader, segnum)
634             dl = defer.DeferredList([d, d2], consumeErrors=True)
635             dl.addCallback(self._validate_block, segnum, reader, started)
636             dl.addErrback(self._validation_or_decoding_failed, [reader])
637             ds.append(dl)
638         dl = defer.DeferredList(ds)
639         if self._verify:
640             dl.addCallback(lambda ignored: "")
641             dl.addCallback(self._set_segment)
642         else:
643             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
644         return dl
645
646
647     def _maybe_decode_and_decrypt_segment(self, blocks_and_salts, segnum):
648         """
649         I take the results of fetching and validating the blocks from a
650         callback chain in another method. If the results are such that
651         they tell me that validation and fetching succeeded without
652         incident, I will proceed with decoding and decryption.
653         Otherwise, I will do nothing.
654         """
655         self.log("trying to decode and decrypt segment %d" % segnum)
656         failures = False
657         for block_and_salt in blocks_and_salts:
658             if not block_and_salt[0] or block_and_salt[1] == None:
659                 self.log("some validation operations failed; not proceeding")
660                 failures = True
661                 break
662         if not failures:
663             self.log("everything looks ok, building segment %d" % segnum)
664             d = self._decode_blocks(blocks_and_salts, segnum)
665             d.addCallback(self._decrypt_segment)
666             d.addErrback(self._validation_or_decoding_failed,
667                          self._active_readers)
668             # check to see whether we've been paused before writing
669             # anything.
670             d.addCallback(self._check_for_paused)
671             d.addCallback(self._check_for_stopped)
672             d.addCallback(self._set_segment)
673             return d
674         else:
675             return defer.succeed(None)
676
677
678     def _set_segment(self, segment):
679         """
680         Given a plaintext segment, I register that segment with the
681         target that is handling the file download.
682         """
683         self.log("got plaintext for segment %d" % self._current_segment)
684         if self._current_segment == self._start_segment:
685             # We're on the first segment. It's possible that we want
686             # only some part of the end of this segment, and that we
687             # just downloaded the whole thing to get that part. If so,
688             # we need to account for that and give the reader just the
689             # data that they want.
690             n = self._offset % self._segment_size
691             self.log("stripping %d bytes off of the first segment" % n)
692             self.log("original segment length: %d" % len(segment))
693             segment = segment[n:]
694             self.log("new segment length: %d" % len(segment))
695
696         if self._current_segment == self._last_segment and self._read_length is not None:
697             # We're on the last segment. It's possible that we only want
698             # part of the beginning of this segment, and that we
699             # downloaded the whole thing anyway. Make sure to give the
700             # caller only the portion of the segment that they want to
701             # receive.
702             extra = self._read_length
703             if self._start_segment != self._last_segment:
704                 extra -= self._segment_size - \
705                             (self._offset % self._segment_size)
706             extra %= self._segment_size
707             self.log("original segment length: %d" % len(segment))
708             segment = segment[:extra]
709             self.log("new segment length: %d" % len(segment))
710             self.log("only taking %d bytes of the last segment" % extra)
711
712         if not self._verify:
713             self._consumer.write(segment)
714         else:
715             # we don't care about the plaintext if we are doing a verify.
716             segment = None
717         self._current_segment += 1
718
719
720     def _validation_or_decoding_failed(self, f, readers):
721         """
722         I am called when a block or a salt fails to correctly validate, or when
723         the decryption or decoding operation fails for some reason.  I react to
724         this failure by notifying the remote server of corruption, and then
725         removing the remote peer from further activity.
726         """
727         assert isinstance(readers, list)
728         bad_shnums = [reader.shnum for reader in readers]
729
730         self.log("validation or decoding failed on share(s) %s, peer(s) %s "
731                  ", segment %d: %s" % \
732                  (bad_shnums, readers, self._current_segment, str(f)))
733         for reader in readers:
734             self._mark_bad_share(reader, f)
735         return
736
737
738     def _validate_block(self, results, segnum, reader, started):
739         """
740         I validate a block from one share on a remote server.
741         """
742         # Grab the part of the block hash tree that is necessary to
743         # validate this block, then generate the block hash root.
744         self.log("validating share %d for segment %d" % (reader.shnum,
745                                                              segnum))
746         elapsed = time.time() - started
747         self._status.add_fetch_timing(reader.peerid, elapsed)
748         self._set_current_status("validating blocks")
749         # Did we fail to fetch either of the things that we were
750         # supposed to? Fail if so.
751         if not results[0][0] and results[1][0]:
752             # handled by the errback handler.
753
754             # These all get batched into one query, so the resulting
755             # failure should be the same for all of them, so we can just
756             # use the first one.
757             assert isinstance(results[0][1], failure.Failure)
758
759             f = results[0][1]
760             raise CorruptShareError(reader.peerid,
761                                     reader.shnum,
762                                     "Connection error: %s" % str(f))
763
764         block_and_salt, block_and_sharehashes = results
765         block, salt = block_and_salt[1]
766         blockhashes, sharehashes = block_and_sharehashes[1]
767
768         blockhashes = dict(enumerate(blockhashes[1]))
769         self.log("the reader gave me the following blockhashes: %s" % \
770                  blockhashes.keys())
771         self.log("the reader gave me the following sharehashes: %s" % \
772                  sharehashes[1].keys())
773         bht = self._block_hash_trees[reader.shnum]
774
775         for bhk, bhv in blockhashes.iteritems():
776             log.msg("xxx 0 blockhash: %s %s" % (bhk, base32.b2a(bhv),))
777
778         if bht.needed_hashes(segnum, include_leaf=True):
779             try:
780                 bht.set_hashes(blockhashes)
781             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
782                     IndexError), e:
783                 raise CorruptShareError(reader.peerid,
784                                         reader.shnum,
785                                         "block hash tree failure: %s" % e)
786
787         if self._version == MDMF_VERSION:
788             blockhash = hashutil.block_hash(salt + block)
789         else:
790             blockhash = hashutil.block_hash(block)
791         # If this works without an error, then validation is
792         # successful.
793         try:
794            bht.set_hashes(leaves={segnum: blockhash})
795         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
796                 IndexError), e:
797             raise CorruptShareError(reader.peerid,
798                                     reader.shnum,
799                                     "block hash tree failure: %s" % e)
800
801         # Reaching this point means that we know that this segment
802         # is correct. Now we need to check to see whether the share
803         # hash chain is also correct. 
804         # SDMF wrote share hash chains that didn't contain the
805         # leaves, which would be produced from the block hash tree.
806         # So we need to validate the block hash tree first. If
807         # successful, then bht[0] will contain the root for the
808         # shnum, which will be a leaf in the share hash tree, which
809         # will allow us to validate the rest of the tree.
810         if self.share_hash_tree.needed_hashes(reader.shnum,
811                                               include_leaf=True) or \
812                                               self._verify:
813             try:
814                 self.share_hash_tree.set_hashes(hashes=sharehashes[1],
815                                             leaves={reader.shnum: bht[0]})
816             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
817                     IndexError), e:
818                 raise CorruptShareError(reader.peerid,
819                                         reader.shnum,
820                                         "corrupt hashes: %s" % e)
821
822         self.log('share %d is valid for segment %d' % (reader.shnum,
823                                                        segnum))
824         return {reader.shnum: (block, salt)}
825
826
827     def _get_needed_hashes(self, reader, segnum):
828         """
829         I get the hashes needed to validate segnum from the reader, then return
830         to my caller when this is done.
831         """
832         bht = self._block_hash_trees[reader.shnum]
833         needed = bht.needed_hashes(segnum, include_leaf=True)
834         # The root of the block hash tree is also a leaf in the share
835         # hash tree. So we don't need to fetch it from the remote
836         # server. In the case of files with one segment, this means that
837         # we won't fetch any block hash tree from the remote server,
838         # since the hash of each share of the file is the entire block
839         # hash tree, and is a leaf in the share hash tree. This is fine,
840         # since any share corruption will be detected in the share hash
841         # tree.
842         #needed.discard(0)
843         self.log("getting blockhashes for segment %d, share %d: %s" % \
844                  (segnum, reader.shnum, str(needed)))
845         d1 = reader.get_blockhashes(needed, force_remote=True)
846         if self.share_hash_tree.needed_hashes(reader.shnum):
847             need = self.share_hash_tree.needed_hashes(reader.shnum)
848             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
849                                                                  str(need)))
850             d2 = reader.get_sharehashes(need, force_remote=True)
851         else:
852             d2 = defer.succeed({}) # the logic in the next method
853                                    # expects a dict
854         dl = defer.DeferredList([d1, d2], consumeErrors=True)
855         return dl
856
857
858     def _decode_blocks(self, blocks_and_salts, segnum):
859         """
860         I take a list of k blocks and salts, and decode that into a
861         single encrypted segment.
862         """
863         d = {}
864         # We want to merge our dictionaries to the form 
865         # {shnum: blocks_and_salts}
866         #
867         # The dictionaries come from validate block that way, so we just
868         # need to merge them.
869         for block_and_salt in blocks_and_salts:
870             d.update(block_and_salt[1])
871
872         # All of these blocks should have the same salt; in SDMF, it is
873         # the file-wide IV, while in MDMF it is the per-segment salt. In
874         # either case, we just need to get one of them and use it.
875         #
876         # d.items()[0] is like (shnum, (block, salt))
877         # d.items()[0][1] is like (block, salt)
878         # d.items()[0][1][1] is the salt.
879         salt = d.items()[0][1][1]
880         # Next, extract just the blocks from the dict. We'll use the
881         # salt in the next step.
882         share_and_shareids = [(k, v[0]) for k, v in d.items()]
883         d2 = dict(share_and_shareids)
884         shareids = []
885         shares = []
886         for shareid, share in d2.items():
887             shareids.append(shareid)
888             shares.append(share)
889
890         self._set_current_status("decoding")
891         started = time.time()
892         assert len(shareids) >= self._required_shares, len(shareids)
893         # zfec really doesn't want extra shares
894         shareids = shareids[:self._required_shares]
895         shares = shares[:self._required_shares]
896         self.log("decoding segment %d" % segnum)
897         if segnum == self._num_segments - 1:
898             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
899         else:
900             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
901         def _process(buffers):
902             segment = "".join(buffers)
903             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
904                      segnum=segnum,
905                      numsegs=self._num_segments,
906                      level=log.NOISY)
907             self.log(" joined length %d, datalength %d" %
908                      (len(segment), self._data_length))
909             if segnum == self._num_segments - 1:
910                 size_to_use = self._tail_data_size
911             else:
912                 size_to_use = self._segment_size
913             segment = segment[:size_to_use]
914             self.log(" segment len=%d" % len(segment))
915             self._status.accumulate_decode_time(time.time() - started)
916             return segment, salt
917         d.addCallback(_process)
918         return d
919
920
921     def _decrypt_segment(self, segment_and_salt):
922         """
923         I take a single segment and its salt, and decrypt it. I return
924         the plaintext of the segment that is in my argument.
925         """
926         segment, salt = segment_and_salt
927         self._set_current_status("decrypting")
928         self.log("decrypting segment %d" % self._current_segment)
929         started = time.time()
930         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
931         decryptor = AES(key)
932         plaintext = decryptor.process(segment)
933         self._status.accumulate_decrypt_time(time.time() - started)
934         return plaintext
935
936
937     def notify_server_corruption(self, peerid, shnum, reason):
938         ss = self.servermap.connections[peerid]
939         ss.callRemoteOnly("advise_corrupt_share",
940                           "mutable", self._storage_index, shnum, reason)
941
942
943     def _try_to_validate_privkey(self, enc_privkey, reader):
944         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
945         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
946         if alleged_writekey != self._node.get_writekey():
947             self.log("invalid privkey from %s shnum %d" %
948                      (reader, reader.shnum),
949                      level=log.WEIRD, umid="YIw4tA")
950             if self._verify:
951                 self.servermap.mark_bad_share(reader.peerid, reader.shnum,
952                                               self.verinfo[-2])
953                 e = CorruptShareError(reader.peerid,
954                                       reader.shnum,
955                                       "invalid privkey")
956                 f = failure.Failure(e)
957                 self._bad_shares.add((reader.peerid, reader.shnum, f))
958             return
959
960         # it's good
961         self.log("got valid privkey from shnum %d on reader %s" %
962                  (reader.shnum, reader))
963         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
964         self._node._populate_encprivkey(enc_privkey)
965         self._node._populate_privkey(privkey)
966         self._need_privkey = False
967
968
969
970     def _done(self):
971         """
972         I am called by _download_current_segment when the download process
973         has finished successfully. After making some useful logging
974         statements, I return the decrypted contents to the owner of this
975         Retrieve object through self._done_deferred.
976         """
977         self._running = False
978         self._status.set_active(False)
979         now = time.time()
980         self._status.timings['total'] = now - self._started
981         self._status.timings['fetch'] = now - self._started_fetching
982         self._status.set_status("Finished")
983         self._status.set_progress(1.0)
984
985         # remember the encoding parameters, use them again next time
986         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
987          offsets_tuple) = self.verinfo
988         self._node._populate_required_shares(k)
989         self._node._populate_total_shares(N)
990
991         if self._verify:
992             ret = list(self._bad_shares)
993             self.log("done verifying, found %d bad shares" % len(ret))
994         else:
995             # TODO: upload status here?
996             ret = self._consumer
997             self._consumer.unregisterProducer()
998         eventually(self._done_deferred.callback, ret)
999
1000
1001     def _raise_notenoughshareserror(self):
1002         """
1003         I am called by _activate_enough_peers when there are not enough
1004         active peers left to complete the download. After making some
1005         useful logging statements, I throw an exception to that effect
1006         to the caller of this Retrieve object through
1007         self._done_deferred.
1008         """
1009
1010         format = ("ran out of peers: "
1011                   "have %(have)d of %(total)d segments "
1012                   "found %(bad)d bad shares "
1013                   "encoding %(k)d-of-%(n)d")
1014         args = {"have": self._current_segment,
1015                 "total": self._num_segments,
1016                 "need": self._last_segment,
1017                 "k": self._required_shares,
1018                 "n": self._total_shares,
1019                 "bad": len(self._bad_shares)}
1020         raise NotEnoughSharesError("%s, last failure: %s" %
1021                                    (format % args, str(self._last_failure)))
1022
1023     def _error(self, f):
1024         # all errors, including NotEnoughSharesError, land here
1025         self._running = False
1026         self._status.set_active(False)
1027         now = time.time()
1028         self._status.timings['total'] = now - self._started
1029         self._status.timings['fetch'] = now - self._started_fetching
1030         self._status.set_status("Failed")
1031         eventually(self._done_deferred.errback, f)