]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable: don't tell server about corruption unless it's really CorruptShareError
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually
9 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
10      DownloadStopped, MDMF_VERSION, SDMF_VERSION
11 from allmydata.util import hashutil, log, mathutil, deferredutil
12 from allmydata.util.dictutil import DictOfSets
13 from allmydata import hashtree, codec
14 from allmydata.storage.server import si_b2a
15 from pycryptopp.cipher.aes import AES
16 from pycryptopp.publickey import rsa
17
18 from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
19 from allmydata.mutable.layout import MDMFSlotReadProxy
20
21 class RetrieveStatus:
22     implements(IRetrieveStatus)
23     statusid_counter = count(0)
24     def __init__(self):
25         self.timings = {}
26         self.timings["fetch_per_server"] = {}
27         self.timings["decode"] = 0.0
28         self.timings["decrypt"] = 0.0
29         self.timings["cumulative_verify"] = 0.0
30         self._problems = {}
31         self.active = True
32         self.storage_index = None
33         self.helper = False
34         self.encoding = ("?","?")
35         self.size = None
36         self.status = "Not started"
37         self.progress = 0.0
38         self.counter = self.statusid_counter.next()
39         self.started = time.time()
40
41     def get_started(self):
42         return self.started
43     def get_storage_index(self):
44         return self.storage_index
45     def get_encoding(self):
46         return self.encoding
47     def using_helper(self):
48         return self.helper
49     def get_size(self):
50         return self.size
51     def get_status(self):
52         return self.status
53     def get_progress(self):
54         return self.progress
55     def get_active(self):
56         return self.active
57     def get_counter(self):
58         return self.counter
59     def get_problems(self):
60         return self._problems
61
62     def add_fetch_timing(self, server, elapsed):
63         serverid = server.get_serverid()
64         if serverid not in self.timings["fetch_per_server"]:
65             self.timings["fetch_per_server"][serverid] = []
66         self.timings["fetch_per_server"][serverid].append(elapsed)
67     def accumulate_decode_time(self, elapsed):
68         self.timings["decode"] += elapsed
69     def accumulate_decrypt_time(self, elapsed):
70         self.timings["decrypt"] += elapsed
71     def set_storage_index(self, si):
72         self.storage_index = si
73     def set_helper(self, helper):
74         self.helper = helper
75     def set_encoding(self, k, n):
76         self.encoding = (k, n)
77     def set_size(self, size):
78         self.size = size
79     def set_status(self, status):
80         self.status = status
81     def set_progress(self, value):
82         self.progress = value
83     def set_active(self, value):
84         self.active = value
85     def add_problem(self, server, f):
86         serverid = server.get_serverid()
87         self._problems[serverid] = f
88
89 class Marker:
90     pass
91
92 class Retrieve:
93     # this class is currently single-use. Eventually (in MDMF) we will make
94     # it multi-use, in which case you can call download(range) multiple
95     # times, and each will have a separate response chain. However the
96     # Retrieve object will remain tied to a specific version of the file, and
97     # will use a single ServerMap instance.
98     implements(IPushProducer)
99
100     def __init__(self, filenode, storage_broker, servermap, verinfo,
101                  fetch_privkey=False, verify=False):
102         self._node = filenode
103         assert self._node.get_pubkey()
104         self._storage_broker = storage_broker
105         self._storage_index = filenode.get_storage_index()
106         assert self._node.get_readkey()
107         self._last_failure = None
108         prefix = si_b2a(self._storage_index)[:5]
109         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
110         self._running = True
111         self._decoding = False
112         self._bad_shares = set()
113
114         self.servermap = servermap
115         assert self._node.get_pubkey()
116         self.verinfo = verinfo
117         # during repair, we may be called upon to grab the private key, since
118         # it wasn't picked up during a verify=False checker run, and we'll
119         # need it for repair to generate a new version.
120         self._need_privkey = verify or (fetch_privkey
121                                         and not self._node.get_privkey())
122
123         if self._need_privkey:
124             # TODO: Evaluate the need for this. We'll use it if we want
125             # to limit how many queries are on the wire for the privkey
126             # at once.
127             self._privkey_query_markers = [] # one Marker for each time we've
128                                              # tried to get the privkey.
129
130         # verify means that we are using the downloader logic to verify all
131         # of our shares. This tells the downloader a few things.
132         # 
133         # 1. We need to download all of the shares.
134         # 2. We don't need to decode or decrypt the shares, since our
135         #    caller doesn't care about the plaintext, only the
136         #    information about which shares are or are not valid.
137         # 3. When we are validating readers, we need to validate the
138         #    signature on the prefix. Do we? We already do this in the
139         #    servermap update?
140         self._verify = verify
141
142         self._status = RetrieveStatus()
143         self._status.set_storage_index(self._storage_index)
144         self._status.set_helper(False)
145         self._status.set_progress(0.0)
146         self._status.set_active(True)
147         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
148          offsets_tuple) = self.verinfo
149         self._status.set_size(datalength)
150         self._status.set_encoding(k, N)
151         self.readers = {}
152         self._stopped = False
153         self._pause_deferred = None
154         self._offset = None
155         self._read_length = None
156         self.log("got seqnum %d" % self.verinfo[0])
157
158
159     def get_status(self):
160         return self._status
161
162     def log(self, *args, **kwargs):
163         if "parent" not in kwargs:
164             kwargs["parent"] = self._log_number
165         if "facility" not in kwargs:
166             kwargs["facility"] = "tahoe.mutable.retrieve"
167         return log.msg(*args, **kwargs)
168
169     def _set_current_status(self, state):
170         seg = "%d/%d" % (self._current_segment, self._last_segment)
171         self._status.set_status("segment %s (%s)" % (seg, state))
172
173     ###################
174     # IPushProducer
175
176     def pauseProducing(self):
177         """
178         I am called by my download target if we have produced too much
179         data for it to handle. I make the downloader stop producing new
180         data until my resumeProducing method is called.
181         """
182         if self._pause_deferred is not None:
183             return
184
185         # fired when the download is unpaused.
186         self._old_status = self._status.get_status()
187         self._set_current_status("paused")
188
189         self._pause_deferred = defer.Deferred()
190
191
192     def resumeProducing(self):
193         """
194         I am called by my download target once it is ready to begin
195         receiving data again.
196         """
197         if self._pause_deferred is None:
198             return
199
200         p = self._pause_deferred
201         self._pause_deferred = None
202         self._status.set_status(self._old_status)
203
204         eventually(p.callback, None)
205
206     def stopProducing(self):
207         self._stopped = True
208         self.resumeProducing()
209
210
211     def _check_for_paused(self, res):
212         """
213         I am called just before a write to the consumer. I return a
214         Deferred that eventually fires with the data that is to be
215         written to the consumer. If the download has not been paused,
216         the Deferred fires immediately. Otherwise, the Deferred fires
217         when the downloader is unpaused.
218         """
219         if self._pause_deferred is not None:
220             d = defer.Deferred()
221             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
222             return d
223         return res
224
225     def _check_for_stopped(self, res):
226         if self._stopped:
227             raise DownloadStopped("our Consumer called stopProducing()")
228         return res
229
230
231     def download(self, consumer=None, offset=0, size=None):
232         assert IConsumer.providedBy(consumer) or self._verify
233
234         if consumer:
235             self._consumer = consumer
236             # we provide IPushProducer, so streaming=True, per
237             # IConsumer.
238             self._consumer.registerProducer(self, streaming=True)
239
240         self._done_deferred = defer.Deferred()
241         self._offset = offset
242         self._read_length = size
243         self._setup_download()
244         self._setup_encoding_parameters()
245         self.log("starting download")
246         self._started_fetching = time.time()
247         # The download process beyond this is a state machine.
248         # _add_active_servers will select the servers that we want to use
249         # for the download, and then attempt to start downloading. After
250         # each segment, it will check for doneness, reacting to broken
251         # servers and corrupt shares as necessary. If it runs out of good
252         # servers before downloading all of the segments, _done_deferred
253         # will errback.  Otherwise, it will eventually callback with the
254         # contents of the mutable file.
255         self.loop()
256         return self._done_deferred
257
258     def loop(self):
259         d = fireEventually(None) # avoid #237 recursion limit problem
260         d.addCallback(lambda ign: self._activate_enough_servers())
261         d.addCallback(lambda ign: self._download_current_segment())
262         # when we're done, _download_current_segment will call _done. If we
263         # aren't, it will call loop() again.
264         d.addErrback(self._error)
265
266     def _setup_download(self):
267         self._started = time.time()
268         self._status.set_status("Retrieving Shares")
269
270         # how many shares do we need?
271         (seqnum,
272          root_hash,
273          IV,
274          segsize,
275          datalength,
276          k,
277          N,
278          prefix,
279          offsets_tuple) = self.verinfo
280
281         # first, which servers can we use?
282         versionmap = self.servermap.make_versionmap()
283         shares = versionmap[self.verinfo]
284         # this sharemap is consumed as we decide to send requests
285         self.remaining_sharemap = DictOfSets()
286         for (shnum, server, timestamp) in shares:
287             self.remaining_sharemap.add(shnum, server)
288             # If the servermap update fetched anything, it fetched at least 1
289             # KiB, so we ask for that much.
290             # TODO: Change the cache methods to allow us to fetch all of the
291             # data that they have, then change this method to do that.
292             any_cache = self._node._read_from_cache(self.verinfo, shnum,
293                                                     0, 1000)
294             reader = MDMFSlotReadProxy(server.get_rref(),
295                                        self._storage_index,
296                                        shnum,
297                                        any_cache)
298             reader.server = server
299             self.readers[shnum] = reader
300         assert len(self.remaining_sharemap) >= k
301
302         self.shares = {} # maps shnum to validated blocks
303         self._active_readers = [] # list of active readers for this dl.
304         self._block_hash_trees = {} # shnum => hashtree
305
306         # We need one share hash tree for the entire file; its leaves
307         # are the roots of the block hash trees for the shares that
308         # comprise it, and its root is in the verinfo.
309         self.share_hash_tree = hashtree.IncompleteHashTree(N)
310         self.share_hash_tree.set_hashes({0: root_hash})
311
312     def decode(self, blocks_and_salts, segnum):
313         """
314         I am a helper method that the mutable file update process uses
315         as a shortcut to decode and decrypt the segments that it needs
316         to fetch in order to perform a file update. I take in a
317         collection of blocks and salts, and pick some of those to make a
318         segment with. I return the plaintext associated with that
319         segment.
320         """
321         # shnum => block hash tree. Unused, but setup_encoding_parameters will
322         # want to set this.
323         self._block_hash_trees = None
324         self._setup_encoding_parameters()
325
326         # _decode_blocks() expects the output of a gatherResults that
327         # contains the outputs of _validate_block() (each of which is a dict
328         # mapping shnum to (block,salt) bytestrings).
329         d = self._decode_blocks([blocks_and_salts], segnum)
330         d.addCallback(self._decrypt_segment)
331         return d
332
333
334     def _setup_encoding_parameters(self):
335         """
336         I set up the encoding parameters, including k, n, the number
337         of segments associated with this file, and the segment decoders.
338         """
339         (seqnum,
340          root_hash,
341          IV,
342          segsize,
343          datalength,
344          k,
345          n,
346          known_prefix,
347          offsets_tuple) = self.verinfo
348         self._required_shares = k
349         self._total_shares = n
350         self._segment_size = segsize
351         self._data_length = datalength
352
353         if not IV:
354             self._version = MDMF_VERSION
355         else:
356             self._version = SDMF_VERSION
357
358         if datalength and segsize:
359             self._num_segments = mathutil.div_ceil(datalength, segsize)
360             self._tail_data_size = datalength % segsize
361         else:
362             self._num_segments = 0
363             self._tail_data_size = 0
364
365         self._segment_decoder = codec.CRSDecoder()
366         self._segment_decoder.set_params(segsize, k, n)
367
368         if  not self._tail_data_size:
369             self._tail_data_size = segsize
370
371         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
372                                                          self._required_shares)
373         if self._tail_segment_size == self._segment_size:
374             self._tail_decoder = self._segment_decoder
375         else:
376             self._tail_decoder = codec.CRSDecoder()
377             self._tail_decoder.set_params(self._tail_segment_size,
378                                           self._required_shares,
379                                           self._total_shares)
380
381         self.log("got encoding parameters: "
382                  "k: %d "
383                  "n: %d "
384                  "%d segments of %d bytes each (%d byte tail segment)" % \
385                  (k, n, self._num_segments, self._segment_size,
386                   self._tail_segment_size))
387
388         if self._block_hash_trees is not None:
389             for i in xrange(self._total_shares):
390                 # So we don't have to do this later.
391                 self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
392
393         # Our last task is to tell the downloader where to start and
394         # where to stop. We use three parameters for that:
395         #   - self._start_segment: the segment that we need to start
396         #     downloading from. 
397         #   - self._current_segment: the next segment that we need to
398         #     download.
399         #   - self._last_segment: The last segment that we were asked to
400         #     download.
401         #
402         #  We say that the download is complete when
403         #  self._current_segment > self._last_segment. We use
404         #  self._start_segment and self._last_segment to know when to
405         #  strip things off of segments, and how much to strip.
406         if self._offset:
407             self.log("got offset: %d" % self._offset)
408             # our start segment is the first segment containing the
409             # offset we were given. 
410             start = self._offset // self._segment_size
411
412             assert start < self._num_segments
413             self._start_segment = start
414             self.log("got start segment: %d" % self._start_segment)
415         else:
416             self._start_segment = 0
417
418
419         # If self._read_length is None, then we want to read the whole
420         # file. Otherwise, we want to read only part of the file, and
421         # need to figure out where to stop reading.
422         if self._read_length is not None:
423             # our end segment is the last segment containing part of the
424             # segment that we were asked to read.
425             self.log("got read length %d" % self._read_length)
426             if self._read_length != 0:
427                 end_data = self._offset + self._read_length
428
429                 # We don't actually need to read the byte at end_data,
430                 # but the one before it.
431                 end = (end_data - 1) // self._segment_size
432
433                 assert end < self._num_segments
434                 self._last_segment = end
435             else:
436                 self._last_segment = self._start_segment
437             self.log("got end segment: %d" % self._last_segment)
438         else:
439             self._last_segment = self._num_segments - 1
440
441         self._current_segment = self._start_segment
442
443     def _activate_enough_servers(self):
444         """
445         I populate self._active_readers with enough active readers to
446         retrieve the contents of this mutable file. I am called before
447         downloading starts, and (eventually) after each validation
448         error, connection error, or other problem in the download.
449         """
450         # TODO: It would be cool to investigate other heuristics for
451         # reader selection. For instance, the cost (in time the user
452         # spends waiting for their file) of selecting a really slow server
453         # that happens to have a primary share is probably more than
454         # selecting a really fast server that doesn't have a primary
455         # share. Maybe the servermap could be extended to provide this
456         # information; it could keep track of latency information while
457         # it gathers more important data, and then this routine could
458         # use that to select active readers.
459         #
460         # (these and other questions would be easier to answer with a
461         #  robust, configurable tahoe-lafs simulator, which modeled node
462         #  failures, differences in node speed, and other characteristics
463         #  that we expect storage servers to have.  You could have
464         #  presets for really stable grids (like allmydata.com),
465         #  friendnets, make it easy to configure your own settings, and
466         #  then simulate the effect of big changes on these use cases
467         #  instead of just reasoning about what the effect might be. Out
468         #  of scope for MDMF, though.)
469
470         # XXX: Why don't format= log messages work here?
471
472         known_shnums = set(self.remaining_sharemap.keys())
473         used_shnums = set([r.shnum for r in self._active_readers])
474         unused_shnums = known_shnums - used_shnums
475
476         if self._verify:
477             new_shnums = unused_shnums # use them all
478         elif len(self._active_readers) < self._required_shares:
479             # need more shares
480             more = self._required_shares - len(self._active_readers)
481             # We favor lower numbered shares, since FEC is faster with
482             # primary shares than with other shares, and lower-numbered
483             # shares are more likely to be primary than higher numbered
484             # shares.
485             new_shnums = sorted(unused_shnums)[:more]
486             if len(new_shnums) < more:
487                 # We don't have enough readers to retrieve the file; fail.
488                 self._raise_notenoughshareserror()
489         else:
490             new_shnums = []
491
492         self.log("adding %d new servers to the active list" % len(new_shnums))
493         for shnum in new_shnums:
494             reader = self.readers[shnum]
495             self._active_readers.append(reader)
496             self.log("added reader for share %d" % shnum)
497             # Each time we add a reader, we check to see if we need the
498             # private key. If we do, we politely ask for it and then continue
499             # computing. If we find that we haven't gotten it at the end of
500             # segment decoding, then we'll take more drastic measures.
501             if self._need_privkey and not self._node.is_readonly():
502                 d = reader.get_encprivkey()
503                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
504                 # XXX: don't just drop the Deferred. We need error-reporting
505                 # but not flow-control here.
506         assert len(self._active_readers) >= self._required_shares
507
508     def _try_to_validate_prefix(self, prefix, reader):
509         """
510         I check that the prefix returned by a candidate server for
511         retrieval matches the prefix that the servermap knows about
512         (and, hence, the prefix that was validated earlier). If it does,
513         I return True, which means that I approve of the use of the
514         candidate server for segment retrieval. If it doesn't, I return
515         False, which means that another server must be chosen.
516         """
517         (seqnum,
518          root_hash,
519          IV,
520          segsize,
521          datalength,
522          k,
523          N,
524          known_prefix,
525          offsets_tuple) = self.verinfo
526         if known_prefix != prefix:
527             self.log("prefix from share %d doesn't match" % reader.shnum)
528             raise UncoordinatedWriteError("Mismatched prefix -- this could "
529                                           "indicate an uncoordinated write")
530         # Otherwise, we're okay -- no issues.
531
532
533     def _remove_reader(self, reader):
534         """
535         At various points, we will wish to remove a server from
536         consideration and/or use. These include, but are not necessarily
537         limited to:
538
539             - A connection error.
540             - A mismatched prefix (that is, a prefix that does not match
541               our conception of the version information string).
542             - A failing block hash, salt hash, or share hash, which can
543               indicate disk failure/bit flips, or network trouble.
544
545         This method will do that. I will make sure that the
546         (shnum,reader) combination represented by my reader argument is
547         not used for anything else during this download. I will not
548         advise the reader of any corruption, something that my callers
549         may wish to do on their own.
550         """
551         # TODO: When you're done writing this, see if this is ever
552         # actually used for something that _mark_bad_share isn't. I have
553         # a feeling that they will be used for very similar things, and
554         # that having them both here is just going to be an epic amount
555         # of code duplication.
556         #
557         # (well, okay, not epic, but meaningful)
558         self.log("removing reader %s" % reader)
559         # Remove the reader from _active_readers
560         self._active_readers.remove(reader)
561         # TODO: self.readers.remove(reader)?
562         for shnum in list(self.remaining_sharemap.keys()):
563             self.remaining_sharemap.discard(shnum, reader.server)
564
565
566     def _mark_bad_share(self, server, shnum, reader, f):
567         """
568         I mark the given (server, shnum) as a bad share, which means that it
569         will not be used anywhere else.
570
571         There are several reasons to want to mark something as a bad
572         share. These include:
573
574             - A connection error to the server.
575             - A mismatched prefix (that is, a prefix that does not match
576               our local conception of the version information string).
577             - A failing block hash, salt hash, share hash, or other
578               integrity check.
579
580         This method will ensure that readers that we wish to mark bad
581         (for these reasons or other reasons) are not used for the rest
582         of the download. Additionally, it will attempt to tell the
583         remote server (with no guarantee of success) that its share is
584         corrupt.
585         """
586         self.log("marking share %d on server %s as bad" % \
587                  (shnum, server.get_name()))
588         prefix = self.verinfo[-2]
589         self.servermap.mark_bad_share(server, shnum, prefix)
590         self._remove_reader(reader)
591         self._bad_shares.add((server, shnum, f))
592         self._status.add_problem(server, f)
593         self._last_failure = f
594         if f.check(CorruptShareError):
595             self.notify_server_corruption(server, shnum, str(f.value))
596
597
598     def _download_current_segment(self):
599         """
600         I download, validate, decode, decrypt, and assemble the segment
601         that this Retrieve is currently responsible for downloading.
602         """
603         assert len(self._active_readers) >= self._required_shares
604         if self._current_segment > self._last_segment:
605             # No more segments to download, we're done.
606             self.log("got plaintext, done")
607             return self._done()
608         self.log("on segment %d of %d" %
609                  (self._current_segment + 1, self._num_segments))
610         d = self._process_segment(self._current_segment)
611         d.addCallback(lambda ign: self.loop())
612         return d
613
614     def _process_segment(self, segnum):
615         """
616         I download, validate, decode, and decrypt one segment of the
617         file that this Retrieve is retrieving. This means coordinating
618         the process of getting k blocks of that file, validating them,
619         assembling them into one segment with the decoder, and then
620         decrypting them.
621         """
622         self.log("processing segment %d" % segnum)
623
624         # TODO: The old code uses a marker. Should this code do that
625         # too? What did the Marker do?
626         assert len(self._active_readers) >= self._required_shares
627
628         # We need to ask each of our active readers for its block and
629         # salt. We will then validate those. If validation is
630         # successful, we will assemble the results into plaintext.
631         ds = []
632         for reader in self._active_readers:
633             started = time.time()
634             d = reader.get_block_and_salt(segnum)
635             d2 = self._get_needed_hashes(reader, segnum)
636             dl = defer.DeferredList([d, d2], consumeErrors=True)
637             dl.addCallback(self._validate_block, segnum, reader, reader.server, started)
638             dl.addErrback(self._validation_or_decoding_failed, [reader])
639             ds.append(dl)
640         # _validation_or_decoding_failed is supposed to eat any recoverable
641         # errors (like corrupt shares), returning a None when that happens.
642         # If it raises an exception itself, or if it can't handle the error,
643         # the download should fail. So we can use gatherResults() here.
644         dl = deferredutil.gatherResults(ds)
645         if self._verify:
646             dl.addCallback(lambda ignored: "")
647             dl.addCallback(self._set_segment)
648         else:
649             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
650         return dl
651
652
653     def _maybe_decode_and_decrypt_segment(self, results, segnum):
654         """
655         I take the results of fetching and validating the blocks from
656         _process_segment. If validation and fetching succeeded without
657         incident, I will proceed with decoding and decryption. Otherwise, I
658         will do nothing.
659         """
660         self.log("trying to decode and decrypt segment %d" % segnum)
661
662         # 'results' is the output of a gatherResults set up in
663         # _process_segment(). Each component Deferred will either contain the
664         # non-Failure output of _validate_block() for a single block (i.e.
665         # {segnum:(block,salt)}), or None if _validate_block threw an
666         # exception and _validation_or_decoding_failed handled it (by
667         # dropping that server).
668
669         if None in results:
670             self.log("some validation operations failed; not proceeding")
671             return defer.succeed(None)
672         self.log("everything looks ok, building segment %d" % segnum)
673         d = self._decode_blocks(results, segnum)
674         d.addCallback(self._decrypt_segment)
675         d.addErrback(self._validation_or_decoding_failed,
676                      self._active_readers)
677         # check to see whether we've been paused before writing
678         # anything.
679         d.addCallback(self._check_for_paused)
680         d.addCallback(self._check_for_stopped)
681         d.addCallback(self._set_segment)
682         return d
683
684
685     def _set_segment(self, segment):
686         """
687         Given a plaintext segment, I register that segment with the
688         target that is handling the file download.
689         """
690         self.log("got plaintext for segment %d" % self._current_segment)
691         if self._current_segment == self._start_segment:
692             # We're on the first segment. It's possible that we want
693             # only some part of the end of this segment, and that we
694             # just downloaded the whole thing to get that part. If so,
695             # we need to account for that and give the reader just the
696             # data that they want.
697             n = self._offset % self._segment_size
698             self.log("stripping %d bytes off of the first segment" % n)
699             self.log("original segment length: %d" % len(segment))
700             segment = segment[n:]
701             self.log("new segment length: %d" % len(segment))
702
703         if self._current_segment == self._last_segment and self._read_length is not None:
704             # We're on the last segment. It's possible that we only want
705             # part of the beginning of this segment, and that we
706             # downloaded the whole thing anyway. Make sure to give the
707             # caller only the portion of the segment that they want to
708             # receive.
709             extra = self._read_length
710             if self._start_segment != self._last_segment:
711                 extra -= self._segment_size - \
712                             (self._offset % self._segment_size)
713             extra %= self._segment_size
714             self.log("original segment length: %d" % len(segment))
715             segment = segment[:extra]
716             self.log("new segment length: %d" % len(segment))
717             self.log("only taking %d bytes of the last segment" % extra)
718
719         if not self._verify:
720             self._consumer.write(segment)
721         else:
722             # we don't care about the plaintext if we are doing a verify.
723             segment = None
724         self._current_segment += 1
725
726
727     def _validation_or_decoding_failed(self, f, readers):
728         """
729         I am called when a block or a salt fails to correctly validate, or when
730         the decryption or decoding operation fails for some reason.  I react to
731         this failure by notifying the remote server of corruption, and then
732         removing the remote server from further activity.
733         """
734         assert isinstance(readers, list)
735         bad_shnums = [reader.shnum for reader in readers]
736
737         self.log("validation or decoding failed on share(s) %s, server(s) %s "
738                  ", segment %d: %s" % \
739                  (bad_shnums, readers, self._current_segment, str(f)))
740         for reader in readers:
741             self._mark_bad_share(reader.server, reader.shnum, reader, f)
742         return
743
744
745     def _validate_block(self, results, segnum, reader, server, started):
746         """
747         I validate a block from one share on a remote server.
748         """
749         # Grab the part of the block hash tree that is necessary to
750         # validate this block, then generate the block hash root.
751         self.log("validating share %d for segment %d" % (reader.shnum,
752                                                              segnum))
753         elapsed = time.time() - started
754         self._status.add_fetch_timing(server, elapsed)
755         self._set_current_status("validating blocks")
756         # Did we fail to fetch either of the things that we were
757         # supposed to? Fail if so.
758         if not results[0][0] and results[1][0]:
759             # handled by the errback handler.
760
761             # These all get batched into one query, so the resulting
762             # failure should be the same for all of them, so we can just
763             # use the first one.
764             assert isinstance(results[0][1], failure.Failure)
765
766             f = results[0][1]
767             raise CorruptShareError(server,
768                                     reader.shnum,
769                                     "Connection error: %s" % str(f))
770
771         block_and_salt, block_and_sharehashes = results
772         block, salt = block_and_salt[1]
773         blockhashes, sharehashes = block_and_sharehashes[1]
774
775         blockhashes = dict(enumerate(blockhashes[1]))
776         self.log("the reader gave me the following blockhashes: %s" % \
777                  blockhashes.keys())
778         self.log("the reader gave me the following sharehashes: %s" % \
779                  sharehashes[1].keys())
780         bht = self._block_hash_trees[reader.shnum]
781
782         if bht.needed_hashes(segnum, include_leaf=True):
783             try:
784                 bht.set_hashes(blockhashes)
785             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
786                     IndexError), e:
787                 raise CorruptShareError(server,
788                                         reader.shnum,
789                                         "block hash tree failure: %s" % e)
790
791         if self._version == MDMF_VERSION:
792             blockhash = hashutil.block_hash(salt + block)
793         else:
794             blockhash = hashutil.block_hash(block)
795         # If this works without an error, then validation is
796         # successful.
797         try:
798            bht.set_hashes(leaves={segnum: blockhash})
799         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
800                 IndexError), e:
801             raise CorruptShareError(server,
802                                     reader.shnum,
803                                     "block hash tree failure: %s" % e)
804
805         # Reaching this point means that we know that this segment
806         # is correct. Now we need to check to see whether the share
807         # hash chain is also correct. 
808         # SDMF wrote share hash chains that didn't contain the
809         # leaves, which would be produced from the block hash tree.
810         # So we need to validate the block hash tree first. If
811         # successful, then bht[0] will contain the root for the
812         # shnum, which will be a leaf in the share hash tree, which
813         # will allow us to validate the rest of the tree.
814         if self.share_hash_tree.needed_hashes(reader.shnum,
815                                               include_leaf=True) or \
816                                               self._verify:
817             try:
818                 self.share_hash_tree.set_hashes(hashes=sharehashes[1],
819                                             leaves={reader.shnum: bht[0]})
820             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
821                     IndexError), e:
822                 raise CorruptShareError(server,
823                                         reader.shnum,
824                                         "corrupt hashes: %s" % e)
825
826         self.log('share %d is valid for segment %d' % (reader.shnum,
827                                                        segnum))
828         return {reader.shnum: (block, salt)}
829
830
831     def _get_needed_hashes(self, reader, segnum):
832         """
833         I get the hashes needed to validate segnum from the reader, then return
834         to my caller when this is done.
835         """
836         bht = self._block_hash_trees[reader.shnum]
837         needed = bht.needed_hashes(segnum, include_leaf=True)
838         # The root of the block hash tree is also a leaf in the share
839         # hash tree. So we don't need to fetch it from the remote
840         # server. In the case of files with one segment, this means that
841         # we won't fetch any block hash tree from the remote server,
842         # since the hash of each share of the file is the entire block
843         # hash tree, and is a leaf in the share hash tree. This is fine,
844         # since any share corruption will be detected in the share hash
845         # tree.
846         #needed.discard(0)
847         self.log("getting blockhashes for segment %d, share %d: %s" % \
848                  (segnum, reader.shnum, str(needed)))
849         d1 = reader.get_blockhashes(needed, force_remote=True)
850         if self.share_hash_tree.needed_hashes(reader.shnum):
851             need = self.share_hash_tree.needed_hashes(reader.shnum)
852             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
853                                                                  str(need)))
854             d2 = reader.get_sharehashes(need, force_remote=True)
855         else:
856             d2 = defer.succeed({}) # the logic in the next method
857                                    # expects a dict
858         dl = defer.DeferredList([d1, d2], consumeErrors=True)
859         return dl
860
861
862     def _decode_blocks(self, results, segnum):
863         """
864         I take a list of k blocks and salts, and decode that into a
865         single encrypted segment.
866         """
867         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
868         # want to merge them all
869         blocks_and_salts = {}
870         for d in results:
871             blocks_and_salts.update(d)
872
873         # All of these blocks should have the same salt; in SDMF, it is
874         # the file-wide IV, while in MDMF it is the per-segment salt. In
875         # either case, we just need to get one of them and use it.
876         #
877         # d.items()[0] is like (shnum, (block, salt))
878         # d.items()[0][1] is like (block, salt)
879         # d.items()[0][1][1] is the salt.
880         salt = blocks_and_salts.items()[0][1][1]
881         # Next, extract just the blocks from the dict. We'll use the
882         # salt in the next step.
883         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
884         d2 = dict(share_and_shareids)
885         shareids = []
886         shares = []
887         for shareid, share in d2.items():
888             shareids.append(shareid)
889             shares.append(share)
890
891         self._set_current_status("decoding")
892         started = time.time()
893         assert len(shareids) >= self._required_shares, len(shareids)
894         # zfec really doesn't want extra shares
895         shareids = shareids[:self._required_shares]
896         shares = shares[:self._required_shares]
897         self.log("decoding segment %d" % segnum)
898         if segnum == self._num_segments - 1:
899             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
900         else:
901             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
902         def _process(buffers):
903             segment = "".join(buffers)
904             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
905                      segnum=segnum,
906                      numsegs=self._num_segments,
907                      level=log.NOISY)
908             self.log(" joined length %d, datalength %d" %
909                      (len(segment), self._data_length))
910             if segnum == self._num_segments - 1:
911                 size_to_use = self._tail_data_size
912             else:
913                 size_to_use = self._segment_size
914             segment = segment[:size_to_use]
915             self.log(" segment len=%d" % len(segment))
916             self._status.accumulate_decode_time(time.time() - started)
917             return segment, salt
918         d.addCallback(_process)
919         return d
920
921
922     def _decrypt_segment(self, segment_and_salt):
923         """
924         I take a single segment and its salt, and decrypt it. I return
925         the plaintext of the segment that is in my argument.
926         """
927         segment, salt = segment_and_salt
928         self._set_current_status("decrypting")
929         self.log("decrypting segment %d" % self._current_segment)
930         started = time.time()
931         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
932         decryptor = AES(key)
933         plaintext = decryptor.process(segment)
934         self._status.accumulate_decrypt_time(time.time() - started)
935         return plaintext
936
937
938     def notify_server_corruption(self, server, shnum, reason):
939         rref = server.get_rref()
940         rref.callRemoteOnly("advise_corrupt_share",
941                             "mutable", self._storage_index, shnum, reason)
942
943
944     def _try_to_validate_privkey(self, enc_privkey, reader, server):
945         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
946         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
947         if alleged_writekey != self._node.get_writekey():
948             self.log("invalid privkey from %s shnum %d" %
949                      (reader, reader.shnum),
950                      level=log.WEIRD, umid="YIw4tA")
951             if self._verify:
952                 self.servermap.mark_bad_share(server, reader.shnum,
953                                               self.verinfo[-2])
954                 e = CorruptShareError(server,
955                                       reader.shnum,
956                                       "invalid privkey")
957                 f = failure.Failure(e)
958                 self._bad_shares.add((server, reader.shnum, f))
959             return
960
961         # it's good
962         self.log("got valid privkey from shnum %d on reader %s" %
963                  (reader.shnum, reader))
964         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
965         self._node._populate_encprivkey(enc_privkey)
966         self._node._populate_privkey(privkey)
967         self._need_privkey = False
968
969
970
971     def _done(self):
972         """
973         I am called by _download_current_segment when the download process
974         has finished successfully. After making some useful logging
975         statements, I return the decrypted contents to the owner of this
976         Retrieve object through self._done_deferred.
977         """
978         self._running = False
979         self._status.set_active(False)
980         now = time.time()
981         self._status.timings['total'] = now - self._started
982         self._status.timings['fetch'] = now - self._started_fetching
983         self._status.set_status("Finished")
984         self._status.set_progress(1.0)
985
986         # remember the encoding parameters, use them again next time
987         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
988          offsets_tuple) = self.verinfo
989         self._node._populate_required_shares(k)
990         self._node._populate_total_shares(N)
991
992         if self._verify:
993             ret = self._bad_shares
994             self.log("done verifying, found %d bad shares" % len(ret))
995         else:
996             # TODO: upload status here?
997             ret = self._consumer
998             self._consumer.unregisterProducer()
999         eventually(self._done_deferred.callback, ret)
1000
1001
1002     def _raise_notenoughshareserror(self):
1003         """
1004         I am called by _activate_enough_servers when there are not enough
1005         active servers left to complete the download. After making some
1006         useful logging statements, I throw an exception to that effect
1007         to the caller of this Retrieve object through
1008         self._done_deferred.
1009         """
1010
1011         format = ("ran out of servers: "
1012                   "have %(have)d of %(total)d segments "
1013                   "found %(bad)d bad shares "
1014                   "encoding %(k)d-of-%(n)d")
1015         args = {"have": self._current_segment,
1016                 "total": self._num_segments,
1017                 "need": self._last_segment,
1018                 "k": self._required_shares,
1019                 "n": self._total_shares,
1020                 "bad": len(self._bad_shares)}
1021         raise NotEnoughSharesError("%s, last failure: %s" %
1022                                    (format % args, str(self._last_failure)))
1023
1024     def _error(self, f):
1025         # all errors, including NotEnoughSharesError, land here
1026         self._running = False
1027         self._status.set_active(False)
1028         now = time.time()
1029         self._status.timings['total'] = now - self._started
1030         self._status.timings['fetch'] = now - self._started_fetching
1031         self._status.set_status("Failed")
1032         eventually(self._done_deferred.errback, f)