]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/mutable/retrieve.py
mutable/retrieve.py: remove all bare assert()s
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
1
2 import time
3 from itertools import count
4 from zope.interface import implements
5 from twisted.internet import defer
6 from twisted.python import failure
7 from twisted.internet.interfaces import IPushProducer, IConsumer
8 from foolscap.api import eventually, fireEventually, DeadReferenceError, \
9      RemoteException
10
11 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
12      DownloadStopped, MDMF_VERSION, SDMF_VERSION
13 from allmydata.util.assertutil import _assert, precondition
14 from allmydata.util import hashutil, log, mathutil, deferredutil
15 from allmydata.util.dictutil import DictOfSets
16 from allmydata import hashtree, codec
17 from allmydata.storage.server import si_b2a
18 from pycryptopp.cipher.aes import AES
19 from pycryptopp.publickey import rsa
20
21 from allmydata.mutable.common import CorruptShareError, BadShareError, \
22      UncoordinatedWriteError
23 from allmydata.mutable.layout import MDMFSlotReadProxy
24
25 class RetrieveStatus:
26     implements(IRetrieveStatus)
27     statusid_counter = count(0)
28     def __init__(self):
29         self.timings = {}
30         self.timings["fetch_per_server"] = {}
31         self.timings["decode"] = 0.0
32         self.timings["decrypt"] = 0.0
33         self.timings["cumulative_verify"] = 0.0
34         self._problems = {}
35         self.active = True
36         self.storage_index = None
37         self.helper = False
38         self.encoding = ("?","?")
39         self.size = None
40         self.status = "Not started"
41         self.progress = 0.0
42         self.counter = self.statusid_counter.next()
43         self.started = time.time()
44
45     def get_started(self):
46         return self.started
47     def get_storage_index(self):
48         return self.storage_index
49     def get_encoding(self):
50         return self.encoding
51     def using_helper(self):
52         return self.helper
53     def get_size(self):
54         return self.size
55     def get_status(self):
56         return self.status
57     def get_progress(self):
58         return self.progress
59     def get_active(self):
60         return self.active
61     def get_counter(self):
62         return self.counter
63     def get_problems(self):
64         return self._problems
65
66     def add_fetch_timing(self, server, elapsed):
67         if server not in self.timings["fetch_per_server"]:
68             self.timings["fetch_per_server"][server] = []
69         self.timings["fetch_per_server"][server].append(elapsed)
70     def accumulate_decode_time(self, elapsed):
71         self.timings["decode"] += elapsed
72     def accumulate_decrypt_time(self, elapsed):
73         self.timings["decrypt"] += elapsed
74     def set_storage_index(self, si):
75         self.storage_index = si
76     def set_helper(self, helper):
77         self.helper = helper
78     def set_encoding(self, k, n):
79         self.encoding = (k, n)
80     def set_size(self, size):
81         self.size = size
82     def set_status(self, status):
83         self.status = status
84     def set_progress(self, value):
85         self.progress = value
86     def set_active(self, value):
87         self.active = value
88     def add_problem(self, server, f):
89         serverid = server.get_serverid()
90         self._problems[serverid] = f
91
92 class Marker:
93     pass
94
95 class Retrieve:
96     # this class is currently single-use. Eventually (in MDMF) we will make
97     # it multi-use, in which case you can call download(range) multiple
98     # times, and each will have a separate response chain. However the
99     # Retrieve object will remain tied to a specific version of the file, and
100     # will use a single ServerMap instance.
101     implements(IPushProducer)
102
103     def __init__(self, filenode, storage_broker, servermap, verinfo,
104                  fetch_privkey=False, verify=False):
105         self._node = filenode
106         _assert(self._node.get_pubkey())
107         self._storage_broker = storage_broker
108         self._storage_index = filenode.get_storage_index()
109         _assert(self._node.get_readkey())
110         self._last_failure = None
111         prefix = si_b2a(self._storage_index)[:5]
112         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
113         self._running = True
114         self._decoding = False
115         self._bad_shares = set()
116
117         self.servermap = servermap
118         self.verinfo = verinfo
119         # TODO: make it possible to use self.verinfo.datalength instead
120         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
121          offsets_tuple) = self.verinfo
122         self._data_length = datalength
123         # during repair, we may be called upon to grab the private key, since
124         # it wasn't picked up during a verify=False checker run, and we'll
125         # need it for repair to generate a new version.
126         self._need_privkey = verify or (fetch_privkey
127                                         and not self._node.get_privkey())
128
129         if self._need_privkey:
130             # TODO: Evaluate the need for this. We'll use it if we want
131             # to limit how many queries are on the wire for the privkey
132             # at once.
133             self._privkey_query_markers = [] # one Marker for each time we've
134                                              # tried to get the privkey.
135
136         # verify means that we are using the downloader logic to verify all
137         # of our shares. This tells the downloader a few things.
138         #
139         # 1. We need to download all of the shares.
140         # 2. We don't need to decode or decrypt the shares, since our
141         #    caller doesn't care about the plaintext, only the
142         #    information about which shares are or are not valid.
143         # 3. When we are validating readers, we need to validate the
144         #    signature on the prefix. Do we? We already do this in the
145         #    servermap update?
146         self._verify = verify
147
148         self._status = RetrieveStatus()
149         self._status.set_storage_index(self._storage_index)
150         self._status.set_helper(False)
151         self._status.set_progress(0.0)
152         self._status.set_active(True)
153         self._status.set_size(datalength)
154         self._status.set_encoding(k, N)
155         self.readers = {}
156         self._stopped = False
157         self._pause_deferred = None
158         self._offset = None
159         self._read_length = None
160         self.log("got seqnum %d" % self.verinfo[0])
161
162
163     def get_status(self):
164         return self._status
165
166     def log(self, *args, **kwargs):
167         if "parent" not in kwargs:
168             kwargs["parent"] = self._log_number
169         if "facility" not in kwargs:
170             kwargs["facility"] = "tahoe.mutable.retrieve"
171         return log.msg(*args, **kwargs)
172
173     def _set_current_status(self, state):
174         seg = "%d/%d" % (self._current_segment, self._last_segment)
175         self._status.set_status("segment %s (%s)" % (seg, state))
176
177     ###################
178     # IPushProducer
179
180     def pauseProducing(self):
181         """
182         I am called by my download target if we have produced too much
183         data for it to handle. I make the downloader stop producing new
184         data until my resumeProducing method is called.
185         """
186         if self._pause_deferred is not None:
187             return
188
189         # fired when the download is unpaused.
190         self._old_status = self._status.get_status()
191         self._set_current_status("paused")
192
193         self._pause_deferred = defer.Deferred()
194
195
196     def resumeProducing(self):
197         """
198         I am called by my download target once it is ready to begin
199         receiving data again.
200         """
201         if self._pause_deferred is None:
202             return
203
204         p = self._pause_deferred
205         self._pause_deferred = None
206         self._status.set_status(self._old_status)
207
208         eventually(p.callback, None)
209
210     def stopProducing(self):
211         self._stopped = True
212         self.resumeProducing()
213
214
215     def _check_for_paused(self, res):
216         """
217         I am called just before a write to the consumer. I return a
218         Deferred that eventually fires with the data that is to be
219         written to the consumer. If the download has not been paused,
220         the Deferred fires immediately. Otherwise, the Deferred fires
221         when the downloader is unpaused.
222         """
223         if self._pause_deferred is not None:
224             d = defer.Deferred()
225             self._pause_deferred.addCallback(lambda ignored: d.callback(res))
226             return d
227         return res
228
229     def _check_for_stopped(self, res):
230         if self._stopped:
231             raise DownloadStopped("our Consumer called stopProducing()")
232         return res
233
234
235     def download(self, consumer=None, offset=0, size=None):
236         precondition(self._verify or IConsumer.providedBy(consumer))
237         if size is None:
238             size = self._data_length - offset
239         if self._verify:
240             _assert(size == self._data_length, (size, self._data_length))
241         self.log("starting download")
242         self._done_deferred = defer.Deferred()
243         if consumer:
244             self._consumer = consumer
245             # we provide IPushProducer, so streaming=True, per IConsumer.
246             self._consumer.registerProducer(self, streaming=True)
247         self._started = time.time()
248         self._started_fetching = time.time()
249         if size == 0:
250             # short-circuit the rest of the process
251             self._done()
252         else:
253             self._start_download(consumer, offset, size)
254         return self._done_deferred
255
256     def _start_download(self, consumer, offset, size):
257         precondition((0 <= offset < self._data_length)
258                      and (size > 0)
259                      and (offset+size <= self._data_length),
260                      (offset, size, self._data_length))
261
262         self._offset = offset
263         self._read_length = size
264         self._setup_encoding_parameters()
265         self._setup_download()
266
267         # The download process beyond this is a state machine.
268         # _add_active_servers will select the servers that we want to use
269         # for the download, and then attempt to start downloading. After
270         # each segment, it will check for doneness, reacting to broken
271         # servers and corrupt shares as necessary. If it runs out of good
272         # servers before downloading all of the segments, _done_deferred
273         # will errback.  Otherwise, it will eventually callback with the
274         # contents of the mutable file.
275         self.loop()
276
277     def loop(self):
278         d = fireEventually(None) # avoid #237 recursion limit problem
279         d.addCallback(lambda ign: self._activate_enough_servers())
280         d.addCallback(lambda ign: self._download_current_segment())
281         # when we're done, _download_current_segment will call _done. If we
282         # aren't, it will call loop() again.
283         d.addErrback(self._error)
284
285     def _setup_download(self):
286         self._status.set_status("Retrieving Shares")
287
288         # how many shares do we need?
289         (seqnum,
290          root_hash,
291          IV,
292          segsize,
293          datalength,
294          k,
295          N,
296          prefix,
297          offsets_tuple) = self.verinfo
298
299         # first, which servers can we use?
300         versionmap = self.servermap.make_versionmap()
301         shares = versionmap[self.verinfo]
302         # this sharemap is consumed as we decide to send requests
303         self.remaining_sharemap = DictOfSets()
304         for (shnum, server, timestamp) in shares:
305             self.remaining_sharemap.add(shnum, server)
306             # Reuse the SlotReader from the servermap.
307             key = (self.verinfo, server.get_serverid(),
308                    self._storage_index, shnum)
309             if key in self.servermap.proxies:
310                 reader = self.servermap.proxies[key]
311             else:
312                 reader = MDMFSlotReadProxy(server.get_rref(),
313                                            self._storage_index, shnum, None)
314             reader.server = server
315             self.readers[shnum] = reader
316
317         if len(self.remaining_sharemap) < k:
318             self._raise_notenoughshareserror()
319
320         self.shares = {} # maps shnum to validated blocks
321         self._active_readers = [] # list of active readers for this dl.
322         self._block_hash_trees = {} # shnum => hashtree
323
324         for i in xrange(self._total_shares):
325             # So we don't have to do this later.
326             self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
327
328         # We need one share hash tree for the entire file; its leaves
329         # are the roots of the block hash trees for the shares that
330         # comprise it, and its root is in the verinfo.
331         self.share_hash_tree = hashtree.IncompleteHashTree(N)
332         self.share_hash_tree.set_hashes({0: root_hash})
333
334     def decode(self, blocks_and_salts, segnum):
335         """
336         I am a helper method that the mutable file update process uses
337         as a shortcut to decode and decrypt the segments that it needs
338         to fetch in order to perform a file update. I take in a
339         collection of blocks and salts, and pick some of those to make a
340         segment with. I return the plaintext associated with that
341         segment.
342         """
343         # We don't need the block hash trees in this case.
344         self._block_hash_trees = None
345         self._offset = 0
346         self._read_length = self._data_length
347         self._setup_encoding_parameters()
348
349         # _decode_blocks() expects the output of a gatherResults that
350         # contains the outputs of _validate_block() (each of which is a dict
351         # mapping shnum to (block,salt) bytestrings).
352         d = self._decode_blocks([blocks_and_salts], segnum)
353         d.addCallback(self._decrypt_segment)
354         return d
355
356
357     def _setup_encoding_parameters(self):
358         """
359         I set up the encoding parameters, including k, n, the number
360         of segments associated with this file, and the segment decoders.
361         """
362         (seqnum,
363          root_hash,
364          IV,
365          segsize,
366          datalength,
367          k,
368          n,
369          known_prefix,
370          offsets_tuple) = self.verinfo
371         self._required_shares = k
372         self._total_shares = n
373         self._segment_size = segsize
374         #self._data_length = datalength # set during __init__()
375
376         if not IV:
377             self._version = MDMF_VERSION
378         else:
379             self._version = SDMF_VERSION
380
381         if datalength and segsize:
382             self._num_segments = mathutil.div_ceil(datalength, segsize)
383             self._tail_data_size = datalength % segsize
384         else:
385             self._num_segments = 0
386             self._tail_data_size = 0
387
388         self._segment_decoder = codec.CRSDecoder()
389         self._segment_decoder.set_params(segsize, k, n)
390
391         if  not self._tail_data_size:
392             self._tail_data_size = segsize
393
394         self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
395                                                          self._required_shares)
396         if self._tail_segment_size == self._segment_size:
397             self._tail_decoder = self._segment_decoder
398         else:
399             self._tail_decoder = codec.CRSDecoder()
400             self._tail_decoder.set_params(self._tail_segment_size,
401                                           self._required_shares,
402                                           self._total_shares)
403
404         self.log("got encoding parameters: "
405                  "k: %d "
406                  "n: %d "
407                  "%d segments of %d bytes each (%d byte tail segment)" % \
408                  (k, n, self._num_segments, self._segment_size,
409                   self._tail_segment_size))
410
411         # Our last task is to tell the downloader where to start and
412         # where to stop. We use three parameters for that:
413         #   - self._start_segment: the segment that we need to start
414         #     downloading from.
415         #   - self._current_segment: the next segment that we need to
416         #     download.
417         #   - self._last_segment: The last segment that we were asked to
418         #     download.
419         #
420         #  We say that the download is complete when
421         #  self._current_segment > self._last_segment. We use
422         #  self._start_segment and self._last_segment to know when to
423         #  strip things off of segments, and how much to strip.
424         if self._offset:
425             self.log("got offset: %d" % self._offset)
426             # our start segment is the first segment containing the
427             # offset we were given.
428             start = self._offset // self._segment_size
429
430             _assert(start <= self._num_segments,
431                     start=start, num_segments=self._num_segments,
432                     offset=self._offset, segment_size=self._segment_size)
433             self._start_segment = start
434             self.log("got start segment: %d" % self._start_segment)
435         else:
436             self._start_segment = 0
437
438         # We might want to read only part of the file, and need to figure out
439         # where to stop reading. Our end segment is the last segment
440         # containing part of the segment that we were asked to read.
441         _assert(self._read_length > 0, self._read_length)
442         end_data = self._offset + self._read_length
443
444         # We don't actually need to read the byte at end_data, but the one
445         # before it.
446         end = (end_data - 1) // self._segment_size
447         _assert(0 <= end < self._num_segments,
448                 end=end, num_segments=self._num_segments,
449                 end_data=end_data, offset=self._offset,
450                 read_length=self._read_length, segment_size=self._segment_size)
451         self._last_segment = end
452         self.log("got end segment: %d" % self._last_segment)
453
454         self._current_segment = self._start_segment
455
456     def _activate_enough_servers(self):
457         """
458         I populate self._active_readers with enough active readers to
459         retrieve the contents of this mutable file. I am called before
460         downloading starts, and (eventually) after each validation
461         error, connection error, or other problem in the download.
462         """
463         # TODO: It would be cool to investigate other heuristics for
464         # reader selection. For instance, the cost (in time the user
465         # spends waiting for their file) of selecting a really slow server
466         # that happens to have a primary share is probably more than
467         # selecting a really fast server that doesn't have a primary
468         # share. Maybe the servermap could be extended to provide this
469         # information; it could keep track of latency information while
470         # it gathers more important data, and then this routine could
471         # use that to select active readers.
472         #
473         # (these and other questions would be easier to answer with a
474         #  robust, configurable tahoe-lafs simulator, which modeled node
475         #  failures, differences in node speed, and other characteristics
476         #  that we expect storage servers to have.  You could have
477         #  presets for really stable grids (like allmydata.com),
478         #  friendnets, make it easy to configure your own settings, and
479         #  then simulate the effect of big changes on these use cases
480         #  instead of just reasoning about what the effect might be. Out
481         #  of scope for MDMF, though.)
482
483         # XXX: Why don't format= log messages work here?
484
485         known_shnums = set(self.remaining_sharemap.keys())
486         used_shnums = set([r.shnum for r in self._active_readers])
487         unused_shnums = known_shnums - used_shnums
488
489         if self._verify:
490             new_shnums = unused_shnums # use them all
491         elif len(self._active_readers) < self._required_shares:
492             # need more shares
493             more = self._required_shares - len(self._active_readers)
494             # We favor lower numbered shares, since FEC is faster with
495             # primary shares than with other shares, and lower-numbered
496             # shares are more likely to be primary than higher numbered
497             # shares.
498             new_shnums = sorted(unused_shnums)[:more]
499             if len(new_shnums) < more:
500                 # We don't have enough readers to retrieve the file; fail.
501                 self._raise_notenoughshareserror()
502         else:
503             new_shnums = []
504
505         self.log("adding %d new servers to the active list" % len(new_shnums))
506         for shnum in new_shnums:
507             reader = self.readers[shnum]
508             self._active_readers.append(reader)
509             self.log("added reader for share %d" % shnum)
510             # Each time we add a reader, we check to see if we need the
511             # private key. If we do, we politely ask for it and then continue
512             # computing. If we find that we haven't gotten it at the end of
513             # segment decoding, then we'll take more drastic measures.
514             if self._need_privkey and not self._node.is_readonly():
515                 d = reader.get_encprivkey()
516                 d.addCallback(self._try_to_validate_privkey, reader, reader.server)
517                 # XXX: don't just drop the Deferred. We need error-reporting
518                 # but not flow-control here.
519
520     def _try_to_validate_prefix(self, prefix, reader):
521         """
522         I check that the prefix returned by a candidate server for
523         retrieval matches the prefix that the servermap knows about
524         (and, hence, the prefix that was validated earlier). If it does,
525         I return True, which means that I approve of the use of the
526         candidate server for segment retrieval. If it doesn't, I return
527         False, which means that another server must be chosen.
528         """
529         (seqnum,
530          root_hash,
531          IV,
532          segsize,
533          datalength,
534          k,
535          N,
536          known_prefix,
537          offsets_tuple) = self.verinfo
538         if known_prefix != prefix:
539             self.log("prefix from share %d doesn't match" % reader.shnum)
540             raise UncoordinatedWriteError("Mismatched prefix -- this could "
541                                           "indicate an uncoordinated write")
542         # Otherwise, we're okay -- no issues.
543
544     def _mark_bad_share(self, server, shnum, reader, f):
545         """
546         I mark the given (server, shnum) as a bad share, which means that it
547         will not be used anywhere else.
548
549         There are several reasons to want to mark something as a bad
550         share. These include:
551
552             - A connection error to the server.
553             - A mismatched prefix (that is, a prefix that does not match
554               our local conception of the version information string).
555             - A failing block hash, salt hash, share hash, or other
556               integrity check.
557
558         This method will ensure that readers that we wish to mark bad
559         (for these reasons or other reasons) are not used for the rest
560         of the download. Additionally, it will attempt to tell the
561         remote server (with no guarantee of success) that its share is
562         corrupt.
563         """
564         self.log("marking share %d on server %s as bad" % \
565                  (shnum, server.get_name()))
566         prefix = self.verinfo[-2]
567         self.servermap.mark_bad_share(server, shnum, prefix)
568         self._bad_shares.add((server, shnum, f))
569         self._status.add_problem(server, f)
570         self._last_failure = f
571
572         # Remove the reader from _active_readers
573         self._active_readers.remove(reader)
574         for shnum in list(self.remaining_sharemap.keys()):
575             self.remaining_sharemap.discard(shnum, reader.server)
576
577         if f.check(BadShareError):
578             self.notify_server_corruption(server, shnum, str(f.value))
579
580     def _download_current_segment(self):
581         """
582         I download, validate, decode, decrypt, and assemble the segment
583         that this Retrieve is currently responsible for downloading.
584         """
585
586         if self._current_segment > self._last_segment:
587             # No more segments to download, we're done.
588             self.log("got plaintext, done")
589             return self._done()
590         elif self._verify and len(self._active_readers) == 0:
591             self.log("no more good shares, no need to keep verifying")
592             return self._done()
593         self.log("on segment %d of %d" %
594                  (self._current_segment + 1, self._num_segments))
595         d = self._process_segment(self._current_segment)
596         d.addCallback(lambda ign: self.loop())
597         return d
598
599     def _process_segment(self, segnum):
600         """
601         I download, validate, decode, and decrypt one segment of the
602         file that this Retrieve is retrieving. This means coordinating
603         the process of getting k blocks of that file, validating them,
604         assembling them into one segment with the decoder, and then
605         decrypting them.
606         """
607         self.log("processing segment %d" % segnum)
608
609         # TODO: The old code uses a marker. Should this code do that
610         # too? What did the Marker do?
611
612         # We need to ask each of our active readers for its block and
613         # salt. We will then validate those. If validation is
614         # successful, we will assemble the results into plaintext.
615         ds = []
616         for reader in self._active_readers:
617             started = time.time()
618             d1 = reader.get_block_and_salt(segnum)
619             d2,d3 = self._get_needed_hashes(reader, segnum)
620             d = deferredutil.gatherResults([d1,d2,d3])
621             d.addCallback(self._validate_block, segnum, reader, reader.server, started)
622             # _handle_bad_share takes care of recoverable errors (by dropping
623             # that share and returning None). Any other errors (i.e. code
624             # bugs) are passed through and cause the retrieve to fail.
625             d.addErrback(self._handle_bad_share, [reader])
626             ds.append(d)
627         dl = deferredutil.gatherResults(ds)
628         if self._verify:
629             dl.addCallback(lambda ignored: "")
630             dl.addCallback(self._set_segment)
631         else:
632             dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
633         return dl
634
635
636     def _maybe_decode_and_decrypt_segment(self, results, segnum):
637         """
638         I take the results of fetching and validating the blocks from
639         _process_segment. If validation and fetching succeeded without
640         incident, I will proceed with decoding and decryption. Otherwise, I
641         will do nothing.
642         """
643         self.log("trying to decode and decrypt segment %d" % segnum)
644
645         # 'results' is the output of a gatherResults set up in
646         # _process_segment(). Each component Deferred will either contain the
647         # non-Failure output of _validate_block() for a single block (i.e.
648         # {segnum:(block,salt)}), or None if _validate_block threw an
649         # exception and _validation_or_decoding_failed handled it (by
650         # dropping that server).
651
652         if None in results:
653             self.log("some validation operations failed; not proceeding")
654             return defer.succeed(None)
655         self.log("everything looks ok, building segment %d" % segnum)
656         d = self._decode_blocks(results, segnum)
657         d.addCallback(self._decrypt_segment)
658         # check to see whether we've been paused before writing
659         # anything.
660         d.addCallback(self._check_for_paused)
661         d.addCallback(self._check_for_stopped)
662         d.addCallback(self._set_segment)
663         return d
664
665
666     def _set_segment(self, segment):
667         """
668         Given a plaintext segment, I register that segment with the
669         target that is handling the file download.
670         """
671         self.log("got plaintext for segment %d" % self._current_segment)
672
673         if self._read_length == 0:
674             self.log("on first+last segment, size=0, using 0 bytes")
675             segment = b""
676
677         if self._current_segment == self._last_segment:
678             # trim off the tail
679             wanted = (self._offset + self._read_length) % self._segment_size
680             if wanted != 0:
681                 self.log("on the last segment: using first %d bytes" % wanted)
682                 segment = segment[:wanted]
683             else:
684                 self.log("on the last segment: using all %d bytes" %
685                          len(segment))
686
687         if self._current_segment == self._start_segment:
688             # Trim off the head, if offset != 0. This should also work if
689             # start==last, because we trim the tail first.
690             skip = self._offset % self._segment_size
691             self.log("on the first segment: skipping first %d bytes" % skip)
692             segment = segment[skip:]
693
694         if not self._verify:
695             self._consumer.write(segment)
696         else:
697             # we don't care about the plaintext if we are doing a verify.
698             segment = None
699         self._current_segment += 1
700
701
702     def _handle_bad_share(self, f, readers):
703         """
704         I am called when a block or a salt fails to correctly validate, or when
705         the decryption or decoding operation fails for some reason.  I react to
706         this failure by notifying the remote server of corruption, and then
707         removing the remote server from further activity.
708         """
709         # these are the errors we can tolerate: by giving up on this share
710         # and finding others to replace it. Any other errors (i.e. coding
711         # bugs) are re-raised, causing the download to fail.
712         f.trap(DeadReferenceError, RemoteException, BadShareError)
713
714         # DeadReferenceError happens when we try to fetch data from a server
715         # that has gone away. RemoteException happens if the server had an
716         # internal error. BadShareError encompasses: (UnknownVersionError,
717         # LayoutInvalid, struct.error) which happen when we get obviously
718         # wrong data, and CorruptShareError which happens later, when we
719         # perform integrity checks on the data.
720
721         precondition(isinstance(readers, list), readers)
722         bad_shnums = [reader.shnum for reader in readers]
723
724         self.log("validation or decoding failed on share(s) %s, server(s) %s "
725                  ", segment %d: %s" % \
726                  (bad_shnums, readers, self._current_segment, str(f)))
727         for reader in readers:
728             self._mark_bad_share(reader.server, reader.shnum, reader, f)
729         return None
730
731
732     def _validate_block(self, results, segnum, reader, server, started):
733         """
734         I validate a block from one share on a remote server.
735         """
736         # Grab the part of the block hash tree that is necessary to
737         # validate this block, then generate the block hash root.
738         self.log("validating share %d for segment %d" % (reader.shnum,
739                                                              segnum))
740         elapsed = time.time() - started
741         self._status.add_fetch_timing(server, elapsed)
742         self._set_current_status("validating blocks")
743
744         block_and_salt, blockhashes, sharehashes = results
745         block, salt = block_and_salt
746         _assert(type(block) is str, (block, salt))
747
748         blockhashes = dict(enumerate(blockhashes))
749         self.log("the reader gave me the following blockhashes: %s" % \
750                  blockhashes.keys())
751         self.log("the reader gave me the following sharehashes: %s" % \
752                  sharehashes.keys())
753         bht = self._block_hash_trees[reader.shnum]
754
755         if bht.needed_hashes(segnum, include_leaf=True):
756             try:
757                 bht.set_hashes(blockhashes)
758             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
759                     IndexError), e:
760                 raise CorruptShareError(server,
761                                         reader.shnum,
762                                         "block hash tree failure: %s" % e)
763
764         if self._version == MDMF_VERSION:
765             blockhash = hashutil.block_hash(salt + block)
766         else:
767             blockhash = hashutil.block_hash(block)
768         # If this works without an error, then validation is
769         # successful.
770         try:
771            bht.set_hashes(leaves={segnum: blockhash})
772         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
773                 IndexError), e:
774             raise CorruptShareError(server,
775                                     reader.shnum,
776                                     "block hash tree failure: %s" % e)
777
778         # Reaching this point means that we know that this segment
779         # is correct. Now we need to check to see whether the share
780         # hash chain is also correct.
781         # SDMF wrote share hash chains that didn't contain the
782         # leaves, which would be produced from the block hash tree.
783         # So we need to validate the block hash tree first. If
784         # successful, then bht[0] will contain the root for the
785         # shnum, which will be a leaf in the share hash tree, which
786         # will allow us to validate the rest of the tree.
787         try:
788             self.share_hash_tree.set_hashes(hashes=sharehashes,
789                                         leaves={reader.shnum: bht[0]})
790         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
791                 IndexError), e:
792             raise CorruptShareError(server,
793                                     reader.shnum,
794                                     "corrupt hashes: %s" % e)
795
796         self.log('share %d is valid for segment %d' % (reader.shnum,
797                                                        segnum))
798         return {reader.shnum: (block, salt)}
799
800
801     def _get_needed_hashes(self, reader, segnum):
802         """
803         I get the hashes needed to validate segnum from the reader, then return
804         to my caller when this is done.
805         """
806         bht = self._block_hash_trees[reader.shnum]
807         needed = bht.needed_hashes(segnum, include_leaf=True)
808         # The root of the block hash tree is also a leaf in the share
809         # hash tree. So we don't need to fetch it from the remote
810         # server. In the case of files with one segment, this means that
811         # we won't fetch any block hash tree from the remote server,
812         # since the hash of each share of the file is the entire block
813         # hash tree, and is a leaf in the share hash tree. This is fine,
814         # since any share corruption will be detected in the share hash
815         # tree.
816         #needed.discard(0)
817         self.log("getting blockhashes for segment %d, share %d: %s" % \
818                  (segnum, reader.shnum, str(needed)))
819         # TODO is force_remote necessary here?
820         d1 = reader.get_blockhashes(needed, force_remote=False)
821         if self.share_hash_tree.needed_hashes(reader.shnum):
822             need = self.share_hash_tree.needed_hashes(reader.shnum)
823             self.log("also need sharehashes for share %d: %s" % (reader.shnum,
824                                                                  str(need)))
825             d2 = reader.get_sharehashes(need, force_remote=False)
826         else:
827             d2 = defer.succeed({}) # the logic in the next method
828                                    # expects a dict
829         return d1,d2
830
831
832     def _decode_blocks(self, results, segnum):
833         """
834         I take a list of k blocks and salts, and decode that into a
835         single encrypted segment.
836         """
837         # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
838         # want to merge them all
839         blocks_and_salts = {}
840         for d in results:
841             blocks_and_salts.update(d)
842
843         # All of these blocks should have the same salt; in SDMF, it is
844         # the file-wide IV, while in MDMF it is the per-segment salt. In
845         # either case, we just need to get one of them and use it.
846         #
847         # d.items()[0] is like (shnum, (block, salt))
848         # d.items()[0][1] is like (block, salt)
849         # d.items()[0][1][1] is the salt.
850         salt = blocks_and_salts.items()[0][1][1]
851         # Next, extract just the blocks from the dict. We'll use the
852         # salt in the next step.
853         share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
854         d2 = dict(share_and_shareids)
855         shareids = []
856         shares = []
857         for shareid, share in d2.items():
858             shareids.append(shareid)
859             shares.append(share)
860
861         self._set_current_status("decoding")
862         started = time.time()
863         _assert(len(shareids) >= self._required_shares, len(shareids))
864         # zfec really doesn't want extra shares
865         shareids = shareids[:self._required_shares]
866         shares = shares[:self._required_shares]
867         self.log("decoding segment %d" % segnum)
868         if segnum == self._num_segments - 1:
869             d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
870         else:
871             d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
872         def _process(buffers):
873             segment = "".join(buffers)
874             self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
875                      segnum=segnum,
876                      numsegs=self._num_segments,
877                      level=log.NOISY)
878             self.log(" joined length %d, datalength %d" %
879                      (len(segment), self._data_length))
880             if segnum == self._num_segments - 1:
881                 size_to_use = self._tail_data_size
882             else:
883                 size_to_use = self._segment_size
884             segment = segment[:size_to_use]
885             self.log(" segment len=%d" % len(segment))
886             self._status.accumulate_decode_time(time.time() - started)
887             return segment, salt
888         d.addCallback(_process)
889         return d
890
891
892     def _decrypt_segment(self, segment_and_salt):
893         """
894         I take a single segment and its salt, and decrypt it. I return
895         the plaintext of the segment that is in my argument.
896         """
897         segment, salt = segment_and_salt
898         self._set_current_status("decrypting")
899         self.log("decrypting segment %d" % self._current_segment)
900         started = time.time()
901         key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
902         decryptor = AES(key)
903         plaintext = decryptor.process(segment)
904         self._status.accumulate_decrypt_time(time.time() - started)
905         return plaintext
906
907
908     def notify_server_corruption(self, server, shnum, reason):
909         rref = server.get_rref()
910         rref.callRemoteOnly("advise_corrupt_share",
911                             "mutable", self._storage_index, shnum, reason)
912
913
914     def _try_to_validate_privkey(self, enc_privkey, reader, server):
915         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
916         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
917         if alleged_writekey != self._node.get_writekey():
918             self.log("invalid privkey from %s shnum %d" %
919                      (reader, reader.shnum),
920                      level=log.WEIRD, umid="YIw4tA")
921             if self._verify:
922                 self.servermap.mark_bad_share(server, reader.shnum,
923                                               self.verinfo[-2])
924                 e = CorruptShareError(server,
925                                       reader.shnum,
926                                       "invalid privkey")
927                 f = failure.Failure(e)
928                 self._bad_shares.add((server, reader.shnum, f))
929             return
930
931         # it's good
932         self.log("got valid privkey from shnum %d on reader %s" %
933                  (reader.shnum, reader))
934         privkey = rsa.create_signing_key_from_string(alleged_privkey_s)
935         self._node._populate_encprivkey(enc_privkey)
936         self._node._populate_privkey(privkey)
937         self._need_privkey = False
938
939
940
941     def _done(self):
942         """
943         I am called by _download_current_segment when the download process
944         has finished successfully. After making some useful logging
945         statements, I return the decrypted contents to the owner of this
946         Retrieve object through self._done_deferred.
947         """
948         self._running = False
949         self._status.set_active(False)
950         now = time.time()
951         self._status.timings['total'] = now - self._started
952         self._status.timings['fetch'] = now - self._started_fetching
953         self._status.set_status("Finished")
954         self._status.set_progress(1.0)
955
956         # remember the encoding parameters, use them again next time
957         (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
958          offsets_tuple) = self.verinfo
959         self._node._populate_required_shares(k)
960         self._node._populate_total_shares(N)
961
962         if self._verify:
963             ret = self._bad_shares
964             self.log("done verifying, found %d bad shares" % len(ret))
965         else:
966             # TODO: upload status here?
967             ret = self._consumer
968             self._consumer.unregisterProducer()
969         eventually(self._done_deferred.callback, ret)
970
971
972     def _raise_notenoughshareserror(self):
973         """
974         I am called when there are not enough active servers left to complete
975         the download. After making some useful logging statements, I throw an
976         exception to that effect to the caller of this Retrieve object through
977         self._done_deferred.
978         """
979
980         format = ("ran out of servers: "
981                   "have %(have)d of %(total)d segments; "
982                   "found %(bad)d bad shares; "
983                   "have %(remaining)d remaining shares of the right version; "
984                   "encoding %(k)d-of-%(n)d")
985         args = {"have": self._current_segment,
986                 "total": self._num_segments,
987                 "need": self._last_segment,
988                 "k": self._required_shares,
989                 "n": self._total_shares,
990                 "bad": len(self._bad_shares),
991                 "remaining": len(self.remaining_sharemap),
992                }
993         raise NotEnoughSharesError("%s, last failure: %s" %
994                                    (format % args, str(self._last_failure)))
995
996     def _error(self, f):
997         # all errors, including NotEnoughSharesError, land here
998         self._running = False
999         self._status.set_active(False)
1000         now = time.time()
1001         self._status.timings['total'] = now - self._started
1002         self._status.timings['fetch'] = now - self._started_fetching
1003         self._status.set_status("Failed")
1004         eventually(self._done_deferred.errback, f)