2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13 IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
17 class HaveAllPeersError(Exception):
18 # we use this to jump out of the loop
21 class BadURIExtensionHashValue(Exception):
23 class BadPlaintextHashValue(Exception):
25 class BadCrypttextHashValue(Exception):
28 class DownloadStopped(Exception):
31 class DownloadResults:
32 implements(IDownloadResults)
35 self.servers_used = set()
36 self.server_problems = {}
42 def __init__(self, downloadable, key, total_length, log_parent,
44 self.downloadable = downloadable
45 self._decryptor = AES(key)
46 self._crypttext_hasher = hashutil.crypttext_hasher()
47 self._plaintext_hasher = hashutil.plaintext_hasher()
49 self.total_length = total_length
50 self._segment_number = 0
51 self._plaintext_hash_tree = None
52 self._crypttext_hash_tree = None
54 self._log_parent = log_parent
55 self._status = download_status
56 self._status.set_progress(0.0)
58 def log(self, *args, **kwargs):
59 if "parent" not in kwargs:
60 kwargs["parent"] = self._log_parent
61 if "facility" not in kwargs:
62 kwargs["facility"] = "download.output"
63 return log.msg(*args, **kwargs)
65 def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66 self._plaintext_hash_tree = plaintext_hashtree
67 self._crypttext_hash_tree = crypttext_hashtree
69 def write_segment(self, crypttext):
70 self.length += len(crypttext)
71 self._status.set_progress( float(self.length) / self.total_length )
73 # memory footprint: 'crypttext' is the only segment_size usage
74 # outstanding. While we decrypt it into 'plaintext', we hit
76 self._crypttext_hasher.update(crypttext)
77 if self._crypttext_hash_tree:
78 ch = hashutil.crypttext_segment_hasher()
80 crypttext_leaves = {self._segment_number: ch.digest()}
81 self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
83 segnum=self._segment_number, hash=base32.b2a(ch.digest()),
85 self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
87 plaintext = self._decryptor.process(crypttext)
90 # now we're back down to 1*segment_size.
92 self._plaintext_hasher.update(plaintext)
93 if self._plaintext_hash_tree:
94 ph = hashutil.plaintext_segment_hasher()
96 plaintext_leaves = {self._segment_number: ph.digest()}
97 self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
99 segnum=self._segment_number, hash=base32.b2a(ph.digest()),
101 self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
103 self._segment_number += 1
104 # We're still at 1*segment_size. The Downloadable is responsible for
105 # any memory usage beyond this.
108 self.downloadable.open(self.total_length)
109 self.downloadable.write(plaintext)
112 # this is really unusual, and deserves maximum forensics
113 self.log("download failed!", failure=why, level=log.SCARY)
114 self.downloadable.fail(why)
117 self.crypttext_hash = self._crypttext_hasher.digest()
118 self.plaintext_hash = self._plaintext_hasher.digest()
119 self.log("download finished, closing IDownloadable", level=log.NOISY)
120 self.downloadable.close()
123 return self.downloadable.finish()
125 class ValidatedBucket:
126 """I am a front-end for a remote storage bucket, responsible for
127 retrieving and validating data from that bucket.
129 My get_block() method is used by BlockDownloaders.
132 def __init__(self, sharenum, bucket,
133 share_hash_tree, roothash,
135 self.sharenum = sharenum
137 self._share_hash = None # None means not validated yet
138 self.share_hash_tree = share_hash_tree
139 self._roothash = roothash
140 self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
143 def get_block(self, blocknum):
145 d = self.bucket.start()
148 return self.get_block(blocknum)
149 d.addCallback(_started)
152 # the first time we use this bucket, we need to fetch enough elements
153 # of the share hash tree to validate it from our share hash up to the
155 if not self._share_hash:
156 d1 = self.bucket.get_share_hashes()
158 d1 = defer.succeed([])
160 # we might need to grab some elements of our block hash tree, to
161 # validate the requested block up to the share hash
162 needed = self.block_hash_tree.needed_hashes(blocknum)
164 # TODO: get fewer hashes, use get_block_hashes(needed)
165 d2 = self.bucket.get_block_hashes()
167 d2 = defer.succeed([])
169 d3 = self.bucket.get_block(blocknum)
171 d = defer.gatherResults([d1, d2, d3])
172 d.addCallback(self._got_data, blocknum)
175 def _got_data(self, res, blocknum):
176 sharehashes, blockhashes, blockdata = res
177 blockhash = None # to make logging it safe
180 if not self._share_hash:
181 sh = dict(sharehashes)
182 sh[0] = self._roothash # always use our own root, from the URI
183 sht = self.share_hash_tree
184 if sht.get_leaf_index(self.sharenum) not in sh:
185 raise hashtree.NotEnoughHashesError
187 self._share_hash = sht.get_leaf(self.sharenum)
189 blockhash = hashutil.block_hash(blockdata)
190 #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
192 # (self.sharenum, blocknum, len(blockdata),
193 # blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
195 # we always validate the blockhash
196 bh = dict(enumerate(blockhashes))
197 # replace blockhash root with validated value
198 bh[0] = self._share_hash
199 self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
201 except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
202 # log.WEIRD: indicates undetected disk/network error, or more
203 # likely a programming error
204 log.msg("hash failure in block=%d, shnum=%d on %s" %
205 (blocknum, self.sharenum, self.bucket))
207 log.msg(""" failure occurred when checking the block_hash_tree.
208 This suggests that either the block data was bad, or that the
209 block hashes we received along with it were bad.""")
211 log.msg(""" the failure probably occurred when checking the
212 share_hash_tree, which suggests that the share hashes we
213 received from the remote peer were bad.""")
214 log.msg(" have self._share_hash: %s" % bool(self._share_hash))
215 log.msg(" block length: %d" % len(blockdata))
216 log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
217 if len(blockdata) < 100:
218 log.msg(" block data: %r" % (blockdata,))
220 log.msg(" block data start/end: %r .. %r" %
221 (blockdata[:50], blockdata[-50:]))
222 log.msg(" root hash: %s" % base32.b2a(self._roothash))
223 log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
224 log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
226 for i,h in sorted(sharehashes):
227 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
228 log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
230 for i,h in enumerate(blockhashes):
231 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
232 log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
235 # If we made it here, the block is good. If the hash trees didn't
236 # like what they saw, they would have raised a BadHashError, causing
237 # our caller to see a Failure and thus ignore this block (as well as
238 # dropping this bucket).
243 class BlockDownloader:
244 """I am responsible for downloading a single block (from a single bucket)
245 for a single segment.
247 I am a child of the SegmentDownloader.
250 def __init__(self, vbucket, blocknum, parent, results):
251 self.vbucket = vbucket
252 self.blocknum = blocknum
254 self.results = results
255 self._log_number = self.parent.log("starting block %d" % blocknum)
257 def log(self, msg, parent=None):
259 parent = self._log_number
260 return self.parent.log(msg, parent=parent)
262 def start(self, segnum):
263 lognum = self.log("get_block(segnum=%d)" % segnum)
264 started = time.time()
265 d = self.vbucket.get_block(segnum)
266 d.addCallbacks(self._hold_block, self._got_block_error,
267 callbackArgs=(started, lognum,), errbackArgs=(lognum,))
270 def _hold_block(self, data, started, lognum):
272 elapsed = time.time() - started
273 peerid = self.vbucket.bucket.get_peerid()
274 if peerid not in self.results.timings["fetch_per_server"]:
275 self.results.timings["fetch_per_server"][peerid] = []
276 self.results.timings["fetch_per_server"][peerid].append(elapsed)
277 self.log("got block", parent=lognum)
278 self.parent.hold_block(self.blocknum, data)
280 def _got_block_error(self, f, lognum):
281 self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
284 peerid = self.vbucket.bucket.get_peerid()
285 self.results.server_problems[peerid] = str(f)
286 self.parent.bucket_failed(self.vbucket)
288 class SegmentDownloader:
289 """I am responsible for downloading all the blocks for a single segment
292 I am a child of the FileDownloader.
295 def __init__(self, parent, segmentnumber, needed_shares, results):
297 self.segmentnumber = segmentnumber
298 self.needed_blocks = needed_shares
299 self.blocks = {} # k: blocknum, v: data
300 self.results = results
301 self._log_number = self.parent.log("starting segment %d" %
304 def log(self, msg, parent=None):
306 parent = self._log_number
307 return self.parent.log(msg, parent=parent)
310 return self._download()
315 if len(self.blocks) >= self.needed_blocks:
316 # we only need self.needed_blocks blocks
317 # we want to get the smallest blockids, because they are
318 # more likely to be fast "primary blocks"
319 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
321 for blocknum in blockids:
322 blocks.append(self.blocks[blocknum])
323 return (blocks, blockids)
325 return self._download()
330 # fill our set of active buckets, maybe raising NotEnoughPeersError
331 active_buckets = self.parent._activate_enough_buckets()
332 # Now we have enough buckets, in self.parent.active_buckets.
334 # in test cases, bd.start might mutate active_buckets right away, so
335 # we need to put off calling start() until we've iterated all the way
338 for blocknum, vbucket in active_buckets.iteritems():
339 bd = BlockDownloader(vbucket, blocknum, self, self.results)
340 downloaders.append(bd)
342 self.results.servers_used.add(vbucket.bucket.get_peerid())
343 l = [bd.start(self.segmentnumber) for bd in downloaders]
344 return defer.DeferredList(l, fireOnOneErrback=True)
346 def hold_block(self, blocknum, data):
347 self.blocks[blocknum] = data
349 def bucket_failed(self, vbucket):
350 self.parent.bucket_failed(vbucket)
352 class DownloadStatus:
353 implements(IDownloadStatus)
354 statusid_counter = itertools.count(0)
357 self.storage_index = None
360 self.status = "Not started"
366 self.counter = self.statusid_counter.next()
367 self.started = time.time()
369 def get_started(self):
371 def get_storage_index(self):
372 return self.storage_index
375 def using_helper(self):
377 def get_status(self):
380 status += " (output paused)"
382 status += " (output stopped)"
384 def get_progress(self):
386 def get_active(self):
388 def get_results(self):
390 def get_counter(self):
393 def set_storage_index(self, si):
394 self.storage_index = si
395 def set_size(self, size):
397 def set_helper(self, helper):
399 def set_status(self, status):
401 def set_paused(self, paused):
403 def set_stopped(self, stopped):
404 self.stopped = stopped
405 def set_progress(self, value):
406 self.progress = value
407 def set_active(self, value):
409 def set_results(self, value):
412 class FileDownloader:
413 implements(IPushProducer)
414 check_crypttext_hash = True
415 check_plaintext_hash = True
418 def __init__(self, client, u, downloadable):
419 self._client = client
422 self._storage_index = u.storage_index
423 self._uri_extension_hash = u.uri_extension_hash
424 self._total_shares = u.total_shares
426 self._num_needed_shares = u.needed_shares
428 self._si_s = storage.si_b2a(self._storage_index)
431 self._started = time.time()
432 self._status = s = DownloadStatus()
433 s.set_status("Starting")
434 s.set_storage_index(self._storage_index)
435 s.set_size(self._size)
439 self._results = DownloadResults()
440 s.set_results(self._results)
441 self._results.file_size = self._size
442 self._results.timings["servers_peer_selection"] = {}
443 self._results.timings["fetch_per_server"] = {}
444 self._results.timings["cumulative_fetch"] = 0.0
445 self._results.timings["cumulative_decode"] = 0.0
446 self._results.timings["cumulative_decrypt"] = 0.0
448 if IConsumer.providedBy(downloadable):
449 downloadable.registerProducer(self, True)
450 self._downloadable = downloadable
451 self._output = Output(downloadable, u.key, self._size, self._log_number,
454 self._stopped = False
456 self.active_buckets = {} # k: shnum, v: bucket
457 self._share_buckets = [] # list of (sharenum, bucket) tuples
458 self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
459 self._uri_extension_sources = []
461 self._uri_extension_data = None
463 self._fetch_failures = {"uri_extension": 0,
464 "plaintext_hashroot": 0,
465 "plaintext_hashtree": 0,
466 "crypttext_hashroot": 0,
467 "crypttext_hashtree": 0,
470 def init_logging(self):
471 self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
472 num = self._client.log(format="FileDownloader(%(si)s): starting",
473 si=storage.si_b2a(self._storage_index))
474 self._log_number = num
476 def log(self, *args, **kwargs):
477 if "parent" not in kwargs:
478 kwargs["parent"] = self._log_number
479 if "facility" not in kwargs:
480 kwargs["facility"] = "tahoe.download"
481 return log.msg(*args, **kwargs)
483 def pauseProducing(self):
486 self._paused = defer.Deferred()
488 self._status.set_paused(True)
490 def resumeProducing(self):
494 eventually(p.callback, None)
496 self._status.set_paused(False)
498 def stopProducing(self):
499 self.log("Download.stopProducing")
502 self._status.set_stopped(True)
503 self._status.set_active(False)
506 self.log("starting download")
508 # first step: who should we download from?
509 d = defer.maybeDeferred(self._get_all_shareholders)
510 d.addCallback(self._got_all_shareholders)
511 # now get the uri_extension block from somebody and validate it
512 d.addCallback(self._obtain_uri_extension)
513 d.addCallback(self._got_uri_extension)
514 d.addCallback(self._get_hashtrees)
515 d.addCallback(self._create_validated_buckets)
516 # once we know that, we can download blocks from everybody
517 d.addCallback(self._download_all_segments)
520 self._status.set_status("Finished")
521 self._status.set_active(False)
522 self._status.set_paused(False)
523 if IConsumer.providedBy(self._downloadable):
524 self._downloadable.unregisterProducer()
529 self._status.set_status("Failed")
530 self._status.set_active(False)
531 self._output.fail(why)
533 d.addErrback(_failed)
534 d.addCallback(self._done)
537 def _get_all_shareholders(self):
539 for (peerid,ss) in self._client.get_permuted_peers("storage",
540 self._storage_index):
541 d = ss.callRemote("get_buckets", self._storage_index)
542 d.addCallbacks(self._got_response, self._got_error,
543 callbackArgs=(peerid,))
545 self._responses_received = 0
546 self._queries_sent = len(dl)
548 self._status.set_status("Locating Shares (%d/%d)" %
549 (self._responses_received,
551 return defer.DeferredList(dl)
553 def _got_response(self, buckets, peerid):
554 self._responses_received += 1
556 elapsed = time.time() - self._started
557 self._results.timings["servers_peer_selection"][peerid] = elapsed
559 self._status.set_status("Locating Shares (%d/%d)" %
560 (self._responses_received,
562 for sharenum, bucket in buckets.iteritems():
563 b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
564 self.add_share_bucket(sharenum, b)
565 self._uri_extension_sources.append(b)
567 if peerid not in self._results.servermap:
568 self._results.servermap[peerid] = set()
569 self._results.servermap[peerid].add(sharenum)
571 def add_share_bucket(self, sharenum, bucket):
572 # this is split out for the benefit of test_encode.py
573 self._share_buckets.append( (sharenum, bucket) )
575 def _got_error(self, f):
576 self._client.log("Somebody failed. -- %s" % (f,))
578 def bucket_failed(self, vbucket):
579 shnum = vbucket.sharenum
580 del self.active_buckets[shnum]
581 s = self._share_vbuckets[shnum]
582 # s is a set of ValidatedBucket instances
584 # ... which might now be empty
586 # there are no more buckets which can provide this share, so
587 # remove the key. This may prompt us to use a different share.
588 del self._share_vbuckets[shnum]
590 def _got_all_shareholders(self, res):
593 self._results.timings["peer_selection"] = now - self._started
595 if len(self._share_buckets) < self._num_needed_shares:
596 raise NotEnoughPeersError
598 #for s in self._share_vbuckets.values():
600 # assert isinstance(vb, ValidatedBucket), \
601 # "vb is %s but should be a ValidatedBucket" % (vb,)
603 def _unpack_uri_extension_data(self, data):
604 return uri.unpack_extension(data)
606 def _obtain_uri_extension(self, ignored):
607 # all shareholders are supposed to have a copy of uri_extension, and
608 # all are supposed to be identical. We compute the hash of the data
609 # that comes back, and compare it against the version in our URI. If
610 # they don't match, ignore their data and try someone else.
612 self._status.set_status("Obtaining URI Extension")
614 self._uri_extension_fetch_started = time.time()
615 def _validate(proposal, bucket):
616 h = hashutil.uri_extension_hash(proposal)
617 if h != self._uri_extension_hash:
618 self._fetch_failures["uri_extension"] += 1
619 msg = ("The copy of uri_extension we received from "
620 "%s was bad: wanted %s, got %s" %
622 base32.b2a(self._uri_extension_hash),
624 self.log(msg, level=log.SCARY)
625 raise BadURIExtensionHashValue(msg)
626 return self._unpack_uri_extension_data(proposal)
627 return self._obtain_validated_thing(None,
628 self._uri_extension_sources,
630 "get_uri_extension", (), _validate)
632 def _obtain_validated_thing(self, ignored, sources, name, methname, args,
635 raise NotEnoughPeersError("started with zero peers while fetching "
638 sources = sources[1:]
639 #d = bucket.callRemote(methname, *args)
640 d = bucket.startIfNecessary()
641 d.addCallback(lambda res: getattr(bucket, methname)(*args))
642 d.addCallback(validatorfunc, bucket)
644 self.log("%s from vbucket %s failed:" % (name, bucket),
645 failure=f, level=log.WEIRD)
647 raise NotEnoughPeersError("ran out of peers, last error was %s"
649 # try again with a different one
650 return self._obtain_validated_thing(None, sources, name,
651 methname, args, validatorfunc)
655 def _got_uri_extension(self, uri_extension_data):
657 elapsed = time.time() - self._uri_extension_fetch_started
658 self._results.timings["uri_extension"] = elapsed
660 d = self._uri_extension_data = uri_extension_data
662 self._codec = codec.get_decoder_by_name(d['codec_name'])
663 self._codec.set_serialized_params(d['codec_params'])
664 self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
665 self._tail_codec.set_serialized_params(d['tail_codec_params'])
667 crypttext_hash = d['crypttext_hash']
668 assert isinstance(crypttext_hash, str)
669 assert len(crypttext_hash) == 32
670 self._crypttext_hash = crypttext_hash
671 self._plaintext_hash = d['plaintext_hash']
672 self._roothash = d['share_root_hash']
674 self._segment_size = segment_size = d['segment_size']
675 self._total_segments = mathutil.div_ceil(self._size, segment_size)
676 self._current_segnum = 0
678 self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
679 self._share_hashtree.set_hashes({0: self._roothash})
681 def _get_hashtrees(self, res):
682 self._get_hashtrees_started = time.time()
684 self._status.set_status("Retrieving Hash Trees")
685 d = self._get_plaintext_hashtrees()
686 d.addCallback(self._get_crypttext_hashtrees)
687 d.addCallback(self._setup_hashtrees)
690 def _get_plaintext_hashtrees(self):
691 def _validate_plaintext_hashtree(proposal, bucket):
692 if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
693 self._fetch_failures["plaintext_hashroot"] += 1
694 msg = ("The copy of the plaintext_root_hash we received from"
695 " %s was bad" % bucket)
696 raise BadPlaintextHashValue(msg)
697 pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
698 pt_hashes = dict(list(enumerate(proposal)))
700 pt_hashtree.set_hashes(pt_hashes)
701 except hashtree.BadHashError:
702 # the hashes they gave us were not self-consistent, even
703 # though the root matched what we saw in the uri_extension
705 self._fetch_failures["plaintext_hashtree"] += 1
707 self._plaintext_hashtree = pt_hashtree
708 d = self._obtain_validated_thing(None,
709 self._uri_extension_sources,
711 "get_plaintext_hashes", (),
712 _validate_plaintext_hashtree)
715 def _get_crypttext_hashtrees(self, res):
716 def _validate_crypttext_hashtree(proposal, bucket):
717 if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
718 self._fetch_failures["crypttext_hashroot"] += 1
719 msg = ("The copy of the crypttext_root_hash we received from"
720 " %s was bad" % bucket)
721 raise BadCrypttextHashValue(msg)
722 ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
723 ct_hashes = dict(list(enumerate(proposal)))
725 ct_hashtree.set_hashes(ct_hashes)
726 except hashtree.BadHashError:
727 self._fetch_failures["crypttext_hashtree"] += 1
729 ct_hashtree.set_hashes(ct_hashes)
730 self._crypttext_hashtree = ct_hashtree
731 d = self._obtain_validated_thing(None,
732 self._uri_extension_sources,
734 "get_crypttext_hashes", (),
735 _validate_crypttext_hashtree)
738 def _setup_hashtrees(self, res):
739 self._output.setup_hashtrees(self._plaintext_hashtree,
740 self._crypttext_hashtree)
742 elapsed = time.time() - self._get_hashtrees_started
743 self._results.timings["hashtrees"] = elapsed
745 def _create_validated_buckets(self, ignored=None):
746 self._share_vbuckets = {}
747 for sharenum, bucket in self._share_buckets:
748 vbucket = ValidatedBucket(sharenum, bucket,
749 self._share_hashtree,
751 self._total_segments)
752 s = self._share_vbuckets.setdefault(sharenum, set())
755 def _activate_enough_buckets(self):
756 """either return a mapping from shnum to a ValidatedBucket that can
757 provide data for that share, or raise NotEnoughPeersError"""
759 while len(self.active_buckets) < self._num_needed_shares:
761 handled_shnums = set(self.active_buckets.keys())
762 available_shnums = set(self._share_vbuckets.keys())
763 potential_shnums = list(available_shnums - handled_shnums)
764 if not potential_shnums:
765 raise NotEnoughPeersError
766 # choose a random share
767 shnum = random.choice(potential_shnums)
768 # and a random bucket that will provide it
769 validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
770 self.active_buckets[shnum] = validated_bucket
771 return self.active_buckets
774 def _download_all_segments(self, res):
775 # the promise: upon entry to this function, self._share_vbuckets
776 # contains enough buckets to complete the download, and some extra
777 # ones to tolerate some buckets dropping out or having errors.
778 # self._share_vbuckets is a dictionary that maps from shnum to a set
779 # of ValidatedBuckets, which themselves are wrappers around
780 # RIBucketReader references.
781 self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
783 self._started_fetching = time.time()
785 d = defer.succeed(None)
786 for segnum in range(self._total_segments-1):
787 d.addCallback(self._download_segment, segnum)
788 # this pause, at the end of write, prevents pre-fetch from
789 # happening until the consumer is ready for more data.
790 d.addCallback(self._check_for_pause)
791 d.addCallback(self._download_tail_segment, self._total_segments-1)
794 def _check_for_pause(self, res):
797 self._paused.addCallback(lambda ignored: d.callback(res))
800 raise DownloadStopped("our Consumer called stopProducing()")
803 def _download_segment(self, res, segnum):
805 self._status.set_status("Downloading segment %d of %d" %
806 (segnum+1, self._total_segments))
807 self.log("downloading seg#%d of %d (%d%%)"
808 % (segnum, self._total_segments,
809 100.0 * segnum / self._total_segments))
810 # memory footprint: when the SegmentDownloader finishes pulling down
811 # all shares, we have 1*segment_size of usage.
812 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
814 started = time.time()
815 d = segmentdler.start()
816 def _finished_fetching(res):
817 elapsed = time.time() - started
818 self._results.timings["cumulative_fetch"] += elapsed
821 d.addCallback(_finished_fetching)
822 # pause before using more memory
823 d.addCallback(self._check_for_pause)
824 # while the codec does its job, we hit 2*segment_size
825 def _started_decode(res):
826 self._started_decode = time.time()
829 d.addCallback(_started_decode)
830 d.addCallback(lambda (shares, shareids):
831 self._codec.decode(shares, shareids))
832 # once the codec is done, we drop back to 1*segment_size, because
833 # 'shares' goes out of scope. The memory usage is all in the
834 # plaintext now, spread out into a bunch of tiny buffers.
835 def _finished_decode(res):
836 elapsed = time.time() - self._started_decode
837 self._results.timings["cumulative_decode"] += elapsed
840 d.addCallback(_finished_decode)
842 # pause/check-for-stop just before writing, to honor stopProducing
843 d.addCallback(self._check_for_pause)
845 # we start by joining all these buffers together into a single
846 # string. This makes Output.write easier, since it wants to hash
847 # data one segment at a time anyways, and doesn't impact our
848 # memory footprint since we're already peaking at 2*segment_size
849 # inside the codec a moment ago.
850 segment = "".join(buffers)
852 # we're down to 1*segment_size right now, but write_segment()
853 # will decrypt a copy of the segment internally, which will push
854 # us up to 2*segment_size while it runs.
855 started_decrypt = time.time()
856 self._output.write_segment(segment)
858 elapsed = time.time() - started_decrypt
859 self._results.timings["cumulative_decrypt"] += elapsed
863 def _download_tail_segment(self, res, segnum):
864 self.log("downloading seg#%d of %d (%d%%)"
865 % (segnum, self._total_segments,
866 100.0 * segnum / self._total_segments))
867 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
869 started = time.time()
870 d = segmentdler.start()
871 def _finished_fetching(res):
872 elapsed = time.time() - started
873 self._results.timings["cumulative_fetch"] += elapsed
876 d.addCallback(_finished_fetching)
877 # pause before using more memory
878 d.addCallback(self._check_for_pause)
879 def _started_decode(res):
880 self._started_decode = time.time()
883 d.addCallback(_started_decode)
884 d.addCallback(lambda (shares, shareids):
885 self._tail_codec.decode(shares, shareids))
886 def _finished_decode(res):
887 elapsed = time.time() - self._started_decode
888 self._results.timings["cumulative_decode"] += elapsed
891 d.addCallback(_finished_decode)
892 # pause/check-for-stop just before writing, to honor stopProducing
893 d.addCallback(self._check_for_pause)
895 # trim off any padding added by the upload side
896 segment = "".join(buffers)
898 # we never send empty segments. If the data was an exact multiple
899 # of the segment size, the last segment will be full.
900 pad_size = mathutil.pad_size(self._size, self._segment_size)
901 tail_size = self._segment_size - pad_size
902 segment = segment[:tail_size]
903 started_decrypt = time.time()
904 self._output.write_segment(segment)
906 elapsed = time.time() - started_decrypt
907 self._results.timings["cumulative_decrypt"] += elapsed
911 def _done(self, res):
912 self.log("download done")
915 self._results.timings["total"] = now - self._started
916 self._results.timings["segments"] = now - self._started_fetching
918 if self.check_crypttext_hash:
919 _assert(self._crypttext_hash == self._output.crypttext_hash,
920 "bad crypttext_hash: computed=%s, expected=%s" %
921 (base32.b2a(self._output.crypttext_hash),
922 base32.b2a(self._crypttext_hash)))
923 if self.check_plaintext_hash:
924 _assert(self._plaintext_hash == self._output.plaintext_hash,
925 "bad plaintext_hash: computed=%s, expected=%s" %
926 (base32.b2a(self._output.plaintext_hash),
927 base32.b2a(self._plaintext_hash)))
928 _assert(self._output.length == self._size,
929 got=self._output.length, expected=self._size)
930 return self._output.finish()
932 def get_download_status(self):
936 class LiteralDownloader:
937 def __init__(self, client, u, downloadable):
938 self._uri = IFileURI(u)
939 assert isinstance(self._uri, uri.LiteralFileURI)
940 self._downloadable = downloadable
941 self._status = s = DownloadStatus()
942 s.set_storage_index(None)
949 data = self._uri.data
950 self._status.set_size(len(data))
951 self._downloadable.open(len(data))
952 self._downloadable.write(data)
953 self._downloadable.close()
954 return defer.maybeDeferred(self._downloadable.finish)
956 def get_download_status(self):
960 implements(IDownloadTarget)
961 def __init__(self, filename):
962 self._filename = filename
964 def open(self, size):
965 self.f = open(self._filename, "wb")
967 def write(self, data):
975 os.unlink(self._filename)
976 def register_canceller(self, cb):
977 pass # we won't use it
982 implements(IDownloadTarget)
985 def open(self, size):
987 def write(self, data):
988 self._data.append(data)
990 self.data = "".join(self._data)
994 def register_canceller(self, cb):
995 pass # we won't use it
1000 """Use me to download data to a pre-defined filehandle-like object. I
1001 will use the target's write() method. I will *not* close the filehandle:
1002 I leave that up to the originator of the filehandle. The download process
1003 will return the filehandle when it completes.
1005 implements(IDownloadTarget)
1006 def __init__(self, filehandle):
1007 self._filehandle = filehandle
1008 def open(self, size):
1010 def write(self, data):
1011 self._filehandle.write(data)
1013 # the originator of the filehandle reserves the right to close it
1015 def fail(self, why):
1017 def register_canceller(self, cb):
1020 return self._filehandle
1022 class Downloader(service.MultiService):
1023 """I am a service that allows file downloading.
1025 implements(IDownloader)
1027 MAX_DOWNLOAD_STATUSES = 10
1030 service.MultiService.__init__(self)
1031 self._all_downloads = weakref.WeakKeyDictionary()
1032 self._recent_download_status = []
1034 def download(self, u, t):
1038 t = IDownloadTarget(t)
1041 if isinstance(u, uri.LiteralFileURI):
1042 dl = LiteralDownloader(self.parent, u, t)
1043 elif isinstance(u, uri.CHKFileURI):
1044 dl = FileDownloader(self.parent, u, t)
1046 raise RuntimeError("I don't know how to download a %s" % u)
1047 self._all_downloads[dl] = None
1048 self._recent_download_status.append(dl.get_download_status())
1049 while len(self._recent_download_status) > self.MAX_DOWNLOAD_STATUSES:
1050 self._recent_download_status.pop(0)
1055 def download_to_data(self, uri):
1056 return self.download(uri, Data())
1057 def download_to_filename(self, uri, filename):
1058 return self.download(uri, FileName(filename))
1059 def download_to_filehandle(self, uri, filehandle):
1060 return self.download(uri, FileHandle(filehandle))
1063 def list_all_downloads(self):
1064 return self._all_downloads.keys()
1065 def list_active_downloads(self):
1066 return [d.get_download_status() for d in self._all_downloads.keys()
1067 if d.get_download_status().get_active()]
1068 def list_recent_downloads(self):
1069 return self._recent_download_status