]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
91c163db8b43f87948f1616cd79de5e158fc4268
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class DownloadResults:
32     implements(IDownloadResults)
33
34     def __init__(self):
35         self.servers_used = set()
36         self.server_problems = {}
37         self.servermap = {}
38         self.timings = {}
39         self.file_size = None
40
41 class Output:
42     def __init__(self, downloadable, key, total_length, log_parent,
43                  download_status):
44         self.downloadable = downloadable
45         self._decryptor = AES(key)
46         self._crypttext_hasher = hashutil.crypttext_hasher()
47         self._plaintext_hasher = hashutil.plaintext_hasher()
48         self.length = 0
49         self.total_length = total_length
50         self._segment_number = 0
51         self._plaintext_hash_tree = None
52         self._crypttext_hash_tree = None
53         self._opened = False
54         self._log_parent = log_parent
55         self._status = download_status
56         self._status.set_progress(0.0)
57
58     def log(self, *args, **kwargs):
59         if "parent" not in kwargs:
60             kwargs["parent"] = self._log_parent
61         if "facility" not in kwargs:
62             kwargs["facility"] = "download.output"
63         return log.msg(*args, **kwargs)
64
65     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66         self._plaintext_hash_tree = plaintext_hashtree
67         self._crypttext_hash_tree = crypttext_hashtree
68
69     def write_segment(self, crypttext):
70         self.length += len(crypttext)
71         self._status.set_progress( float(self.length) / self.total_length )
72
73         # memory footprint: 'crypttext' is the only segment_size usage
74         # outstanding. While we decrypt it into 'plaintext', we hit
75         # 2*segment_size.
76         self._crypttext_hasher.update(crypttext)
77         if self._crypttext_hash_tree:
78             ch = hashutil.crypttext_segment_hasher()
79             ch.update(crypttext)
80             crypttext_leaves = {self._segment_number: ch.digest()}
81             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
82                      bytes=len(crypttext),
83                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
84                      level=log.NOISY)
85             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
86
87         plaintext = self._decryptor.process(crypttext)
88         del crypttext
89
90         # now we're back down to 1*segment_size.
91
92         self._plaintext_hasher.update(plaintext)
93         if self._plaintext_hash_tree:
94             ph = hashutil.plaintext_segment_hasher()
95             ph.update(plaintext)
96             plaintext_leaves = {self._segment_number: ph.digest()}
97             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
98                      bytes=len(plaintext),
99                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
100                      level=log.NOISY)
101             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
102
103         self._segment_number += 1
104         # We're still at 1*segment_size. The Downloadable is responsible for
105         # any memory usage beyond this.
106         if not self._opened:
107             self._opened = True
108             self.downloadable.open(self.total_length)
109         self.downloadable.write(plaintext)
110
111     def fail(self, why):
112         # this is really unusual, and deserves maximum forensics
113         self.log("download failed!", failure=why, level=log.SCARY)
114         self.downloadable.fail(why)
115
116     def close(self):
117         self.crypttext_hash = self._crypttext_hasher.digest()
118         self.plaintext_hash = self._plaintext_hasher.digest()
119         self.log("download finished, closing IDownloadable", level=log.NOISY)
120         self.downloadable.close()
121
122     def finish(self):
123         return self.downloadable.finish()
124
125 class ValidatedBucket:
126     """I am a front-end for a remote storage bucket, responsible for
127     retrieving and validating data from that bucket.
128
129     My get_block() method is used by BlockDownloaders.
130     """
131
132     def __init__(self, sharenum, bucket,
133                  share_hash_tree, roothash,
134                  num_blocks):
135         self.sharenum = sharenum
136         self.bucket = bucket
137         self._share_hash = None # None means not validated yet
138         self.share_hash_tree = share_hash_tree
139         self._roothash = roothash
140         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
141         self.started = False
142
143     def get_block(self, blocknum):
144         if not self.started:
145             d = self.bucket.start()
146             def _started(res):
147                 self.started = True
148                 return self.get_block(blocknum)
149             d.addCallback(_started)
150             return d
151
152         # the first time we use this bucket, we need to fetch enough elements
153         # of the share hash tree to validate it from our share hash up to the
154         # hashroot.
155         if not self._share_hash:
156             d1 = self.bucket.get_share_hashes()
157         else:
158             d1 = defer.succeed([])
159
160         # we might need to grab some elements of our block hash tree, to
161         # validate the requested block up to the share hash
162         needed = self.block_hash_tree.needed_hashes(blocknum)
163         if needed:
164             # TODO: get fewer hashes, use get_block_hashes(needed)
165             d2 = self.bucket.get_block_hashes()
166         else:
167             d2 = defer.succeed([])
168
169         d3 = self.bucket.get_block(blocknum)
170
171         d = defer.gatherResults([d1, d2, d3])
172         d.addCallback(self._got_data, blocknum)
173         return d
174
175     def _got_data(self, res, blocknum):
176         sharehashes, blockhashes, blockdata = res
177         blockhash = None # to make logging it safe
178
179         try:
180             if not self._share_hash:
181                 sh = dict(sharehashes)
182                 sh[0] = self._roothash # always use our own root, from the URI
183                 sht = self.share_hash_tree
184                 if sht.get_leaf_index(self.sharenum) not in sh:
185                     raise hashtree.NotEnoughHashesError
186                 sht.set_hashes(sh)
187                 self._share_hash = sht.get_leaf(self.sharenum)
188
189             blockhash = hashutil.block_hash(blockdata)
190             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
191             #        "%r .. %r: %s" %
192             #        (self.sharenum, blocknum, len(blockdata),
193             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
194
195             # we always validate the blockhash
196             bh = dict(enumerate(blockhashes))
197             # replace blockhash root with validated value
198             bh[0] = self._share_hash
199             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
200
201         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
202             # log.WEIRD: indicates undetected disk/network error, or more
203             # likely a programming error
204             log.msg("hash failure in block=%d, shnum=%d on %s" %
205                     (blocknum, self.sharenum, self.bucket))
206             if self._share_hash:
207                 log.msg(""" failure occurred when checking the block_hash_tree.
208                 This suggests that either the block data was bad, or that the
209                 block hashes we received along with it were bad.""")
210             else:
211                 log.msg(""" the failure probably occurred when checking the
212                 share_hash_tree, which suggests that the share hashes we
213                 received from the remote peer were bad.""")
214             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
215             log.msg(" block length: %d" % len(blockdata))
216             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
217             if len(blockdata) < 100:
218                 log.msg(" block data: %r" % (blockdata,))
219             else:
220                 log.msg(" block data start/end: %r .. %r" %
221                         (blockdata[:50], blockdata[-50:]))
222             log.msg(" root hash: %s" % base32.b2a(self._roothash))
223             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
224             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
225             lines = []
226             for i,h in sorted(sharehashes):
227                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
228             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
229             lines = []
230             for i,h in enumerate(blockhashes):
231                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
232             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
233             raise
234
235         # If we made it here, the block is good. If the hash trees didn't
236         # like what they saw, they would have raised a BadHashError, causing
237         # our caller to see a Failure and thus ignore this block (as well as
238         # dropping this bucket).
239         return blockdata
240
241
242
243 class BlockDownloader:
244     """I am responsible for downloading a single block (from a single bucket)
245     for a single segment.
246
247     I am a child of the SegmentDownloader.
248     """
249
250     def __init__(self, vbucket, blocknum, parent):
251         self.vbucket = vbucket
252         self.blocknum = blocknum
253         self.parent = parent
254         self._log_number = self.parent.log("starting block %d" % blocknum)
255
256     def log(self, msg, parent=None):
257         if parent is None:
258             parent = self._log_number
259         return self.parent.log(msg, parent=parent)
260
261     def start(self, segnum):
262         lognum = self.log("get_block(segnum=%d)" % segnum)
263         d = self.vbucket.get_block(segnum)
264         d.addCallbacks(self._hold_block, self._got_block_error,
265                        callbackArgs=(lognum,), errbackArgs=(lognum,))
266         return d
267
268     def _hold_block(self, data, lognum):
269         self.log("got block", parent=lognum)
270         self.parent.hold_block(self.blocknum, data)
271
272     def _got_block_error(self, f, lognum):
273         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
274                  parent=lognum)
275         self.parent.bucket_failed(self.vbucket)
276
277 class SegmentDownloader:
278     """I am responsible for downloading all the blocks for a single segment
279     of data.
280
281     I am a child of the FileDownloader.
282     """
283
284     def __init__(self, parent, segmentnumber, needed_shares):
285         self.parent = parent
286         self.segmentnumber = segmentnumber
287         self.needed_blocks = needed_shares
288         self.blocks = {} # k: blocknum, v: data
289         self._log_number = self.parent.log("starting segment %d" %
290                                            segmentnumber)
291
292     def log(self, msg, parent=None):
293         if parent is None:
294             parent = self._log_number
295         return self.parent.log(msg, parent=parent)
296
297     def start(self):
298         return self._download()
299
300     def _download(self):
301         d = self._try()
302         def _done(res):
303             if len(self.blocks) >= self.needed_blocks:
304                 # we only need self.needed_blocks blocks
305                 # we want to get the smallest blockids, because they are
306                 # more likely to be fast "primary blocks"
307                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
308                 blocks = []
309                 for blocknum in blockids:
310                     blocks.append(self.blocks[blocknum])
311                 return (blocks, blockids)
312             else:
313                 return self._download()
314         d.addCallback(_done)
315         return d
316
317     def _try(self):
318         # fill our set of active buckets, maybe raising NotEnoughPeersError
319         active_buckets = self.parent._activate_enough_buckets()
320         # Now we have enough buckets, in self.parent.active_buckets.
321
322         # in test cases, bd.start might mutate active_buckets right away, so
323         # we need to put off calling start() until we've iterated all the way
324         # through it.
325         downloaders = []
326         for blocknum, vbucket in active_buckets.iteritems():
327             bd = BlockDownloader(vbucket, blocknum, self)
328             downloaders.append(bd)
329         l = [bd.start(self.segmentnumber) for bd in downloaders]
330         return defer.DeferredList(l, fireOnOneErrback=True)
331
332     def hold_block(self, blocknum, data):
333         self.blocks[blocknum] = data
334
335     def bucket_failed(self, vbucket):
336         self.parent.bucket_failed(vbucket)
337
338 class DownloadStatus:
339     implements(IDownloadStatus)
340     statusid_counter = itertools.count(0)
341
342     def __init__(self):
343         self.storage_index = None
344         self.size = None
345         self.helper = False
346         self.status = "Not started"
347         self.progress = 0.0
348         self.paused = False
349         self.stopped = False
350         self.active = True
351         self.results = None
352         self.counter = self.statusid_counter.next()
353
354     def get_storage_index(self):
355         return self.storage_index
356     def get_size(self):
357         return self.size
358     def using_helper(self):
359         return self.helper
360     def get_status(self):
361         status = self.status
362         if self.paused:
363             status += " (output paused)"
364         if self.stopped:
365             status += " (output stopped)"
366         return status
367     def get_progress(self):
368         return self.progress
369     def get_active(self):
370         return self.active
371     def get_results(self):
372         return self.results
373     def get_counter(self):
374         return self.counter
375
376     def set_storage_index(self, si):
377         self.storage_index = si
378     def set_size(self, size):
379         self.size = size
380     def set_helper(self, helper):
381         self.helper = helper
382     def set_status(self, status):
383         self.status = status
384     def set_paused(self, paused):
385         self.paused = paused
386     def set_stopped(self, stopped):
387         self.stopped = stopped
388     def set_progress(self, value):
389         self.progress = value
390     def set_active(self, value):
391         self.active = value
392     def set_results(self, value):
393         self.results = value
394
395 class FileDownloader:
396     implements(IPushProducer)
397     check_crypttext_hash = True
398     check_plaintext_hash = True
399     _status = None
400
401     def __init__(self, client, u, downloadable):
402         self._client = client
403
404         u = IFileURI(u)
405         self._storage_index = u.storage_index
406         self._uri_extension_hash = u.uri_extension_hash
407         self._total_shares = u.total_shares
408         self._size = u.size
409         self._num_needed_shares = u.needed_shares
410
411         self._si_s = storage.si_b2a(self._storage_index)
412         self.init_logging()
413
414         self._started = time.time()
415         self._status = s = DownloadStatus()
416         s.set_status("Starting")
417         s.set_storage_index(self._storage_index)
418         s.set_size(self._size)
419         s.set_helper(False)
420         s.set_active(True)
421
422         self._results = DownloadResults()
423         s.set_results(self._results)
424         self._results.file_size = self._size
425         self._results.timings["servers_peer_selection"] = {}
426         self._results.timings["cumulative_fetch"] = 0.0
427         self._results.timings["cumulative_decode"] = 0.0
428         self._results.timings["cumulative_decrypt"] = 0.0
429
430         if IConsumer.providedBy(downloadable):
431             downloadable.registerProducer(self, True)
432         self._downloadable = downloadable
433         self._output = Output(downloadable, u.key, self._size, self._log_number,
434                               self._status)
435         self._paused = False
436         self._stopped = False
437
438         self.active_buckets = {} # k: shnum, v: bucket
439         self._share_buckets = [] # list of (sharenum, bucket) tuples
440         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
441         self._uri_extension_sources = []
442
443         self._uri_extension_data = None
444
445         self._fetch_failures = {"uri_extension": 0,
446                                 "plaintext_hashroot": 0,
447                                 "plaintext_hashtree": 0,
448                                 "crypttext_hashroot": 0,
449                                 "crypttext_hashtree": 0,
450                                 }
451
452     def init_logging(self):
453         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
454         num = self._client.log(format="FileDownloader(%(si)s): starting",
455                                si=storage.si_b2a(self._storage_index))
456         self._log_number = num
457
458     def log(self, *args, **kwargs):
459         if "parent" not in kwargs:
460             kwargs["parent"] = self._log_number
461         if "facility" not in kwargs:
462             kwargs["facility"] = "tahoe.download"
463         return log.msg(*args, **kwargs)
464
465     def pauseProducing(self):
466         if self._paused:
467             return
468         self._paused = defer.Deferred()
469         if self._status:
470             self._status.set_paused(True)
471
472     def resumeProducing(self):
473         if self._paused:
474             p = self._paused
475             self._paused = None
476             eventually(p.callback, None)
477             if self._status:
478                 self._status.set_paused(False)
479
480     def stopProducing(self):
481         self.log("Download.stopProducing")
482         self._stopped = True
483         if self._status:
484             self._status.set_stopped(True)
485             self._status.set_active(False)
486
487     def start(self):
488         self.log("starting download")
489
490         # first step: who should we download from?
491         d = defer.maybeDeferred(self._get_all_shareholders)
492         d.addCallback(self._got_all_shareholders)
493         # now get the uri_extension block from somebody and validate it
494         d.addCallback(self._obtain_uri_extension)
495         d.addCallback(self._got_uri_extension)
496         d.addCallback(self._get_hashtrees)
497         d.addCallback(self._create_validated_buckets)
498         # once we know that, we can download blocks from everybody
499         d.addCallback(self._download_all_segments)
500         def _finished(res):
501             if self._status:
502                 self._status.set_status("Finished")
503                 self._status.set_active(False)
504                 self._status.set_paused(False)
505             if IConsumer.providedBy(self._downloadable):
506                 self._downloadable.unregisterProducer()
507             return res
508         d.addBoth(_finished)
509         def _failed(why):
510             if self._status:
511                 self._status.set_status("Failed")
512                 self._status.set_active(False)
513             self._output.fail(why)
514             return why
515         d.addErrback(_failed)
516         d.addCallback(self._done)
517         return d
518
519     def _get_all_shareholders(self):
520         dl = []
521         for (peerid,ss) in self._client.get_permuted_peers("storage",
522                                                            self._storage_index):
523             d = ss.callRemote("get_buckets", self._storage_index)
524             d.addCallbacks(self._got_response, self._got_error,
525                            callbackArgs=(peerid,))
526             dl.append(d)
527         self._responses_received = 0
528         self._queries_sent = len(dl)
529         if self._status:
530             self._status.set_status("Locating Shares (%d/%d)" %
531                                     (self._responses_received,
532                                      self._queries_sent))
533         return defer.DeferredList(dl)
534
535     def _got_response(self, buckets, peerid):
536         self._responses_received += 1
537         if self._results:
538             elapsed = time.time() - self._started
539             self._results.timings["servers_peer_selection"][peerid] = elapsed
540         if self._status:
541             self._status.set_status("Locating Shares (%d/%d)" %
542                                     (self._responses_received,
543                                      self._queries_sent))
544         for sharenum, bucket in buckets.iteritems():
545             b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
546             self.add_share_bucket(sharenum, b)
547             self._uri_extension_sources.append(b)
548             if self._results:
549                 if peerid not in self._results.servermap:
550                     self._results.servermap[peerid] = set()
551                 self._results.servermap[peerid].add(sharenum)
552
553     def add_share_bucket(self, sharenum, bucket):
554         # this is split out for the benefit of test_encode.py
555         self._share_buckets.append( (sharenum, bucket) )
556
557     def _got_error(self, f):
558         self._client.log("Somebody failed. -- %s" % (f,))
559
560     def bucket_failed(self, vbucket):
561         shnum = vbucket.sharenum
562         del self.active_buckets[shnum]
563         s = self._share_vbuckets[shnum]
564         # s is a set of ValidatedBucket instances
565         s.remove(vbucket)
566         # ... which might now be empty
567         if not s:
568             # there are no more buckets which can provide this share, so
569             # remove the key. This may prompt us to use a different share.
570             del self._share_vbuckets[shnum]
571
572     def _got_all_shareholders(self, res):
573         if self._results:
574             now = time.time()
575             self._results.timings["peer_selection"] = now - self._started
576
577         if len(self._share_buckets) < self._num_needed_shares:
578             raise NotEnoughPeersError
579
580         #for s in self._share_vbuckets.values():
581         #    for vb in s:
582         #        assert isinstance(vb, ValidatedBucket), \
583         #               "vb is %s but should be a ValidatedBucket" % (vb,)
584
585     def _unpack_uri_extension_data(self, data):
586         return uri.unpack_extension(data)
587
588     def _obtain_uri_extension(self, ignored):
589         # all shareholders are supposed to have a copy of uri_extension, and
590         # all are supposed to be identical. We compute the hash of the data
591         # that comes back, and compare it against the version in our URI. If
592         # they don't match, ignore their data and try someone else.
593         if self._status:
594             self._status.set_status("Obtaining URI Extension")
595
596         self._uri_extension_fetch_started = time.time()
597         def _validate(proposal, bucket):
598             h = hashutil.uri_extension_hash(proposal)
599             if h != self._uri_extension_hash:
600                 self._fetch_failures["uri_extension"] += 1
601                 msg = ("The copy of uri_extension we received from "
602                        "%s was bad: wanted %s, got %s" %
603                        (bucket,
604                         base32.b2a(self._uri_extension_hash),
605                         base32.b2a(h)))
606                 self.log(msg, level=log.SCARY)
607                 raise BadURIExtensionHashValue(msg)
608             return self._unpack_uri_extension_data(proposal)
609         return self._obtain_validated_thing(None,
610                                             self._uri_extension_sources,
611                                             "uri_extension",
612                                             "get_uri_extension", (), _validate)
613
614     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
615                                 validatorfunc):
616         if not sources:
617             raise NotEnoughPeersError("started with zero peers while fetching "
618                                       "%s" % name)
619         bucket = sources[0]
620         sources = sources[1:]
621         #d = bucket.callRemote(methname, *args)
622         d = bucket.startIfNecessary()
623         d.addCallback(lambda res: getattr(bucket, methname)(*args))
624         d.addCallback(validatorfunc, bucket)
625         def _bad(f):
626             self.log("%s from vbucket %s failed:" % (name, bucket),
627                      failure=f, level=log.WEIRD)
628             if not sources:
629                 raise NotEnoughPeersError("ran out of peers, last error was %s"
630                                           % (f,))
631             # try again with a different one
632             return self._obtain_validated_thing(None, sources, name,
633                                                 methname, args, validatorfunc)
634         d.addErrback(_bad)
635         return d
636
637     def _got_uri_extension(self, uri_extension_data):
638         if self._results:
639             elapsed = time.time() - self._uri_extension_fetch_started
640             self._results.timings["uri_extension"] = elapsed
641
642         d = self._uri_extension_data = uri_extension_data
643
644         self._codec = codec.get_decoder_by_name(d['codec_name'])
645         self._codec.set_serialized_params(d['codec_params'])
646         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
647         self._tail_codec.set_serialized_params(d['tail_codec_params'])
648
649         crypttext_hash = d['crypttext_hash']
650         assert isinstance(crypttext_hash, str)
651         assert len(crypttext_hash) == 32
652         self._crypttext_hash = crypttext_hash
653         self._plaintext_hash = d['plaintext_hash']
654         self._roothash = d['share_root_hash']
655
656         self._segment_size = segment_size = d['segment_size']
657         self._total_segments = mathutil.div_ceil(self._size, segment_size)
658         self._current_segnum = 0
659
660         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
661         self._share_hashtree.set_hashes({0: self._roothash})
662
663     def _get_hashtrees(self, res):
664         self._get_hashtrees_started = time.time()
665         if self._status:
666             self._status.set_status("Retrieving Hash Trees")
667         d = self._get_plaintext_hashtrees()
668         d.addCallback(self._get_crypttext_hashtrees)
669         d.addCallback(self._setup_hashtrees)
670         return d
671
672     def _get_plaintext_hashtrees(self):
673         def _validate_plaintext_hashtree(proposal, bucket):
674             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
675                 self._fetch_failures["plaintext_hashroot"] += 1
676                 msg = ("The copy of the plaintext_root_hash we received from"
677                        " %s was bad" % bucket)
678                 raise BadPlaintextHashValue(msg)
679             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
680             pt_hashes = dict(list(enumerate(proposal)))
681             try:
682                 pt_hashtree.set_hashes(pt_hashes)
683             except hashtree.BadHashError:
684                 # the hashes they gave us were not self-consistent, even
685                 # though the root matched what we saw in the uri_extension
686                 # block
687                 self._fetch_failures["plaintext_hashtree"] += 1
688                 raise
689             self._plaintext_hashtree = pt_hashtree
690         d = self._obtain_validated_thing(None,
691                                          self._uri_extension_sources,
692                                          "plaintext_hashes",
693                                          "get_plaintext_hashes", (),
694                                          _validate_plaintext_hashtree)
695         return d
696
697     def _get_crypttext_hashtrees(self, res):
698         def _validate_crypttext_hashtree(proposal, bucket):
699             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
700                 self._fetch_failures["crypttext_hashroot"] += 1
701                 msg = ("The copy of the crypttext_root_hash we received from"
702                        " %s was bad" % bucket)
703                 raise BadCrypttextHashValue(msg)
704             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
705             ct_hashes = dict(list(enumerate(proposal)))
706             try:
707                 ct_hashtree.set_hashes(ct_hashes)
708             except hashtree.BadHashError:
709                 self._fetch_failures["crypttext_hashtree"] += 1
710                 raise
711             ct_hashtree.set_hashes(ct_hashes)
712             self._crypttext_hashtree = ct_hashtree
713         d = self._obtain_validated_thing(None,
714                                          self._uri_extension_sources,
715                                          "crypttext_hashes",
716                                          "get_crypttext_hashes", (),
717                                          _validate_crypttext_hashtree)
718         return d
719
720     def _setup_hashtrees(self, res):
721         self._output.setup_hashtrees(self._plaintext_hashtree,
722                                      self._crypttext_hashtree)
723         if self._results:
724             elapsed = time.time() - self._get_hashtrees_started
725             self._results.timings["hashtrees"] = elapsed
726
727     def _create_validated_buckets(self, ignored=None):
728         self._share_vbuckets = {}
729         for sharenum, bucket in self._share_buckets:
730             vbucket = ValidatedBucket(sharenum, bucket,
731                                       self._share_hashtree,
732                                       self._roothash,
733                                       self._total_segments)
734             s = self._share_vbuckets.setdefault(sharenum, set())
735             s.add(vbucket)
736
737     def _activate_enough_buckets(self):
738         """either return a mapping from shnum to a ValidatedBucket that can
739         provide data for that share, or raise NotEnoughPeersError"""
740
741         while len(self.active_buckets) < self._num_needed_shares:
742             # need some more
743             handled_shnums = set(self.active_buckets.keys())
744             available_shnums = set(self._share_vbuckets.keys())
745             potential_shnums = list(available_shnums - handled_shnums)
746             if not potential_shnums:
747                 raise NotEnoughPeersError
748             # choose a random share
749             shnum = random.choice(potential_shnums)
750             # and a random bucket that will provide it
751             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
752             self.active_buckets[shnum] = validated_bucket
753         return self.active_buckets
754
755
756     def _download_all_segments(self, res):
757         # the promise: upon entry to this function, self._share_vbuckets
758         # contains enough buckets to complete the download, and some extra
759         # ones to tolerate some buckets dropping out or having errors.
760         # self._share_vbuckets is a dictionary that maps from shnum to a set
761         # of ValidatedBuckets, which themselves are wrappers around
762         # RIBucketReader references.
763         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
764
765         self._started_fetching = time.time()
766
767         d = defer.succeed(None)
768         for segnum in range(self._total_segments-1):
769             d.addCallback(self._download_segment, segnum)
770             # this pause, at the end of write, prevents pre-fetch from
771             # happening until the consumer is ready for more data.
772             d.addCallback(self._check_for_pause)
773         d.addCallback(self._download_tail_segment, self._total_segments-1)
774         return d
775
776     def _check_for_pause(self, res):
777         if self._paused:
778             d = defer.Deferred()
779             self._paused.addCallback(lambda ignored: d.callback(res))
780             return d
781         if self._stopped:
782             raise DownloadStopped("our Consumer called stopProducing()")
783         return res
784
785     def _download_segment(self, res, segnum):
786         if self._status:
787             self._status.set_status("Downloading segment %d of %d" %
788                                     (segnum+1, self._total_segments))
789         self.log("downloading seg#%d of %d (%d%%)"
790                  % (segnum, self._total_segments,
791                     100.0 * segnum / self._total_segments))
792         # memory footprint: when the SegmentDownloader finishes pulling down
793         # all shares, we have 1*segment_size of usage.
794         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
795         started = time.time()
796         d = segmentdler.start()
797         def _finished_fetching(res):
798             elapsed = time.time() - started
799             self._results.timings["cumulative_fetch"] += elapsed
800             return res
801         if self._results:
802             d.addCallback(_finished_fetching)
803         # pause before using more memory
804         d.addCallback(self._check_for_pause)
805         # while the codec does its job, we hit 2*segment_size
806         def _started_decode(res):
807             self._started_decode = time.time()
808             return res
809         if self._results:
810             d.addCallback(_started_decode)
811         d.addCallback(lambda (shares, shareids):
812                       self._codec.decode(shares, shareids))
813         # once the codec is done, we drop back to 1*segment_size, because
814         # 'shares' goes out of scope. The memory usage is all in the
815         # plaintext now, spread out into a bunch of tiny buffers.
816         def _finished_decode(res):
817             elapsed = time.time() - self._started_decode
818             self._results.timings["cumulative_decode"] += elapsed
819             return res
820         if self._results:
821             d.addCallback(_finished_decode)
822
823         # pause/check-for-stop just before writing, to honor stopProducing
824         d.addCallback(self._check_for_pause)
825         def _done(buffers):
826             # we start by joining all these buffers together into a single
827             # string. This makes Output.write easier, since it wants to hash
828             # data one segment at a time anyways, and doesn't impact our
829             # memory footprint since we're already peaking at 2*segment_size
830             # inside the codec a moment ago.
831             segment = "".join(buffers)
832             del buffers
833             # we're down to 1*segment_size right now, but write_segment()
834             # will decrypt a copy of the segment internally, which will push
835             # us up to 2*segment_size while it runs.
836             started_decrypt = time.time()
837             self._output.write_segment(segment)
838             if self._results:
839                 elapsed = time.time() - started_decrypt
840                 self._results.timings["cumulative_decrypt"] += elapsed
841         d.addCallback(_done)
842         return d
843
844     def _download_tail_segment(self, res, segnum):
845         self.log("downloading seg#%d of %d (%d%%)"
846                  % (segnum, self._total_segments,
847                     100.0 * segnum / self._total_segments))
848         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
849         started = time.time()
850         d = segmentdler.start()
851         def _finished_fetching(res):
852             elapsed = time.time() - started
853             self._results.timings["cumulative_fetch"] += elapsed
854             return res
855         if self._results:
856             d.addCallback(_finished_fetching)
857         # pause before using more memory
858         d.addCallback(self._check_for_pause)
859         def _started_decode(res):
860             self._started_decode = time.time()
861             return res
862         if self._results:
863             d.addCallback(_started_decode)
864         d.addCallback(lambda (shares, shareids):
865                       self._tail_codec.decode(shares, shareids))
866         def _finished_decode(res):
867             elapsed = time.time() - self._started_decode
868             self._results.timings["cumulative_decode"] += elapsed
869             return res
870         if self._results:
871             d.addCallback(_finished_decode)
872         # pause/check-for-stop just before writing, to honor stopProducing
873         d.addCallback(self._check_for_pause)
874         def _done(buffers):
875             # trim off any padding added by the upload side
876             segment = "".join(buffers)
877             del buffers
878             # we never send empty segments. If the data was an exact multiple
879             # of the segment size, the last segment will be full.
880             pad_size = mathutil.pad_size(self._size, self._segment_size)
881             tail_size = self._segment_size - pad_size
882             segment = segment[:tail_size]
883             started_decrypt = time.time()
884             self._output.write_segment(segment)
885             if self._results:
886                 elapsed = time.time() - started_decrypt
887                 self._results.timings["cumulative_decrypt"] += elapsed
888         d.addCallback(_done)
889         return d
890
891     def _done(self, res):
892         self.log("download done")
893         if self._results:
894             now = time.time()
895             self._results.timings["total"] = now - self._started
896             self._results.timings["segments"] = now - self._started_fetching
897         self._output.close()
898         if self.check_crypttext_hash:
899             _assert(self._crypttext_hash == self._output.crypttext_hash,
900                     "bad crypttext_hash: computed=%s, expected=%s" %
901                     (base32.b2a(self._output.crypttext_hash),
902                      base32.b2a(self._crypttext_hash)))
903         if self.check_plaintext_hash:
904             _assert(self._plaintext_hash == self._output.plaintext_hash,
905                     "bad plaintext_hash: computed=%s, expected=%s" %
906                     (base32.b2a(self._output.plaintext_hash),
907                      base32.b2a(self._plaintext_hash)))
908         _assert(self._output.length == self._size,
909                 got=self._output.length, expected=self._size)
910         return self._output.finish()
911
912     def get_download_status(self):
913         return self._status
914
915
916 class LiteralDownloader:
917     def __init__(self, client, u, downloadable):
918         self._uri = IFileURI(u)
919         assert isinstance(self._uri, uri.LiteralFileURI)
920         self._downloadable = downloadable
921         self._status = s = DownloadStatus()
922         s.set_storage_index(None)
923         s.set_helper(False)
924         s.set_status("Done")
925         s.set_active(False)
926         s.set_progress(1.0)
927
928     def start(self):
929         data = self._uri.data
930         self._status.set_size(len(data))
931         self._downloadable.open(len(data))
932         self._downloadable.write(data)
933         self._downloadable.close()
934         return defer.maybeDeferred(self._downloadable.finish)
935
936     def get_download_status(self):
937         return self._status
938
939 class FileName:
940     implements(IDownloadTarget)
941     def __init__(self, filename):
942         self._filename = filename
943         self.f = None
944     def open(self, size):
945         self.f = open(self._filename, "wb")
946         return self.f
947     def write(self, data):
948         self.f.write(data)
949     def close(self):
950         if self.f:
951             self.f.close()
952     def fail(self, why):
953         if self.f:
954             self.f.close()
955             os.unlink(self._filename)
956     def register_canceller(self, cb):
957         pass # we won't use it
958     def finish(self):
959         pass
960
961 class Data:
962     implements(IDownloadTarget)
963     def __init__(self):
964         self._data = []
965     def open(self, size):
966         pass
967     def write(self, data):
968         self._data.append(data)
969     def close(self):
970         self.data = "".join(self._data)
971         del self._data
972     def fail(self, why):
973         del self._data
974     def register_canceller(self, cb):
975         pass # we won't use it
976     def finish(self):
977         return self.data
978
979 class FileHandle:
980     """Use me to download data to a pre-defined filehandle-like object. I
981     will use the target's write() method. I will *not* close the filehandle:
982     I leave that up to the originator of the filehandle. The download process
983     will return the filehandle when it completes.
984     """
985     implements(IDownloadTarget)
986     def __init__(self, filehandle):
987         self._filehandle = filehandle
988     def open(self, size):
989         pass
990     def write(self, data):
991         self._filehandle.write(data)
992     def close(self):
993         # the originator of the filehandle reserves the right to close it
994         pass
995     def fail(self, why):
996         pass
997     def register_canceller(self, cb):
998         pass
999     def finish(self):
1000         return self._filehandle
1001
1002 class Downloader(service.MultiService):
1003     """I am a service that allows file downloading.
1004     """
1005     implements(IDownloader)
1006     name = "downloader"
1007     MAX_DOWNLOAD_STATUSES = 10
1008
1009     def __init__(self):
1010         service.MultiService.__init__(self)
1011         self._all_downloads = weakref.WeakKeyDictionary()
1012         self._recent_download_status = []
1013
1014     def download(self, u, t):
1015         assert self.parent
1016         assert self.running
1017         u = IFileURI(u)
1018         t = IDownloadTarget(t)
1019         assert t.write
1020         assert t.close
1021         if isinstance(u, uri.LiteralFileURI):
1022             dl = LiteralDownloader(self.parent, u, t)
1023         elif isinstance(u, uri.CHKFileURI):
1024             dl = FileDownloader(self.parent, u, t)
1025         else:
1026             raise RuntimeError("I don't know how to download a %s" % u)
1027         self._all_downloads[dl] = None
1028         self._recent_download_status.append(dl.get_download_status())
1029         while len(self._recent_download_status) > self.MAX_DOWNLOAD_STATUSES:
1030             self._recent_download_status.pop(0)
1031         d = dl.start()
1032         return d
1033
1034     # utility functions
1035     def download_to_data(self, uri):
1036         return self.download(uri, Data())
1037     def download_to_filename(self, uri, filename):
1038         return self.download(uri, FileName(filename))
1039     def download_to_filehandle(self, uri, filehandle):
1040         return self.download(uri, FileHandle(filehandle))
1041
1042
1043     def list_all_downloads(self):
1044         return self._all_downloads.keys()
1045     def list_active_downloads(self):
1046         return [d.get_download_status() for d in self._all_downloads.keys()
1047                 if d.get_download_status().get_active()]
1048     def list_recent_downloads(self):
1049         return self._recent_download_status