]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable/retrieve.py: rewrite partial-read handling
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
9      RemoteException
10
11 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
12      DownloadStopped, MDMF_VERSION, SDMF_VERSION
13 from allmydata.util.assertutil import _assert, precondition
14 from allmydata.util import hashutil, log, mathutil, deferredutil
15 from allmydata.util.dictutil import DictOfSets
16 from allmydata import hashtree, codec
17 from allmydata.storage.server import si_b2a
18 from pycryptopp.cipher.aes import AES
19 from pycryptopp.publickey import rsa
20
21 from allmydata.mutable.common import CorruptShareError, BadShareError, \
22      UncoordinatedWriteError
23 from allmydata.mutable.layout import MDMFSlotReadProxy
24
25 class RetrieveStatus:
26     implements(IRetrieveStatus)
27     statusid_counter = count(0)
28     def __init__(self):
29         self.timings = {}
30         self.timings["fetch_per_server"] = {}
31         self.timings["decode"] = 0.0
32         self.timings["decrypt"] = 0.0
33         self.timings["cumulative_verify"] = 0.0
34         self._problems = {}
35         self.active = True
36         self.storage_index = None
37         self.helper = False
38         self.encoding = ("?","?")
39         self.size = None
40         self.status = "Not started"
41         self.progress = 0.0
42         self.counter = self.statusid_counter.next()
43         self.started = time.time()
44
45     def get_started(self):
46         return self.started
47     def get_storage_index(self):
48         return self.storage_index
49     def get_encoding(self):
50         return self.encoding
51     def using_helper(self):
52         return self.helper
53     def get_size(self):
54         return self.size
55     def get_status(self):
56         return self.status
57     def get_progress(self):
58         return self.progress
59     def get_active(self):
60         return self.active
61     def get_counter(self):
62         return self.counter
63     def get_problems(self):
64         return self._problems
65
66     def add_fetch_timing(self, server, elapsed):
67         if server not in self.timings["fetch_per_server"]:
68             self.timings["fetch_per_server"][server] = []
69         self.timings["fetch_per_server"][server].append(elapsed)
70     def accumulate_decode_time(self, elapsed):
71         self.timings["decode"] += elapsed
72     def accumulate_decrypt_time(self, elapsed):
73         self.timings["decrypt"] += elapsed
74     def set_storage_index(self, si):
75         self.storage_index = si
76     def set_helper(self, helper):
77         self.helper = helper
78     def set_encoding(self, k, n):
79         self.encoding = (k, n)
80     def set_size(self, size):
81         self.size = size
82     def set_status(self, status):
83         self.status = status
84     def set_progress(self, value):
85         self.progress = value
86     def set_active(self, value):
87         self.active = value
88     def add_problem(self, server, f):
89         serverid = server.get_serverid()
90         self._problems[serverid] = f
91
92 class Marker:
93     pass
94
95 class Retrieve:
96     # this class is currently single-use. Eventually (in MDMF) we will make
97     # it multi-use, in which case you can call download(range) multiple
98     # times, and each will have a separate response chain. However the
99     # Retrieve object will remain tied to a specific version of the file, and
100     # will use a single ServerMap instance.
101     implements(IPushProducer)
102
103     def __init__(self, filenode, storage_broker, servermap, verinfo,
104                  fetch_privkey=False, verify=False):
105         self._node = filenode
106         assert self._node.get_pubkey()
107         self._storage_broker = storage_broker
108         self._storage_index = filenode.get_storage_index()
109         assert self._node.get_readkey()
110         self._last_failure = None
111         prefix = si_b2a(self._storage_index)[:5]
112         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
113         self._running = True
114         self._decoding = False
115         self._bad_shares = set()
116
117         self.servermap = servermap
118         assert self._node.get_pubkey()
119         self.verinfo = verinfo
120         # TODO: make it possible to use self.verinfo.datalength instead
121         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
122          offsets_tuple) = self.verinfo
123         self._data_length = datalength
124         # during repair, we may be called upon to grab the private key, since
125         # it wasn't picked up during a verify=False checker run, and we'll
126         # need it for repair to generate a new version.
127         self._need_privkey = verify or (fetch_privkey
128                                         and not self._node.get_privkey())
129
130         if self._need_privkey:
131             # TODO: Evaluate the need for this. We'll use it if we want
132             # to limit how many queries are on the wire for the privkey
133             # at once.
134             self._privkey_query_markers = [] # one Marker for each time we've
135                                              # tried to get the privkey.
136
137         # verify means that we are using the downloader logic to verify all
138         # of our shares. This tells the downloader a few things.
139         #
140         # 1. We need to download all of the shares.
141         # 2. We don't need to decode or decrypt the shares, since our
142         #    caller doesn't care about the plaintext, only the
143         #    information about which shares are or are not valid.
144         # 3. When we are validating readers, we need to validate the
145         #    signature on the prefix. Do we? We already do this in the
146         #    servermap update?
147         self._verify = verify
148
149         self._status = RetrieveStatus()
150         self._status.set_storage_index(self._storage_index)
151         self._status.set_helper(False)
152         self._status.set_progress(0.0)
153         self._status.set_active(True)
154         self._status.set_size(datalength)
155         self._status.set_encoding(k, N)
156         self.readers = {}
157         self._stopped = False
158         self._pause_deferred = None
159         self._offset = None
160         self._read_length = None
161         self.log("got seqnum %d" % self.verinfo[0])
162
163
164     def get_status(self):
165         return self._status
166
167     def log(self, *args, **kwargs):
168         if "parent" not in kwargs:
169             kwargs["parent"] = self._log_number
170         if "facility" not in kwargs:
171             kwargs["facility"] = "tahoe.mutable.retrieve"
172         return log.msg(*args, **kwargs)
173
174     def _set_current_status(self, state):
175         seg = "%d/%d" % (self._current_segment, self._last_segment)
176         self._status.set_status("segment %s (%s)" % (seg, state))
177
178     ###################
179     # IPushProducer
180
181     def pauseProducing(self):
182         """
183         I am called by my download target if we have produced too much
184         data for it to handle. I make the downloader stop producing new
185         data until my resumeProducing method is called.
186         """
187         if self._pause_deferred is not None:
188             return
189
190         # fired when the download is unpaused.
191         self._old_status = self._status.get_status()
192         self._set_current_status("paused")
193
194         self._pause_deferred = defer.Deferred()
195
196
197     def resumeProducing(self):
198         """
199         I am called by my download target once it is ready to begin
200         receiving data again.
201         """
202         if self._pause_deferred is None:
203             return
204
205         p = self._pause_deferred
206         self._pause_deferred = None
207         self._status.set_status(self._old_status)
208
209         eventually(p.callback, None)
210
211     def stopProducing(self):
212         self._stopped = True
213         self.resumeProducing()
214
215
216     def _check_for_paused(self, res):
217         """
218         I am called just before a write to the consumer. I return a
219         Deferred that eventually fires with the data that is to be
220         written to the consumer. If the download has not been paused,
221         the Deferred fires immediately. Otherwise, the Deferred fires
222         when the downloader is unpaused.
223         """
224         if self._pause_deferred is not None:
225             d = defer.Deferred()
226             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
227             return d
228         return res
229
230     def _check_for_stopped(self, res):
231         if self._stopped:
232             raise DownloadStopped("our Consumer called stopProducing()")
233         return res
234
235
236     def download(self, consumer=None, offset=0, size=None):
237         precondition(self._verify or IConsumer.providedBy(consumer))
238         if size is None:
239             size = self._data_length - offset
240         if self._verify:
241             _assert(size == self._data_length, (size, self._data_length))
242         self.log("starting download")
243         self._done_deferred = defer.Deferred()
244         if consumer:
245             self._consumer = consumer
246             # we provide IPushProducer, so streaming=True, per IConsumer.
247             self._consumer.registerProducer(self, streaming=True)
248         self._started = time.time()
249         self._started_fetching = time.time()
250         if size == 0:
251             # short-circuit the rest of the process
252             self._done()
253         else:
254             self._start_download(consumer, offset, size)
255         return self._done_deferred
256
257     def _start_download(self, consumer, offset, size):
258         precondition((0 <= offset < self._data_length)
259                      and (size > 0)
260                      and (offset+size <= self._data_length),
261                      (offset, size, self._data_length))
262
263         self._offset = offset
264         self._read_length = size
265         self._setup_encoding_parameters()
266         self._setup_download()
267
268         # The download process beyond this is a state machine.
269         # _add_active_servers will select the servers that we want to use
270         # for the download, and then attempt to start downloading. After
271         # each segment, it will check for doneness, reacting to broken
272         # servers and corrupt shares as necessary. If it runs out of good
273         # servers before downloading all of the segments, _done_deferred
274         # will errback.  Otherwise, it will eventually callback with the
275         # contents of the mutable file.
276         self.loop()
277
278     def loop(self):
279         d = fireEventually(None) # avoid #237 recursion limit problem
280         d.addCallback(lambda ign: self._activate_enough_servers())
281         d.addCallback(lambda ign: self._download_current_segment())
282         # when we're done, _download_current_segment will call _done. If we
283         # aren't, it will call loop() again.
284         d.addErrback(self._error)
285
286     def _setup_download(self):
287         self._status.set_status("Retrieving Shares")
288
289         # how many shares do we need?
290         (seqnum,
291          root_hash,
292          IV,
293          segsize,
294          datalength,
295          k,
296          N,
297          prefix,
298          offsets_tuple) = self.verinfo
299
300         # first, which servers can we use?
301         versionmap = self.servermap.make_versionmap()
302         shares = versionmap[self.verinfo]
303         # this sharemap is consumed as we decide to send requests
304         self.remaining_sharemap = DictOfSets()
305         for (shnum, server, timestamp) in shares:
306             self.remaining_sharemap.add(shnum, server)
307             # Reuse the SlotReader from the servermap.
308             key = (self.verinfo, server.get_serverid(),
309                    self._storage_index, shnum)
310             if key in self.servermap.proxies:
311                 reader = self.servermap.proxies[key]
312             else:
313                 reader = MDMFSlotReadProxy(server.get_rref(),
314                                            self._storage_index, shnum, None)
315             reader.server = server
316             self.readers[shnum] = reader
317
318         if len(self.remaining_sharemap) < k:
319             self._raise_notenoughshareserror()
320
321         self.shares = {} # maps shnum to validated blocks
322         self._active_readers = [] # list of active readers for this dl.
323         self._block_hash_trees = {} # shnum => hashtree
324
325         for i in xrange(self._total_shares):
326             # So we don't have to do this later.
327             self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
328
329         # We need one share hash tree for the entire file; its leaves
330         # are the roots of the block hash trees for the shares that
331         # comprise it, and its root is in the verinfo.
332         self.share_hash_tree = hashtree.IncompleteHashTree(N)
333         self.share_hash_tree.set_hashes({0: root_hash})
334
335     def decode(self, blocks_and_salts, segnum):
336         """
337         I am a helper method that the mutable file update process uses
338         as a shortcut to decode and decrypt the segments that it needs
339         to fetch in order to perform a file update. I take in a
340         collection of blocks and salts, and pick some of those to make a
341         segment with. I return the plaintext associated with that
342         segment.
343         """
344         # We don't need the block hash trees in this case.
345         self._block_hash_trees = None
346         self._offset = 0
347         self._read_length = self._data_length
348         self._setup_encoding_parameters()
349
350         # _decode_blocks() expects the output of a gatherResults that
351         # contains the outputs of _validate_block() (each of which is a dict
352         # mapping shnum to (block,salt) bytestrings).
353         d = self._decode_blocks([blocks_and_salts], segnum)
354         d.addCallback(self._decrypt_segment)
355         return d
356
357
358     def _setup_encoding_parameters(self):
359         """
360         I set up the encoding parameters, including k, n, the number
361         of segments associated with this file, and the segment decoders.
362         """
363         (seqnum,
364          root_hash,
365          IV,
366          segsize,
367          datalength,
368          k,
369          n,
370          known_prefix,
371          offsets_tuple) = self.verinfo
372         self._required_shares = k
373         self._total_shares = n
374         self._segment_size = segsize
375         #self._data_length = datalength # set during __init__()
376
377         if not IV:
378             self._version = MDMF_VERSION
379         else:
380             self._version = SDMF_VERSION
381
382         if datalength and segsize:
383             self._num_segments = mathutil.div_ceil(datalength, segsize)
384             self._tail_data_size = datalength % segsize
385         else:
386             self._num_segments = 0
387             self._tail_data_size = 0
388
389         self._segment_decoder = codec.CRSDecoder()
390         self._segment_decoder.set_params(segsize, k, n)
391
392         if  not self._tail_data_size:
393             self._tail_data_size = segsize
394
395         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
396                                                          self._required_shares)
397         if self._tail_segment_size == self._segment_size:
398             self._tail_decoder = self._segment_decoder
399         else:
400             self._tail_decoder = codec.CRSDecoder()
401             self._tail_decoder.set_params(self._tail_segment_size,
402                                           self._required_shares,
403                                           self._total_shares)
404
405         self.log("got encoding parameters: "
406                  "k: %d "
407                  "n: %d "
408                  "%d segments of %d bytes each (%d byte tail segment)" % \
409                  (k, n, self._num_segments, self._segment_size,
410                   self._tail_segment_size))
411
412         # Our last task is to tell the downloader where to start and
413         # where to stop. We use three parameters for that:
414         #   - self._start_segment: the segment that we need to start
415         #     downloading from.
416         #   - self._current_segment: the next segment that we need to
417         #     download.
418         #   - self._last_segment: The last segment that we were asked to
419         #     download.
420         #
421         #  We say that the download is complete when
422         #  self._current_segment > self._last_segment. We use
423         #  self._start_segment and self._last_segment to know when to
424         #  strip things off of segments, and how much to strip.
425         if self._offset:
426             self.log("got offset: %d" % self._offset)
427             # our start segment is the first segment containing the
428             # offset we were given.
429             start = self._offset // self._segment_size
430
431             _assert(start <= self._num_segments,
432                     start=start, num_segments=self._num_segments,
433                     offset=self._offset, segment_size=self._segment_size)
434             self._start_segment = start
435             self.log("got start segment: %d" % self._start_segment)
436         else:
437             self._start_segment = 0
438
439         # We might want to read only part of the file, and need to figure out
440         # where to stop reading. Our end segment is the last segment
441         # containing part of the segment that we were asked to read.
442         _assert(self._read_length > 0, self._read_length)
443         end_data = self._offset + self._read_length
444
445         # We don't actually need to read the byte at end_data, but the one
446         # before it.
447         end = (end_data - 1) // self._segment_size
448         _assert(0 <= end < self._num_segments,
449                 end=end, num_segments=self._num_segments,
450                 end_data=end_data, offset=self._offset,
451                 read_length=self._read_length, segment_size=self._segment_size)
452         self._last_segment = end
453         self.log("got end segment: %d" % self._last_segment)
454
455         self._current_segment = self._start_segment
456
457     def _activate_enough_servers(self):
458         """
459         I populate self._active_readers with enough active readers to
460         retrieve the contents of this mutable file. I am called before
461         downloading starts, and (eventually) after each validation
462         error, connection error, or other problem in the download.
463         """
464         # TODO: It would be cool to investigate other heuristics for
465         # reader selection. For instance, the cost (in time the user
466         # spends waiting for their file) of selecting a really slow server
467         # that happens to have a primary share is probably more than
468         # selecting a really fast server that doesn't have a primary
469         # share. Maybe the servermap could be extended to provide this
470         # information; it could keep track of latency information while
471         # it gathers more important data, and then this routine could
472         # use that to select active readers.
473         #
474         # (these and other questions would be easier to answer with a
475         #  robust, configurable tahoe-lafs simulator, which modeled node
476         #  failures, differences in node speed, and other characteristics
477         #  that we expect storage servers to have.  You could have
478         #  presets for really stable grids (like allmydata.com),
479         #  friendnets, make it easy to configure your own settings, and
480         #  then simulate the effect of big changes on these use cases
481         #  instead of just reasoning about what the effect might be. Out
482         #  of scope for MDMF, though.)
483
484         # XXX: Why don't format= log messages work here?
485
486         known_shnums = set(self.remaining_sharemap.keys())
487         used_shnums = set([r.shnum for r in self._active_readers])
488         unused_shnums = known_shnums - used_shnums
489
490         if self._verify:
491             new_shnums = unused_shnums # use them all
492         elif len(self._active_readers) < self._required_shares:
493             # need more shares
494             more = self._required_shares - len(self._active_readers)
495             # We favor lower numbered shares, since FEC is faster with
496             # primary shares than with other shares, and lower-numbered
497             # shares are more likely to be primary than higher numbered
498             # shares.
499             new_shnums = sorted(unused_shnums)[:more]
500             if len(new_shnums) < more:
501                 # We don't have enough readers to retrieve the file; fail.
502                 self._raise_notenoughshareserror()
503         else:
504             new_shnums = []
505
506         self.log("adding %d new servers to the active list" % len(new_shnums))
507         for shnum in new_shnums:
508             reader = self.readers[shnum]
509             self._active_readers.append(reader)
510             self.log("added reader for share %d" % shnum)
511             # Each time we add a reader, we check to see if we need the
512             # private key. If we do, we politely ask for it and then continue
513             # computing. If we find that we haven't gotten it at the end of
514             # segment decoding, then we'll take more drastic measures.
515             if self._need_privkey and not self._node.is_readonly():
516                 d = reader.get_encprivkey()
517                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
518                 # XXX: don't just drop the Deferred. We need error-reporting
519                 # but not flow-control here.
520
521     def _try_to_validate_prefix(self, prefix, reader):
522         """
523         I check that the prefix returned by a candidate server for
524         retrieval matches the prefix that the servermap knows about
525         (and, hence, the prefix that was validated earlier). If it does,
526         I return True, which means that I approve of the use of the
527         candidate server for segment retrieval. If it doesn't, I return
528         False, which means that another server must be chosen.
529         """
530         (seqnum,
531          root_hash,
532          IV,
533          segsize,
534          datalength,
535          k,
536          N,
537          known_prefix,
538          offsets_tuple) = self.verinfo
539         if known_prefix != prefix:
540             self.log("prefix from share %d doesn't match" % reader.shnum)
541             raise UncoordinatedWriteError("Mismatched prefix -- this could "
542                                           "indicate an uncoordinated write")
543         # Otherwise, we're okay -- no issues.
544
545     def _mark_bad_share(self, server, shnum, reader, f):
546         """
547         I mark the given (server, shnum) as a bad share, which means that it
548         will not be used anywhere else.
549
550         There are several reasons to want to mark something as a bad
551         share. These include:
552
553             - A connection error to the server.
554             - A mismatched prefix (that is, a prefix that does not match
555               our local conception of the version information string).
556             - A failing block hash, salt hash, share hash, or other
557               integrity check.
558
559         This method will ensure that readers that we wish to mark bad
560         (for these reasons or other reasons) are not used for the rest
561         of the download. Additionally, it will attempt to tell the
562         remote server (with no guarantee of success) that its share is
563         corrupt.
564         """
565         self.log("marking share %d on server %s as bad" % \
566                  (shnum, server.get_name()))
567         prefix = self.verinfo[-2]
568         self.servermap.mark_bad_share(server, shnum, prefix)
569         self._bad_shares.add((server, shnum, f))
570         self._status.add_problem(server, f)
571         self._last_failure = f
572
573         # Remove the reader from _active_readers
574         self._active_readers.remove(reader)
575         for shnum in list(self.remaining_sharemap.keys()):
576             self.remaining_sharemap.discard(shnum, reader.server)
577
578         if f.check(BadShareError):
579             self.notify_server_corruption(server, shnum, str(f.value))
580
581     def _download_current_segment(self):
582         """
583         I download, validate, decode, decrypt, and assemble the segment
584         that this Retrieve is currently responsible for downloading.
585         """
586
587         if self._current_segment > self._last_segment:
588             # No more segments to download, we're done.
589             self.log("got plaintext, done")
590             return self._done()
591         elif self._verify and len(self._active_readers) == 0:
592             self.log("no more good shares, no need to keep verifying")
593             return self._done()
594         self.log("on segment %d of %d" %
595                  (self._current_segment + 1, self._num_segments))
596         d = self._process_segment(self._current_segment)
597         d.addCallback(lambda ign: self.loop())
598         return d
599
600     def _process_segment(self, segnum):
601         """
602         I download, validate, decode, and decrypt one segment of the
603         file that this Retrieve is retrieving. This means coordinating
604         the process of getting k blocks of that file, validating them,
605         assembling them into one segment with the decoder, and then
606         decrypting them.
607         """
608         self.log("processing segment %d" % segnum)
609
610         # TODO: The old code uses a marker. Should this code do that
611         # too? What did the Marker do?
612
613         # We need to ask each of our active readers for its block and
614         # salt. We will then validate those. If validation is
615         # successful, we will assemble the results into plaintext.
616         ds = []
617         for reader in self._active_readers:
618             started = time.time()
619             d1 = reader.get_block_and_salt(segnum)
620             d2,d3 = self._get_needed_hashes(reader, segnum)
621             d = deferredutil.gatherResults([d1,d2,d3])
622             d.addCallback(self._validate_block, segnum, reader, reader.server, started)
623             # _handle_bad_share takes care of recoverable errors (by dropping
624             # that share and returning None). Any other errors (i.e. code
625             # bugs) are passed through and cause the retrieve to fail.
626             d.addErrback(self._handle_bad_share, [reader])
627             ds.append(d)
628         dl = deferredutil.gatherResults(ds)
629         if self._verify:
630             dl.addCallback(lambda ignored: "")
631             dl.addCallback(self._set_segment)
632         else:
633             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
634         return dl
635
636
637     def _maybe_decode_and_decrypt_segment(self, results, segnum):
638         """
639         I take the results of fetching and validating the blocks from
640         _process_segment. If validation and fetching succeeded without
641         incident, I will proceed with decoding and decryption. Otherwise, I
642         will do nothing.
643         """
644         self.log("trying to decode and decrypt segment %d" % segnum)
645
646         # 'results' is the output of a gatherResults set up in
647         # _process_segment(). Each component Deferred will either contain the
648         # non-Failure output of _validate_block() for a single block (i.e.
649         # {segnum:(block,salt)}), or None if _validate_block threw an
650         # exception and _validation_or_decoding_failed handled it (by
651         # dropping that server).
652
653         if None in results:
654             self.log("some validation operations failed; not proceeding")
655             return defer.succeed(None)
656         self.log("everything looks ok, building segment %d" % segnum)
657         d = self._decode_blocks(results, segnum)
658         d.addCallback(self._decrypt_segment)
659         # check to see whether we've been paused before writing
660         # anything.
661         d.addCallback(self._check_for_paused)
662         d.addCallback(self._check_for_stopped)
663         d.addCallback(self._set_segment)
664         return d
665
666
667     def _set_segment(self, segment):
668         """
669         Given a plaintext segment, I register that segment with the
670         target that is handling the file download.
671         """
672         self.log("got plaintext for segment %d" % self._current_segment)
673
674         if self._read_length == 0:
675             self.log("on first+last segment, size=0, using 0 bytes")
676             segment = b""
677
678         if self._current_segment == self._last_segment:
679             # trim off the tail
680             wanted = (self._offset + self._read_length) % self._segment_size
681             if wanted != 0:
682                 self.log("on the last segment: using first %d bytes" % wanted)
683                 segment = segment[:wanted]
684             else:
685                 self.log("on the last segment: using all %d bytes" %
686                          len(segment))
687
688         if self._current_segment == self._start_segment:
689             # Trim off the head, if offset != 0. This should also work if
690             # start==last, because we trim the tail first.
691             skip = self._offset % self._segment_size
692             self.log("on the first segment: skipping first %d bytes" % skip)
693             segment = segment[skip:]
694
695         if not self._verify:
696             self._consumer.write(segment)
697         else:
698             # we don't care about the plaintext if we are doing a verify.
699             segment = None
700         self._current_segment += 1
701
702
703     def _handle_bad_share(self, f, readers):
704         """
705         I am called when a block or a salt fails to correctly validate, or when
706         the decryption or decoding operation fails for some reason.  I react to
707         this failure by notifying the remote server of corruption, and then
708         removing the remote server from further activity.
709         """
710         # these are the errors we can tolerate: by giving up on this share
711         # and finding others to replace it. Any other errors (i.e. coding
712         # bugs) are re-raised, causing the download to fail.
713         f.trap(DeadReferenceError, RemoteException, BadShareError)
714
715         # DeadReferenceError happens when we try to fetch data from a server
716         # that has gone away. RemoteException happens if the server had an
717         # internal error. BadShareError encompasses: (UnknownVersionError,
718         # LayoutInvalid, struct.error) which happen when we get obviously
719         # wrong data, and CorruptShareError which happens later, when we
720         # perform integrity checks on the data.
721
722         assert isinstance(readers, list)
723         bad_shnums = [reader.shnum for reader in readers]
724
725         self.log("validation or decoding failed on share(s) %s, server(s) %s "
726                  ", segment %d: %s" % \
727                  (bad_shnums, readers, self._current_segment, str(f)))
728         for reader in readers:
729             self._mark_bad_share(reader.server, reader.shnum, reader, f)
730         return None
731
732
733     def _validate_block(self, results, segnum, reader, server, started):
734         """
735         I validate a block from one share on a remote server.
736         """
737         # Grab the part of the block hash tree that is necessary to
738         # validate this block, then generate the block hash root.
739         self.log("validating share %d for segment %d" % (reader.shnum,
740                                                              segnum))
741         elapsed = time.time() - started
742         self._status.add_fetch_timing(server, elapsed)
743         self._set_current_status("validating blocks")
744
745         block_and_salt, blockhashes, sharehashes = results
746         block, salt = block_and_salt
747         assert type(block) is str, (block, salt)
748
749         blockhashes = dict(enumerate(blockhashes))
750         self.log("the reader gave me the following blockhashes: %s" % \
751                  blockhashes.keys())
752         self.log("the reader gave me the following sharehashes: %s" % \
753                  sharehashes.keys())
754         bht = self._block_hash_trees[reader.shnum]
755
756         if bht.needed_hashes(segnum, include_leaf=True):
757             try:
758                 bht.set_hashes(blockhashes)
759             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
760                     IndexError), e:
761                 raise CorruptShareError(server,
762                                         reader.shnum,
763                                         "block hash tree failure: %s" % e)
764
765         if self._version == MDMF_VERSION:
766             blockhash = hashutil.block_hash(salt + block)
767         else:
768             blockhash = hashutil.block_hash(block)
769         # If this works without an error, then validation is
770         # successful.
771         try:
772            bht.set_hashes(leaves={segnum: blockhash})
773         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
774                 IndexError), e:
775             raise CorruptShareError(server,
776                                     reader.shnum,
777                                     "block hash tree failure: %s" % e)
778
779         # Reaching this point means that we know that this segment
780         # is correct. Now we need to check to see whether the share
781         # hash chain is also correct.
782         # SDMF wrote share hash chains that didn't contain the
783         # leaves, which would be produced from the block hash tree.
784         # So we need to validate the block hash tree first. If
785         # successful, then bht[0] will contain the root for the
786         # shnum, which will be a leaf in the share hash tree, which
787         # will allow us to validate the rest of the tree.
788         try:
789             self.share_hash_tree.set_hashes(hashes=sharehashes,
790                                         leaves={reader.shnum: bht[0]})
791         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
792                 IndexError), e:
793             raise CorruptShareError(server,
794                                     reader.shnum,
795                                     "corrupt hashes: %s" % e)
796
797         self.log('share %d is valid for segment %d' % (reader.shnum,
798                                                        segnum))
799         return {reader.shnum: (block, salt)}
800
801
802     def _get_needed_hashes(self, reader, segnum):
803         """
804         I get the hashes needed to validate segnum from the reader, then return
805         to my caller when this is done.
806         """
807         bht = self._block_hash_trees[reader.shnum]
808         needed = bht.needed_hashes(segnum, include_leaf=True)
809         # The root of the block hash tree is also a leaf in the share
810         # hash tree. So we don't need to fetch it from the remote
811         # server. In the case of files with one segment, this means that
812         # we won't fetch any block hash tree from the remote server,
813         # since the hash of each share of the file is the entire block
814         # hash tree, and is a leaf in the share hash tree. This is fine,
815         # since any share corruption will be detected in the share hash
816         # tree.
817         #needed.discard(0)
818         self.log("getting blockhashes for segment %d, share %d: %s" % \
819                  (segnum, reader.shnum, str(needed)))
820         # TODO is force_remote necessary here?
821         d1 = reader.get_blockhashes(needed, force_remote=False)
822         if self.share_hash_tree.needed_hashes(reader.shnum):
823             need = self.share_hash_tree.needed_hashes(reader.shnum)
824             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
825                                                                  str(need)))
826             d2 = reader.get_sharehashes(need, force_remote=False)
827         else:
828             d2 = defer.succeed({}) # the logic in the next method
829                                    # expects a dict
830         return d1,d2
831
832
833     def _decode_blocks(self, results, segnum):
834         """
835         I take a list of k blocks and salts, and decode that into a
836         single encrypted segment.
837         """
838         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
839         # want to merge them all
840         blocks_and_salts = {}
841         for d in results:
842             blocks_and_salts.update(d)
843
844         # All of these blocks should have the same salt; in SDMF, it is
845         # the file-wide IV, while in MDMF it is the per-segment salt. In
846         # either case, we just need to get one of them and use it.
847         #
848         # d.items()[0] is like (shnum, (block, salt))
849         # d.items()[0][1] is like (block, salt)
850         # d.items()[0][1][1] is the salt.
851         salt = blocks_and_salts.items()[0][1][1]
852         # Next, extract just the blocks from the dict. We'll use the
853         # salt in the next step.
854         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
855         d2 = dict(share_and_shareids)
856         shareids = []
857         shares = []
858         for shareid, share in d2.items():
859             shareids.append(shareid)
860             shares.append(share)
861
862         self._set_current_status("decoding")
863         started = time.time()
864         assert len(shareids) >= self._required_shares, len(shareids)
865         # zfec really doesn't want extra shares
866         shareids = shareids[:self._required_shares]
867         shares = shares[:self._required_shares]
868         self.log("decoding segment %d" % segnum)
869         if segnum == self._num_segments - 1:
870             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
871         else:
872             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
873         def _process(buffers):
874             segment = "".join(buffers)
875             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
876                      segnum=segnum,
877                      numsegs=self._num_segments,
878                      level=log.NOISY)
879             self.log(" joined length %d, datalength %d" %
880                      (len(segment), self._data_length))
881             if segnum == self._num_segments - 1:
882                 size_to_use = self._tail_data_size
883             else:
884                 size_to_use = self._segment_size
885             segment = segment[:size_to_use]
886             self.log(" segment len=%d" % len(segment))
887             self._status.accumulate_decode_time(time.time() - started)
888             return segment, salt
889         d.addCallback(_process)
890         return d
891
892
893     def _decrypt_segment(self, segment_and_salt):
894         """
895         I take a single segment and its salt, and decrypt it. I return
896         the plaintext of the segment that is in my argument.
897         """
898         segment, salt = segment_and_salt
899         self._set_current_status("decrypting")
900         self.log("decrypting segment %d" % self._current_segment)
901         started = time.time()
902         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
903         decryptor = AES(key)
904         plaintext = decryptor.process(segment)
905         self._status.accumulate_decrypt_time(time.time() - started)
906         return plaintext
907
908
909     def notify_server_corruption(self, server, shnum, reason):
910         rref = server.get_rref()
911         rref.callRemoteOnly("advise_corrupt_share",
912                             "mutable", self._storage_index, shnum, reason)
913
914
915     def _try_to_validate_privkey(self, enc_privkey, reader, server):
916         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
917         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
918         if alleged_writekey != self._node.get_writekey():
919             self.log("invalid privkey from %s shnum %d" %
920                      (reader, reader.shnum),
921                      level=log.WEIRD, umid="YIw4tA")
922             if self._verify:
923                 self.servermap.mark_bad_share(server, reader.shnum,
924                                               self.verinfo[-2])
925                 e = CorruptShareError(server,
926                                       reader.shnum,
927                                       "invalid privkey")
928                 f = failure.Failure(e)
929                 self._bad_shares.add((server, reader.shnum, f))
930             return
931
932         # it's good
933         self.log("got valid privkey from shnum %d on reader %s" %
934                  (reader.shnum, reader))
935         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
936         self._node._populate_encprivkey(enc_privkey)
937         self._node._populate_privkey(privkey)
938         self._need_privkey = False
939
940
941
942     def _done(self):
943         """
944         I am called by _download_current_segment when the download process
945         has finished successfully. After making some useful logging
946         statements, I return the decrypted contents to the owner of this
947         Retrieve object through self._done_deferred.
948         """
949         self._running = False
950         self._status.set_active(False)
951         now = time.time()
952         self._status.timings['total'] = now - self._started
953         self._status.timings['fetch'] = now - self._started_fetching
954         self._status.set_status("Finished")
955         self._status.set_progress(1.0)
956
957         # remember the encoding parameters, use them again next time
958         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
959          offsets_tuple) = self.verinfo
960         self._node._populate_required_shares(k)
961         self._node._populate_total_shares(N)
962
963         if self._verify:
964             ret = self._bad_shares
965             self.log("done verifying, found %d bad shares" % len(ret))
966         else:
967             # TODO: upload status here?
968             ret = self._consumer
969             self._consumer.unregisterProducer()
970         eventually(self._done_deferred.callback, ret)
971
972
973     def _raise_notenoughshareserror(self):
974         """
975         I am called when there are not enough active servers left to complete
976         the download. After making some useful logging statements, I throw an
977         exception to that effect to the caller of this Retrieve object through
978         self._done_deferred.
979         """
980
981         format = ("ran out of servers: "
982                   "have %(have)d of %(total)d segments; "
983                   "found %(bad)d bad shares; "
984                   "have %(remaining)d remaining shares of the right version; "
985                   "encoding %(k)d-of-%(n)d")
986         args = {"have": self._current_segment,
987                 "total": self._num_segments,
988                 "need": self._last_segment,
989                 "k": self._required_shares,
990                 "n": self._total_shares,
991                 "bad": len(self._bad_shares),
992                 "remaining": len(self.remaining_sharemap),
993                }
994         raise NotEnoughSharesError("%s, last failure: %s" %
995                                    (format % args, str(self._last_failure)))
996
997     def _error(self, f):
998         # all errors, including NotEnoughSharesError, land here
999         self._running = False
1000         self._status.set_active(False)
1001         now = time.time()
1002         self._status.timings['total'] = now - self._started
1003         self._status.timings['fetch'] = now - self._started_fetching
1004         self._status.set_status("Failed")
1005         eventually(self._done_deferred.errback, f)