2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13 IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughSharesError
15 from pycryptopp.cipher.aes import AES
17 class HaveAllPeersError(Exception):
18 # we use this to jump out of the loop
21 class BadURIExtensionHashValue(Exception):
23 class BadPlaintextHashValue(Exception):
25 class BadCrypttextHashValue(Exception):
28 class DownloadStopped(Exception):
31 class DownloadResults:
32 implements(IDownloadResults)
35 self.servers_used = set()
36 self.server_problems = {}
42 def __init__(self, downloadable, key, total_length, log_parent,
44 self.downloadable = downloadable
45 self._decryptor = AES(key)
46 self._crypttext_hasher = hashutil.crypttext_hasher()
47 self._plaintext_hasher = hashutil.plaintext_hasher()
49 self.total_length = total_length
50 self._segment_number = 0
51 self._plaintext_hash_tree = None
52 self._crypttext_hash_tree = None
54 self._log_parent = log_parent
55 self._status = download_status
56 self._status.set_progress(0.0)
58 def log(self, *args, **kwargs):
59 if "parent" not in kwargs:
60 kwargs["parent"] = self._log_parent
61 if "facility" not in kwargs:
62 kwargs["facility"] = "download.output"
63 return log.msg(*args, **kwargs)
65 def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66 self._plaintext_hash_tree = plaintext_hashtree
67 self._crypttext_hash_tree = crypttext_hashtree
69 def write_segment(self, crypttext):
70 self.length += len(crypttext)
71 self._status.set_progress( float(self.length) / self.total_length )
73 # memory footprint: 'crypttext' is the only segment_size usage
74 # outstanding. While we decrypt it into 'plaintext', we hit
76 self._crypttext_hasher.update(crypttext)
77 if self._crypttext_hash_tree:
78 ch = hashutil.crypttext_segment_hasher()
80 crypttext_leaves = {self._segment_number: ch.digest()}
81 self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
83 segnum=self._segment_number, hash=base32.b2a(ch.digest()),
85 self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
87 plaintext = self._decryptor.process(crypttext)
90 # now we're back down to 1*segment_size.
92 self._plaintext_hasher.update(plaintext)
93 if self._plaintext_hash_tree:
94 ph = hashutil.plaintext_segment_hasher()
96 plaintext_leaves = {self._segment_number: ph.digest()}
97 self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
99 segnum=self._segment_number, hash=base32.b2a(ph.digest()),
101 self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
103 self._segment_number += 1
104 # We're still at 1*segment_size. The Downloadable is responsible for
105 # any memory usage beyond this.
108 self.downloadable.open(self.total_length)
109 self.downloadable.write(plaintext)
112 # this is really unusual, and deserves maximum forensics
113 if why.check(DownloadStopped):
114 # except DownloadStopped just means the consumer aborted the
115 # download, not so scary
116 self.log("download stopped", level=log.UNUSUAL)
118 self.log("download failed!", failure=why, level=log.SCARY)
119 self.downloadable.fail(why)
122 self.crypttext_hash = self._crypttext_hasher.digest()
123 self.plaintext_hash = self._plaintext_hasher.digest()
124 self.log("download finished, closing IDownloadable", level=log.NOISY)
125 self.downloadable.close()
128 return self.downloadable.finish()
130 class ValidatedBucket:
131 """I am a front-end for a remote storage bucket, responsible for
132 retrieving and validating data from that bucket.
134 My get_block() method is used by BlockDownloaders.
137 def __init__(self, sharenum, bucket,
138 share_hash_tree, roothash,
140 self.sharenum = sharenum
142 self._share_hash = None # None means not validated yet
143 self.share_hash_tree = share_hash_tree
144 self._roothash = roothash
145 self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
148 def get_block(self, blocknum):
150 d = self.bucket.start()
153 return self.get_block(blocknum)
154 d.addCallback(_started)
157 # the first time we use this bucket, we need to fetch enough elements
158 # of the share hash tree to validate it from our share hash up to the
160 if not self._share_hash:
161 d1 = self.bucket.get_share_hashes()
163 d1 = defer.succeed([])
165 # we might need to grab some elements of our block hash tree, to
166 # validate the requested block up to the share hash
167 needed = self.block_hash_tree.needed_hashes(blocknum)
169 # TODO: get fewer hashes, use get_block_hashes(needed)
170 d2 = self.bucket.get_block_hashes()
172 d2 = defer.succeed([])
174 d3 = self.bucket.get_block(blocknum)
176 d = defer.gatherResults([d1, d2, d3])
177 d.addCallback(self._got_data, blocknum)
180 def _got_data(self, res, blocknum):
181 sharehashes, blockhashes, blockdata = res
182 blockhash = None # to make logging it safe
185 if not self._share_hash:
186 sh = dict(sharehashes)
187 sh[0] = self._roothash # always use our own root, from the URI
188 sht = self.share_hash_tree
189 if sht.get_leaf_index(self.sharenum) not in sh:
190 raise hashtree.NotEnoughHashesError
192 self._share_hash = sht.get_leaf(self.sharenum)
194 blockhash = hashutil.block_hash(blockdata)
195 #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
197 # (self.sharenum, blocknum, len(blockdata),
198 # blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
200 # we always validate the blockhash
201 bh = dict(enumerate(blockhashes))
202 # replace blockhash root with validated value
203 bh[0] = self._share_hash
204 self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
206 except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
207 # log.WEIRD: indicates undetected disk/network error, or more
208 # likely a programming error
209 log.msg("hash failure in block=%d, shnum=%d on %s" %
210 (blocknum, self.sharenum, self.bucket))
212 log.msg(""" failure occurred when checking the block_hash_tree.
213 This suggests that either the block data was bad, or that the
214 block hashes we received along with it were bad.""")
216 log.msg(""" the failure probably occurred when checking the
217 share_hash_tree, which suggests that the share hashes we
218 received from the remote peer were bad.""")
219 log.msg(" have self._share_hash: %s" % bool(self._share_hash))
220 log.msg(" block length: %d" % len(blockdata))
221 log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
222 if len(blockdata) < 100:
223 log.msg(" block data: %r" % (blockdata,))
225 log.msg(" block data start/end: %r .. %r" %
226 (blockdata[:50], blockdata[-50:]))
227 log.msg(" root hash: %s" % base32.b2a(self._roothash))
228 log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
229 log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
231 for i,h in sorted(sharehashes):
232 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
233 log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
235 for i,h in enumerate(blockhashes):
236 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
237 log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
240 # If we made it here, the block is good. If the hash trees didn't
241 # like what they saw, they would have raised a BadHashError, causing
242 # our caller to see a Failure and thus ignore this block (as well as
243 # dropping this bucket).
248 class BlockDownloader:
249 """I am responsible for downloading a single block (from a single bucket)
250 for a single segment.
252 I am a child of the SegmentDownloader.
255 def __init__(self, vbucket, blocknum, parent, results):
256 self.vbucket = vbucket
257 self.blocknum = blocknum
259 self.results = results
260 self._log_number = self.parent.log("starting block %d" % blocknum)
262 def log(self, msg, parent=None):
264 parent = self._log_number
265 return self.parent.log(msg, parent=parent)
267 def start(self, segnum):
268 lognum = self.log("get_block(segnum=%d)" % segnum)
269 started = time.time()
270 d = self.vbucket.get_block(segnum)
271 d.addCallbacks(self._hold_block, self._got_block_error,
272 callbackArgs=(started, lognum,), errbackArgs=(lognum,))
275 def _hold_block(self, data, started, lognum):
277 elapsed = time.time() - started
278 peerid = self.vbucket.bucket.get_peerid()
279 if peerid not in self.results.timings["fetch_per_server"]:
280 self.results.timings["fetch_per_server"][peerid] = []
281 self.results.timings["fetch_per_server"][peerid].append(elapsed)
282 self.log("got block", parent=lognum)
283 self.parent.hold_block(self.blocknum, data)
285 def _got_block_error(self, f, lognum):
286 self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
289 peerid = self.vbucket.bucket.get_peerid()
290 self.results.server_problems[peerid] = str(f)
291 self.parent.bucket_failed(self.vbucket)
293 class SegmentDownloader:
294 """I am responsible for downloading all the blocks for a single segment
297 I am a child of the FileDownloader.
300 def __init__(self, parent, segmentnumber, needed_shares, results):
302 self.segmentnumber = segmentnumber
303 self.needed_blocks = needed_shares
304 self.blocks = {} # k: blocknum, v: data
305 self.results = results
306 self._log_number = self.parent.log("starting segment %d" %
309 def log(self, msg, parent=None):
311 parent = self._log_number
312 return self.parent.log(msg, parent=parent)
315 return self._download()
320 if len(self.blocks) >= self.needed_blocks:
321 # we only need self.needed_blocks blocks
322 # we want to get the smallest blockids, because they are
323 # more likely to be fast "primary blocks"
324 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
326 for blocknum in blockids:
327 blocks.append(self.blocks[blocknum])
328 return (blocks, blockids)
330 return self._download()
335 # fill our set of active buckets, maybe raising NotEnoughSharesError
336 active_buckets = self.parent._activate_enough_buckets()
337 # Now we have enough buckets, in self.parent.active_buckets.
339 # in test cases, bd.start might mutate active_buckets right away, so
340 # we need to put off calling start() until we've iterated all the way
343 for blocknum, vbucket in active_buckets.iteritems():
344 bd = BlockDownloader(vbucket, blocknum, self, self.results)
345 downloaders.append(bd)
347 self.results.servers_used.add(vbucket.bucket.get_peerid())
348 l = [bd.start(self.segmentnumber) for bd in downloaders]
349 return defer.DeferredList(l, fireOnOneErrback=True)
351 def hold_block(self, blocknum, data):
352 self.blocks[blocknum] = data
354 def bucket_failed(self, vbucket):
355 self.parent.bucket_failed(vbucket)
357 class DownloadStatus:
358 implements(IDownloadStatus)
359 statusid_counter = itertools.count(0)
362 self.storage_index = None
365 self.status = "Not started"
371 self.counter = self.statusid_counter.next()
372 self.started = time.time()
374 def get_started(self):
376 def get_storage_index(self):
377 return self.storage_index
380 def using_helper(self):
382 def get_status(self):
385 status += " (output paused)"
387 status += " (output stopped)"
389 def get_progress(self):
391 def get_active(self):
393 def get_results(self):
395 def get_counter(self):
398 def set_storage_index(self, si):
399 self.storage_index = si
400 def set_size(self, size):
402 def set_helper(self, helper):
404 def set_status(self, status):
406 def set_paused(self, paused):
408 def set_stopped(self, stopped):
409 self.stopped = stopped
410 def set_progress(self, value):
411 self.progress = value
412 def set_active(self, value):
414 def set_results(self, value):
417 class FileDownloader:
418 implements(IPushProducer)
419 check_crypttext_hash = True
420 check_plaintext_hash = True
423 def __init__(self, client, u, downloadable):
424 self._client = client
427 self._storage_index = u.storage_index
428 self._uri_extension_hash = u.uri_extension_hash
429 self._total_shares = u.total_shares
431 self._num_needed_shares = u.needed_shares
433 self._si_s = storage.si_b2a(self._storage_index)
436 self._started = time.time()
437 self._status = s = DownloadStatus()
438 s.set_status("Starting")
439 s.set_storage_index(self._storage_index)
440 s.set_size(self._size)
444 self._results = DownloadResults()
445 s.set_results(self._results)
446 self._results.file_size = self._size
447 self._results.timings["servers_peer_selection"] = {}
448 self._results.timings["fetch_per_server"] = {}
449 self._results.timings["cumulative_fetch"] = 0.0
450 self._results.timings["cumulative_decode"] = 0.0
451 self._results.timings["cumulative_decrypt"] = 0.0
452 self._results.timings["paused"] = 0.0
454 if IConsumer.providedBy(downloadable):
455 downloadable.registerProducer(self, True)
456 self._downloadable = downloadable
457 self._output = Output(downloadable, u.key, self._size, self._log_number,
460 self._stopped = False
462 self.active_buckets = {} # k: shnum, v: bucket
463 self._share_buckets = [] # list of (sharenum, bucket) tuples
464 self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
465 self._uri_extension_sources = []
467 self._uri_extension_data = None
469 self._fetch_failures = {"uri_extension": 0,
470 "plaintext_hashroot": 0,
471 "plaintext_hashtree": 0,
472 "crypttext_hashroot": 0,
473 "crypttext_hashtree": 0,
476 def init_logging(self):
477 self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
478 num = self._client.log(format="FileDownloader(%(si)s): starting",
479 si=storage.si_b2a(self._storage_index))
480 self._log_number = num
482 def log(self, *args, **kwargs):
483 if "parent" not in kwargs:
484 kwargs["parent"] = self._log_number
485 if "facility" not in kwargs:
486 kwargs["facility"] = "tahoe.download"
487 return log.msg(*args, **kwargs)
489 def pauseProducing(self):
492 self._paused = defer.Deferred()
493 self._paused_at = time.time()
495 self._status.set_paused(True)
497 def resumeProducing(self):
501 eventually(p.callback, None)
503 self._status.set_paused(False)
505 def stopProducing(self):
506 self.log("Download.stopProducing")
508 paused_for = time.time() - self._paused_at
509 self._results.timings['paused'] += paused_for
511 self._status.set_stopped(True)
512 self._status.set_active(False)
515 self.log("starting download")
517 # first step: who should we download from?
518 d = defer.maybeDeferred(self._get_all_shareholders)
519 d.addCallback(self._got_all_shareholders)
520 # now get the uri_extension block from somebody and validate it
521 d.addCallback(self._obtain_uri_extension)
522 d.addCallback(self._got_uri_extension)
523 d.addCallback(self._get_hashtrees)
524 d.addCallback(self._create_validated_buckets)
525 # once we know that, we can download blocks from everybody
526 d.addCallback(self._download_all_segments)
529 self._status.set_status("Finished")
530 self._status.set_active(False)
531 self._status.set_paused(False)
532 if IConsumer.providedBy(self._downloadable):
533 self._downloadable.unregisterProducer()
538 self._status.set_status("Failed")
539 self._status.set_active(False)
540 self._output.fail(why)
542 d.addErrback(_failed)
543 d.addCallback(self._done)
546 def _get_all_shareholders(self):
548 for (peerid,ss) in self._client.get_permuted_peers("storage",
549 self._storage_index):
550 d = ss.callRemote("get_buckets", self._storage_index)
551 d.addCallbacks(self._got_response, self._got_error,
552 callbackArgs=(peerid,))
554 self._responses_received = 0
555 self._queries_sent = len(dl)
557 self._status.set_status("Locating Shares (%d/%d)" %
558 (self._responses_received,
560 return defer.DeferredList(dl)
562 def _got_response(self, buckets, peerid):
563 self._responses_received += 1
565 elapsed = time.time() - self._started
566 self._results.timings["servers_peer_selection"][peerid] = elapsed
568 self._status.set_status("Locating Shares (%d/%d)" %
569 (self._responses_received,
571 for sharenum, bucket in buckets.iteritems():
572 b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
573 self.add_share_bucket(sharenum, b)
574 self._uri_extension_sources.append(b)
576 if peerid not in self._results.servermap:
577 self._results.servermap[peerid] = set()
578 self._results.servermap[peerid].add(sharenum)
580 def add_share_bucket(self, sharenum, bucket):
581 # this is split out for the benefit of test_encode.py
582 self._share_buckets.append( (sharenum, bucket) )
584 def _got_error(self, f):
585 self._client.log("Somebody failed. -- %s" % (f,))
587 def bucket_failed(self, vbucket):
588 shnum = vbucket.sharenum
589 del self.active_buckets[shnum]
590 s = self._share_vbuckets[shnum]
591 # s is a set of ValidatedBucket instances
593 # ... which might now be empty
595 # there are no more buckets which can provide this share, so
596 # remove the key. This may prompt us to use a different share.
597 del self._share_vbuckets[shnum]
599 def _got_all_shareholders(self, res):
602 self._results.timings["peer_selection"] = now - self._started
604 if len(self._share_buckets) < self._num_needed_shares:
605 raise NotEnoughSharesError
607 #for s in self._share_vbuckets.values():
609 # assert isinstance(vb, ValidatedBucket), \
610 # "vb is %s but should be a ValidatedBucket" % (vb,)
612 def _unpack_uri_extension_data(self, data):
613 return uri.unpack_extension(data)
615 def _obtain_uri_extension(self, ignored):
616 # all shareholders are supposed to have a copy of uri_extension, and
617 # all are supposed to be identical. We compute the hash of the data
618 # that comes back, and compare it against the version in our URI. If
619 # they don't match, ignore their data and try someone else.
621 self._status.set_status("Obtaining URI Extension")
623 self._uri_extension_fetch_started = time.time()
624 def _validate(proposal, bucket):
625 h = hashutil.uri_extension_hash(proposal)
626 if h != self._uri_extension_hash:
627 self._fetch_failures["uri_extension"] += 1
628 msg = ("The copy of uri_extension we received from "
629 "%s was bad: wanted %s, got %s" %
631 base32.b2a(self._uri_extension_hash),
633 self.log(msg, level=log.SCARY)
634 raise BadURIExtensionHashValue(msg)
635 return self._unpack_uri_extension_data(proposal)
636 return self._obtain_validated_thing(None,
637 self._uri_extension_sources,
639 "get_uri_extension", (), _validate)
641 def _obtain_validated_thing(self, ignored, sources, name, methname, args,
644 raise NotEnoughSharesError("started with zero peers while fetching "
647 sources = sources[1:]
648 #d = bucket.callRemote(methname, *args)
649 d = bucket.startIfNecessary()
650 d.addCallback(lambda res: getattr(bucket, methname)(*args))
651 d.addCallback(validatorfunc, bucket)
653 self.log("%s from vbucket %s failed:" % (name, bucket),
654 failure=f, level=log.WEIRD)
656 raise NotEnoughSharesError("ran out of peers, last error was %s"
658 # try again with a different one
659 return self._obtain_validated_thing(None, sources, name,
660 methname, args, validatorfunc)
664 def _got_uri_extension(self, uri_extension_data):
666 elapsed = time.time() - self._uri_extension_fetch_started
667 self._results.timings["uri_extension"] = elapsed
669 d = self._uri_extension_data = uri_extension_data
671 self._codec = codec.get_decoder_by_name(d['codec_name'])
672 self._codec.set_serialized_params(d['codec_params'])
673 self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
674 self._tail_codec.set_serialized_params(d['tail_codec_params'])
676 crypttext_hash = d.get('crypttext_hash', None) # optional
678 assert isinstance(crypttext_hash, str)
679 assert len(crypttext_hash) == 32
680 self._crypttext_hash = crypttext_hash
681 self._plaintext_hash = d.get('plaintext_hash', None) # optional
683 self._roothash = d['share_root_hash']
685 self._segment_size = segment_size = d['segment_size']
686 self._total_segments = mathutil.div_ceil(self._size, segment_size)
687 self._current_segnum = 0
689 self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
690 self._share_hashtree.set_hashes({0: self._roothash})
692 def _get_hashtrees(self, res):
693 self._get_hashtrees_started = time.time()
695 self._status.set_status("Retrieving Hash Trees")
696 d = defer.maybeDeferred(self._get_plaintext_hashtrees)
697 d.addCallback(self._get_crypttext_hashtrees)
698 d.addCallback(self._setup_hashtrees)
701 def _get_plaintext_hashtrees(self):
702 # plaintext hashes are optional. If the root isn't in the UEB, then
703 # the share will be holding an empty list. We don't even bother
705 if "plaintext_root_hash" not in self._uri_extension_data:
706 self._plaintext_hashtree = None
708 def _validate_plaintext_hashtree(proposal, bucket):
709 if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
710 self._fetch_failures["plaintext_hashroot"] += 1
711 msg = ("The copy of the plaintext_root_hash we received from"
712 " %s was bad" % bucket)
713 raise BadPlaintextHashValue(msg)
714 pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
715 pt_hashes = dict(list(enumerate(proposal)))
717 pt_hashtree.set_hashes(pt_hashes)
718 except hashtree.BadHashError:
719 # the hashes they gave us were not self-consistent, even
720 # though the root matched what we saw in the uri_extension
722 self._fetch_failures["plaintext_hashtree"] += 1
724 self._plaintext_hashtree = pt_hashtree
725 d = self._obtain_validated_thing(None,
726 self._uri_extension_sources,
728 "get_plaintext_hashes", (),
729 _validate_plaintext_hashtree)
732 def _get_crypttext_hashtrees(self, res):
733 # crypttext hashes are optional too
734 if "crypttext_root_hash" not in self._uri_extension_data:
735 self._crypttext_hashtree = None
737 def _validate_crypttext_hashtree(proposal, bucket):
738 if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
739 self._fetch_failures["crypttext_hashroot"] += 1
740 msg = ("The copy of the crypttext_root_hash we received from"
741 " %s was bad" % bucket)
742 raise BadCrypttextHashValue(msg)
743 ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
744 ct_hashes = dict(list(enumerate(proposal)))
746 ct_hashtree.set_hashes(ct_hashes)
747 except hashtree.BadHashError:
748 self._fetch_failures["crypttext_hashtree"] += 1
750 ct_hashtree.set_hashes(ct_hashes)
751 self._crypttext_hashtree = ct_hashtree
752 d = self._obtain_validated_thing(None,
753 self._uri_extension_sources,
755 "get_crypttext_hashes", (),
756 _validate_crypttext_hashtree)
759 def _setup_hashtrees(self, res):
760 self._output.setup_hashtrees(self._plaintext_hashtree,
761 self._crypttext_hashtree)
763 elapsed = time.time() - self._get_hashtrees_started
764 self._results.timings["hashtrees"] = elapsed
766 def _create_validated_buckets(self, ignored=None):
767 self._share_vbuckets = {}
768 for sharenum, bucket in self._share_buckets:
769 vbucket = ValidatedBucket(sharenum, bucket,
770 self._share_hashtree,
772 self._total_segments)
773 s = self._share_vbuckets.setdefault(sharenum, set())
776 def _activate_enough_buckets(self):
777 """either return a mapping from shnum to a ValidatedBucket that can
778 provide data for that share, or raise NotEnoughSharesError"""
780 while len(self.active_buckets) < self._num_needed_shares:
782 handled_shnums = set(self.active_buckets.keys())
783 available_shnums = set(self._share_vbuckets.keys())
784 potential_shnums = list(available_shnums - handled_shnums)
785 if not potential_shnums:
786 raise NotEnoughSharesError
787 # choose a random share
788 shnum = random.choice(potential_shnums)
789 # and a random bucket that will provide it
790 validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
791 self.active_buckets[shnum] = validated_bucket
792 return self.active_buckets
795 def _download_all_segments(self, res):
796 # the promise: upon entry to this function, self._share_vbuckets
797 # contains enough buckets to complete the download, and some extra
798 # ones to tolerate some buckets dropping out or having errors.
799 # self._share_vbuckets is a dictionary that maps from shnum to a set
800 # of ValidatedBuckets, which themselves are wrappers around
801 # RIBucketReader references.
802 self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
804 self._started_fetching = time.time()
806 d = defer.succeed(None)
807 for segnum in range(self._total_segments-1):
808 d.addCallback(self._download_segment, segnum)
809 # this pause, at the end of write, prevents pre-fetch from
810 # happening until the consumer is ready for more data.
811 d.addCallback(self._check_for_pause)
812 d.addCallback(self._download_tail_segment, self._total_segments-1)
815 def _check_for_pause(self, res):
818 self._paused.addCallback(lambda ignored: d.callback(res))
821 raise DownloadStopped("our Consumer called stopProducing()")
824 def _download_segment(self, res, segnum):
826 self._status.set_status("Downloading segment %d of %d" %
827 (segnum+1, self._total_segments))
828 self.log("downloading seg#%d of %d (%d%%)"
829 % (segnum, self._total_segments,
830 100.0 * segnum / self._total_segments))
831 # memory footprint: when the SegmentDownloader finishes pulling down
832 # all shares, we have 1*segment_size of usage.
833 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
835 started = time.time()
836 d = segmentdler.start()
837 def _finished_fetching(res):
838 elapsed = time.time() - started
839 self._results.timings["cumulative_fetch"] += elapsed
842 d.addCallback(_finished_fetching)
843 # pause before using more memory
844 d.addCallback(self._check_for_pause)
845 # while the codec does its job, we hit 2*segment_size
846 def _started_decode(res):
847 self._started_decode = time.time()
850 d.addCallback(_started_decode)
851 d.addCallback(lambda (shares, shareids):
852 self._codec.decode(shares, shareids))
853 # once the codec is done, we drop back to 1*segment_size, because
854 # 'shares' goes out of scope. The memory usage is all in the
855 # plaintext now, spread out into a bunch of tiny buffers.
856 def _finished_decode(res):
857 elapsed = time.time() - self._started_decode
858 self._results.timings["cumulative_decode"] += elapsed
861 d.addCallback(_finished_decode)
863 # pause/check-for-stop just before writing, to honor stopProducing
864 d.addCallback(self._check_for_pause)
866 # we start by joining all these buffers together into a single
867 # string. This makes Output.write easier, since it wants to hash
868 # data one segment at a time anyways, and doesn't impact our
869 # memory footprint since we're already peaking at 2*segment_size
870 # inside the codec a moment ago.
871 segment = "".join(buffers)
873 # we're down to 1*segment_size right now, but write_segment()
874 # will decrypt a copy of the segment internally, which will push
875 # us up to 2*segment_size while it runs.
876 started_decrypt = time.time()
877 self._output.write_segment(segment)
879 elapsed = time.time() - started_decrypt
880 self._results.timings["cumulative_decrypt"] += elapsed
884 def _download_tail_segment(self, res, segnum):
885 self.log("downloading seg#%d of %d (%d%%)"
886 % (segnum, self._total_segments,
887 100.0 * segnum / self._total_segments))
888 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
890 started = time.time()
891 d = segmentdler.start()
892 def _finished_fetching(res):
893 elapsed = time.time() - started
894 self._results.timings["cumulative_fetch"] += elapsed
897 d.addCallback(_finished_fetching)
898 # pause before using more memory
899 d.addCallback(self._check_for_pause)
900 def _started_decode(res):
901 self._started_decode = time.time()
904 d.addCallback(_started_decode)
905 d.addCallback(lambda (shares, shareids):
906 self._tail_codec.decode(shares, shareids))
907 def _finished_decode(res):
908 elapsed = time.time() - self._started_decode
909 self._results.timings["cumulative_decode"] += elapsed
912 d.addCallback(_finished_decode)
913 # pause/check-for-stop just before writing, to honor stopProducing
914 d.addCallback(self._check_for_pause)
916 # trim off any padding added by the upload side
917 segment = "".join(buffers)
919 # we never send empty segments. If the data was an exact multiple
920 # of the segment size, the last segment will be full.
921 pad_size = mathutil.pad_size(self._size, self._segment_size)
922 tail_size = self._segment_size - pad_size
923 segment = segment[:tail_size]
924 started_decrypt = time.time()
925 self._output.write_segment(segment)
927 elapsed = time.time() - started_decrypt
928 self._results.timings["cumulative_decrypt"] += elapsed
932 def _done(self, res):
933 self.log("download done")
936 self._results.timings["total"] = now - self._started
937 self._results.timings["segments"] = now - self._started_fetching
939 if self.check_crypttext_hash and self._crypttext_hash:
940 _assert(self._crypttext_hash == self._output.crypttext_hash,
941 "bad crypttext_hash: computed=%s, expected=%s" %
942 (base32.b2a(self._output.crypttext_hash),
943 base32.b2a(self._crypttext_hash)))
944 if self.check_plaintext_hash and self._plaintext_hash:
945 _assert(self._plaintext_hash == self._output.plaintext_hash,
946 "bad plaintext_hash: computed=%s, expected=%s" %
947 (base32.b2a(self._output.plaintext_hash),
948 base32.b2a(self._plaintext_hash)))
949 _assert(self._output.length == self._size,
950 got=self._output.length, expected=self._size)
951 return self._output.finish()
953 def get_download_status(self):
957 class LiteralDownloader:
958 def __init__(self, client, u, downloadable):
959 self._uri = IFileURI(u)
960 assert isinstance(self._uri, uri.LiteralFileURI)
961 self._downloadable = downloadable
962 self._status = s = DownloadStatus()
963 s.set_storage_index(None)
970 data = self._uri.data
971 self._status.set_size(len(data))
972 self._downloadable.open(len(data))
973 self._downloadable.write(data)
974 self._downloadable.close()
975 return defer.maybeDeferred(self._downloadable.finish)
977 def get_download_status(self):
981 implements(IDownloadTarget)
982 def __init__(self, filename):
983 self._filename = filename
985 def open(self, size):
986 self.f = open(self._filename, "wb")
988 def write(self, data):
996 os.unlink(self._filename)
997 def register_canceller(self, cb):
998 pass # we won't use it
1003 implements(IDownloadTarget)
1006 def open(self, size):
1008 def write(self, data):
1009 self._data.append(data)
1011 self.data = "".join(self._data)
1013 def fail(self, why):
1015 def register_canceller(self, cb):
1016 pass # we won't use it
1021 """Use me to download data to a pre-defined filehandle-like object. I
1022 will use the target's write() method. I will *not* close the filehandle:
1023 I leave that up to the originator of the filehandle. The download process
1024 will return the filehandle when it completes.
1026 implements(IDownloadTarget)
1027 def __init__(self, filehandle):
1028 self._filehandle = filehandle
1029 def open(self, size):
1031 def write(self, data):
1032 self._filehandle.write(data)
1034 # the originator of the filehandle reserves the right to close it
1036 def fail(self, why):
1038 def register_canceller(self, cb):
1041 return self._filehandle
1043 class Downloader(service.MultiService):
1044 """I am a service that allows file downloading.
1046 implements(IDownloader)
1048 MAX_DOWNLOAD_STATUSES = 10
1050 def __init__(self, stats_provider=None):
1051 service.MultiService.__init__(self)
1052 self.stats_provider = stats_provider
1053 self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1054 self._all_download_statuses = weakref.WeakKeyDictionary()
1055 self._recent_download_statuses = []
1057 def download(self, u, t):
1061 t = IDownloadTarget(t)
1065 if self.stats_provider:
1066 self.stats_provider.count('downloader.files_downloaded', 1)
1067 self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1069 if isinstance(u, uri.LiteralFileURI):
1070 dl = LiteralDownloader(self.parent, u, t)
1071 elif isinstance(u, uri.CHKFileURI):
1072 dl = FileDownloader(self.parent, u, t)
1074 raise RuntimeError("I don't know how to download a %s" % u)
1075 self._add_download(dl)
1080 def download_to_data(self, uri):
1081 return self.download(uri, Data())
1082 def download_to_filename(self, uri, filename):
1083 return self.download(uri, FileName(filename))
1084 def download_to_filehandle(self, uri, filehandle):
1085 return self.download(uri, FileHandle(filehandle))
1087 def _add_download(self, downloader):
1088 self._all_downloads[downloader] = None
1089 s = downloader.get_download_status()
1090 self._all_download_statuses[s] = None
1091 self._recent_download_statuses.append(s)
1092 while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1093 self._recent_download_statuses.pop(0)
1095 def list_all_download_statuses(self):
1096 for ds in self._all_download_statuses: