2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap import DeadReferenceError
8 from foolscap.eventual import eventually
10 from allmydata.util import base32, mathutil, hashutil, log, observer
11 from allmydata.util.assertutil import _assert
12 from allmydata import codec, hashtree, storage, uri
13 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
14 IDownloadStatus, IDownloadResults, NotEnoughSharesError
15 from allmydata.immutable import layout
16 from pycryptopp.cipher.aes import AES
18 class HaveAllPeersError(Exception):
19 # we use this to jump out of the loop
22 class IntegrityCheckError(Exception):
25 class BadURIExtensionHashValue(IntegrityCheckError):
27 class BadURIExtension(IntegrityCheckError):
29 class BadPlaintextHashValue(IntegrityCheckError):
31 class BadCrypttextHashValue(IntegrityCheckError):
34 class DownloadStopped(Exception):
37 class DownloadResults:
38 implements(IDownloadResults)
41 self.servers_used = set()
42 self.server_problems = {}
48 def __init__(self, downloadable, key, total_length, log_parent,
50 self.downloadable = downloadable
51 self._decryptor = AES(key)
52 self._crypttext_hasher = hashutil.crypttext_hasher()
53 self._plaintext_hasher = hashutil.plaintext_hasher()
55 self.total_length = total_length
56 self._segment_number = 0
57 self._plaintext_hash_tree = None
58 self._crypttext_hash_tree = None
60 self._log_parent = log_parent
61 self._status = download_status
62 self._status.set_progress(0.0)
64 def log(self, *args, **kwargs):
65 if "parent" not in kwargs:
66 kwargs["parent"] = self._log_parent
67 if "facility" not in kwargs:
68 kwargs["facility"] = "download.output"
69 return log.msg(*args, **kwargs)
71 def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
72 self._plaintext_hash_tree = plaintext_hashtree
73 self._crypttext_hash_tree = crypttext_hashtree
75 def write_segment(self, crypttext):
76 self.length += len(crypttext)
77 self._status.set_progress( float(self.length) / self.total_length )
79 # memory footprint: 'crypttext' is the only segment_size usage
80 # outstanding. While we decrypt it into 'plaintext', we hit
82 self._crypttext_hasher.update(crypttext)
83 if self._crypttext_hash_tree:
84 ch = hashutil.crypttext_segment_hasher()
86 crypttext_leaves = {self._segment_number: ch.digest()}
87 self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
89 segnum=self._segment_number, hash=base32.b2a(ch.digest()),
91 self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
93 plaintext = self._decryptor.process(crypttext)
96 # now we're back down to 1*segment_size.
98 self._plaintext_hasher.update(plaintext)
99 if self._plaintext_hash_tree:
100 ph = hashutil.plaintext_segment_hasher()
102 plaintext_leaves = {self._segment_number: ph.digest()}
103 self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
104 bytes=len(plaintext),
105 segnum=self._segment_number, hash=base32.b2a(ph.digest()),
107 self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
109 self._segment_number += 1
110 # We're still at 1*segment_size. The Downloadable is responsible for
111 # any memory usage beyond this.
114 self.downloadable.open(self.total_length)
115 self.downloadable.write(plaintext)
118 # this is really unusual, and deserves maximum forensics
119 if why.check(DownloadStopped):
120 # except DownloadStopped just means the consumer aborted the
121 # download, not so scary
122 self.log("download stopped", level=log.UNUSUAL)
124 self.log("download failed!", failure=why,
125 level=log.SCARY, umid="lp1vaQ")
126 self.downloadable.fail(why)
129 self.crypttext_hash = self._crypttext_hasher.digest()
130 self.plaintext_hash = self._plaintext_hasher.digest()
131 self.log("download finished, closing IDownloadable", level=log.NOISY)
132 self.downloadable.close()
135 return self.downloadable.finish()
137 class ValidatedBucket:
138 """I am a front-end for a remote storage bucket, responsible for
139 retrieving and validating data from that bucket.
141 My get_block() method is used by BlockDownloaders.
144 def __init__(self, sharenum, bucket,
145 share_hash_tree, roothash,
147 self.sharenum = sharenum
149 self._share_hash = None # None means not validated yet
150 self.share_hash_tree = share_hash_tree
151 self._roothash = roothash
152 self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
155 def get_block(self, blocknum):
157 d = self.bucket.start()
160 return self.get_block(blocknum)
161 d.addCallback(_started)
164 # the first time we use this bucket, we need to fetch enough elements
165 # of the share hash tree to validate it from our share hash up to the
167 if not self._share_hash:
168 d1 = self.bucket.get_share_hashes()
170 d1 = defer.succeed([])
172 # we might need to grab some elements of our block hash tree, to
173 # validate the requested block up to the share hash
174 needed = self.block_hash_tree.needed_hashes(blocknum)
176 # TODO: get fewer hashes, use get_block_hashes(needed)
177 d2 = self.bucket.get_block_hashes()
179 d2 = defer.succeed([])
181 d3 = self.bucket.get_block(blocknum)
183 d = defer.gatherResults([d1, d2, d3])
184 d.addCallback(self._got_data, blocknum)
187 def _got_data(self, res, blocknum):
188 sharehashes, blockhashes, blockdata = res
189 blockhash = None # to make logging it safe
192 if not self._share_hash:
193 sh = dict(sharehashes)
194 sh[0] = self._roothash # always use our own root, from the URI
195 sht = self.share_hash_tree
196 if sht.get_leaf_index(self.sharenum) not in sh:
197 raise hashtree.NotEnoughHashesError
199 self._share_hash = sht.get_leaf(self.sharenum)
201 blockhash = hashutil.block_hash(blockdata)
202 #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
204 # (self.sharenum, blocknum, len(blockdata),
205 # blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
207 # we always validate the blockhash
208 bh = dict(enumerate(blockhashes))
209 # replace blockhash root with validated value
210 bh[0] = self._share_hash
211 self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
213 except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
214 # log.WEIRD: indicates undetected disk/network error, or more
215 # likely a programming error
216 log.msg("hash failure in block=%d, shnum=%d on %s" %
217 (blocknum, self.sharenum, self.bucket))
219 log.msg(""" failure occurred when checking the block_hash_tree.
220 This suggests that either the block data was bad, or that the
221 block hashes we received along with it were bad.""")
223 log.msg(""" the failure probably occurred when checking the
224 share_hash_tree, which suggests that the share hashes we
225 received from the remote peer were bad.""")
226 log.msg(" have self._share_hash: %s" % bool(self._share_hash))
227 log.msg(" block length: %d" % len(blockdata))
228 log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
229 if len(blockdata) < 100:
230 log.msg(" block data: %r" % (blockdata,))
232 log.msg(" block data start/end: %r .. %r" %
233 (blockdata[:50], blockdata[-50:]))
234 log.msg(" root hash: %s" % base32.b2a(self._roothash))
235 log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
236 log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
238 for i,h in sorted(sharehashes):
239 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
240 log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
242 for i,h in enumerate(blockhashes):
243 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
244 log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
247 # If we made it here, the block is good. If the hash trees didn't
248 # like what they saw, they would have raised a BadHashError, causing
249 # our caller to see a Failure and thus ignore this block (as well as
250 # dropping this bucket).
255 class BlockDownloader:
256 """I am responsible for downloading a single block (from a single bucket)
257 for a single segment.
259 I am a child of the SegmentDownloader.
262 def __init__(self, vbucket, blocknum, parent, results):
263 self.vbucket = vbucket
264 self.blocknum = blocknum
266 self.results = results
267 self._log_number = self.parent.log("starting block %d" % blocknum)
269 def log(self, *args, **kwargs):
270 if "parent" not in kwargs:
271 kwargs["parent"] = self._log_number
272 return self.parent.log(*args, **kwargs)
274 def start(self, segnum):
275 lognum = self.log("get_block(segnum=%d)" % segnum)
276 started = time.time()
277 d = self.vbucket.get_block(segnum)
278 d.addCallbacks(self._hold_block, self._got_block_error,
279 callbackArgs=(started, lognum,), errbackArgs=(lognum,))
282 def _hold_block(self, data, started, lognum):
284 elapsed = time.time() - started
285 peerid = self.vbucket.bucket.get_peerid()
286 if peerid not in self.results.timings["fetch_per_server"]:
287 self.results.timings["fetch_per_server"][peerid] = []
288 self.results.timings["fetch_per_server"][peerid].append(elapsed)
289 self.log("got block", parent=lognum)
290 self.parent.hold_block(self.blocknum, data)
292 def _got_block_error(self, f, lognum):
294 if f.check(DeadReferenceError):
296 self.log("BlockDownloader[%d] got error" % self.blocknum,
297 failure=f, level=level, parent=lognum, umid="5Z4uHQ")
299 peerid = self.vbucket.bucket.get_peerid()
300 self.results.server_problems[peerid] = str(f)
301 self.parent.bucket_failed(self.vbucket)
303 class SegmentDownloader:
304 """I am responsible for downloading all the blocks for a single segment
307 I am a child of the FileDownloader.
310 def __init__(self, parent, segmentnumber, needed_shares, results):
312 self.segmentnumber = segmentnumber
313 self.needed_blocks = needed_shares
314 self.blocks = {} # k: blocknum, v: data
315 self.results = results
316 self._log_number = self.parent.log("starting segment %d" %
319 def log(self, *args, **kwargs):
320 if "parent" not in kwargs:
321 kwargs["parent"] = self._log_number
322 return self.parent.log(*args, **kwargs)
325 return self._download()
330 if len(self.blocks) >= self.needed_blocks:
331 # we only need self.needed_blocks blocks
332 # we want to get the smallest blockids, because they are
333 # more likely to be fast "primary blocks"
334 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
336 for blocknum in blockids:
337 blocks.append(self.blocks[blocknum])
338 return (blocks, blockids)
340 return self._download()
345 # fill our set of active buckets, maybe raising NotEnoughSharesError
346 active_buckets = self.parent._activate_enough_buckets()
347 # Now we have enough buckets, in self.parent.active_buckets.
349 # in test cases, bd.start might mutate active_buckets right away, so
350 # we need to put off calling start() until we've iterated all the way
353 for blocknum, vbucket in active_buckets.iteritems():
354 bd = BlockDownloader(vbucket, blocknum, self, self.results)
355 downloaders.append(bd)
357 self.results.servers_used.add(vbucket.bucket.get_peerid())
358 l = [bd.start(self.segmentnumber) for bd in downloaders]
359 return defer.DeferredList(l, fireOnOneErrback=True)
361 def hold_block(self, blocknum, data):
362 self.blocks[blocknum] = data
364 def bucket_failed(self, vbucket):
365 self.parent.bucket_failed(vbucket)
367 class DownloadStatus:
368 implements(IDownloadStatus)
369 statusid_counter = itertools.count(0)
372 self.storage_index = None
375 self.status = "Not started"
381 self.counter = self.statusid_counter.next()
382 self.started = time.time()
384 def get_started(self):
386 def get_storage_index(self):
387 return self.storage_index
390 def using_helper(self):
392 def get_status(self):
395 status += " (output paused)"
397 status += " (output stopped)"
399 def get_progress(self):
401 def get_active(self):
403 def get_results(self):
405 def get_counter(self):
408 def set_storage_index(self, si):
409 self.storage_index = si
410 def set_size(self, size):
412 def set_helper(self, helper):
414 def set_status(self, status):
416 def set_paused(self, paused):
418 def set_stopped(self, stopped):
419 self.stopped = stopped
420 def set_progress(self, value):
421 self.progress = value
422 def set_active(self, value):
424 def set_results(self, value):
427 class FileDownloader:
428 implements(IPushProducer)
429 check_crypttext_hash = True
430 check_plaintext_hash = True
433 def __init__(self, client, u, downloadable):
434 self._client = client
437 self._storage_index = u.storage_index
438 self._uri_extension_hash = u.uri_extension_hash
439 self._total_shares = u.total_shares
441 self._num_needed_shares = u.needed_shares
443 self._si_s = storage.si_b2a(self._storage_index)
446 self._started = time.time()
447 self._status = s = DownloadStatus()
448 s.set_status("Starting")
449 s.set_storage_index(self._storage_index)
450 s.set_size(self._size)
454 self._results = DownloadResults()
455 s.set_results(self._results)
456 self._results.file_size = self._size
457 self._results.timings["servers_peer_selection"] = {}
458 self._results.timings["fetch_per_server"] = {}
459 self._results.timings["cumulative_fetch"] = 0.0
460 self._results.timings["cumulative_decode"] = 0.0
461 self._results.timings["cumulative_decrypt"] = 0.0
462 self._results.timings["paused"] = 0.0
465 self._stopped = False
466 if IConsumer.providedBy(downloadable):
467 downloadable.registerProducer(self, True)
468 self._downloadable = downloadable
469 self._output = Output(downloadable, u.key, self._size, self._log_number,
472 self.active_buckets = {} # k: shnum, v: bucket
473 self._share_buckets = [] # list of (sharenum, bucket) tuples
474 self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
475 self._uri_extension_sources = []
477 self._uri_extension_data = None
479 self._fetch_failures = {"uri_extension": 0,
480 "plaintext_hashroot": 0,
481 "plaintext_hashtree": 0,
482 "crypttext_hashroot": 0,
483 "crypttext_hashtree": 0,
486 def init_logging(self):
487 self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
488 num = self._client.log(format="FileDownloader(%(si)s): starting",
489 si=storage.si_b2a(self._storage_index))
490 self._log_number = num
492 def log(self, *args, **kwargs):
493 if "parent" not in kwargs:
494 kwargs["parent"] = self._log_number
495 if "facility" not in kwargs:
496 kwargs["facility"] = "tahoe.download"
497 return log.msg(*args, **kwargs)
499 def pauseProducing(self):
502 self._paused = defer.Deferred()
503 self._paused_at = time.time()
505 self._status.set_paused(True)
507 def resumeProducing(self):
509 paused_for = time.time() - self._paused_at
510 self._results.timings['paused'] += paused_for
513 eventually(p.callback, None)
515 self._status.set_paused(False)
517 def stopProducing(self):
518 self.log("Download.stopProducing")
520 self.resumeProducing()
522 self._status.set_stopped(True)
523 self._status.set_active(False)
526 self.log("starting download")
528 # first step: who should we download from?
529 d = defer.maybeDeferred(self._get_all_shareholders)
530 d.addCallback(self._got_all_shareholders)
531 # now get the uri_extension block from somebody and validate it
532 d.addCallback(self._obtain_uri_extension)
533 d.addCallback(self._got_uri_extension)
534 d.addCallback(self._get_hashtrees)
535 d.addCallback(self._create_validated_buckets)
536 # once we know that, we can download blocks from everybody
537 d.addCallback(self._download_all_segments)
540 self._status.set_status("Finished")
541 self._status.set_active(False)
542 self._status.set_paused(False)
543 if IConsumer.providedBy(self._downloadable):
544 self._downloadable.unregisterProducer()
549 self._status.set_status("Failed")
550 self._status.set_active(False)
551 self._output.fail(why)
553 d.addErrback(_failed)
554 d.addCallback(self._done)
557 def _get_all_shareholders(self):
559 for (peerid,ss) in self._client.get_permuted_peers("storage",
560 self._storage_index):
561 d = ss.callRemote("get_buckets", self._storage_index)
562 d.addCallbacks(self._got_response, self._got_error,
563 callbackArgs=(peerid,))
565 self._responses_received = 0
566 self._queries_sent = len(dl)
568 self._status.set_status("Locating Shares (%d/%d)" %
569 (self._responses_received,
571 return defer.DeferredList(dl)
573 def _got_response(self, buckets, peerid):
574 self._responses_received += 1
576 elapsed = time.time() - self._started
577 self._results.timings["servers_peer_selection"][peerid] = elapsed
579 self._status.set_status("Locating Shares (%d/%d)" %
580 (self._responses_received,
582 for sharenum, bucket in buckets.iteritems():
583 b = layout.ReadBucketProxy(bucket, peerid, self._si_s)
584 self.add_share_bucket(sharenum, b)
585 self._uri_extension_sources.append(b)
587 if peerid not in self._results.servermap:
588 self._results.servermap[peerid] = set()
589 self._results.servermap[peerid].add(sharenum)
591 def add_share_bucket(self, sharenum, bucket):
592 # this is split out for the benefit of test_encode.py
593 self._share_buckets.append( (sharenum, bucket) )
595 def _got_error(self, f):
597 if f.check(DeadReferenceError):
599 self._client.log("Error during get_buckets", failure=f, level=level,
602 def bucket_failed(self, vbucket):
603 shnum = vbucket.sharenum
604 del self.active_buckets[shnum]
605 s = self._share_vbuckets[shnum]
606 # s is a set of ValidatedBucket instances
608 # ... which might now be empty
610 # there are no more buckets which can provide this share, so
611 # remove the key. This may prompt us to use a different share.
612 del self._share_vbuckets[shnum]
614 def _got_all_shareholders(self, res):
617 self._results.timings["peer_selection"] = now - self._started
619 if len(self._share_buckets) < self._num_needed_shares:
620 raise NotEnoughSharesError
622 #for s in self._share_vbuckets.values():
624 # assert isinstance(vb, ValidatedBucket), \
625 # "vb is %s but should be a ValidatedBucket" % (vb,)
627 def _unpack_uri_extension_data(self, data):
628 return uri.unpack_extension(data)
630 def _obtain_uri_extension(self, ignored):
631 # all shareholders are supposed to have a copy of uri_extension, and
632 # all are supposed to be identical. We compute the hash of the data
633 # that comes back, and compare it against the version in our URI. If
634 # they don't match, ignore their data and try someone else.
636 self._status.set_status("Obtaining URI Extension")
638 self._uri_extension_fetch_started = time.time()
639 def _validate(proposal, bucket):
640 h = hashutil.uri_extension_hash(proposal)
641 if h != self._uri_extension_hash:
642 self._fetch_failures["uri_extension"] += 1
643 msg = ("The copy of uri_extension we received from "
644 "%s was bad: wanted %s, got %s" %
646 base32.b2a(self._uri_extension_hash),
648 self.log(msg, level=log.SCARY, umid="jnkTtQ")
649 raise BadURIExtensionHashValue(msg)
650 return self._unpack_uri_extension_data(proposal)
651 return self._obtain_validated_thing(None,
652 self._uri_extension_sources,
654 "get_uri_extension", (), _validate)
656 def _obtain_validated_thing(self, ignored, sources, name, methname, args,
659 raise NotEnoughSharesError("started with zero peers while fetching "
662 sources = sources[1:]
663 #d = bucket.callRemote(methname, *args)
664 d = bucket.startIfNecessary()
665 d.addCallback(lambda res: getattr(bucket, methname)(*args))
666 d.addCallback(validatorfunc, bucket)
669 if f.check(DeadReferenceError):
671 self.log(format="operation %(op)s from vbucket %(vbucket)s failed",
672 op=name, vbucket=str(bucket),
673 failure=f, level=level, umid="JGXxBA")
675 raise NotEnoughSharesError("ran out of peers, last error was %s"
677 # try again with a different one
678 return self._obtain_validated_thing(None, sources, name,
679 methname, args, validatorfunc)
683 def _got_uri_extension(self, uri_extension_data):
685 elapsed = time.time() - self._uri_extension_fetch_started
686 self._results.timings["uri_extension"] = elapsed
688 d = self._uri_extension_data = uri_extension_data
690 self._codec = codec.get_decoder_by_name(d['codec_name'])
691 self._codec.set_serialized_params(d['codec_params'])
692 self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
693 self._tail_codec.set_serialized_params(d['tail_codec_params'])
695 crypttext_hash = d.get('crypttext_hash', None) # optional
697 assert isinstance(crypttext_hash, str)
698 assert len(crypttext_hash) == 32
699 self._crypttext_hash = crypttext_hash
700 self._plaintext_hash = d.get('plaintext_hash', None) # optional
702 self._roothash = d['share_root_hash']
704 self._segment_size = segment_size = d['segment_size']
705 self._total_segments = mathutil.div_ceil(self._size, segment_size)
706 self._current_segnum = 0
708 self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
709 self._share_hashtree.set_hashes({0: self._roothash})
711 def _get_hashtrees(self, res):
712 self._get_hashtrees_started = time.time()
714 self._status.set_status("Retrieving Hash Trees")
715 d = defer.maybeDeferred(self._get_plaintext_hashtrees)
716 d.addCallback(self._get_crypttext_hashtrees)
717 d.addCallback(self._setup_hashtrees)
720 def _get_plaintext_hashtrees(self):
721 # plaintext hashes are optional. If the root isn't in the UEB, then
722 # the share will be holding an empty list. We don't even bother
724 if "plaintext_root_hash" not in self._uri_extension_data:
725 self._plaintext_hashtree = None
727 def _validate_plaintext_hashtree(proposal, bucket):
728 if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
729 self._fetch_failures["plaintext_hashroot"] += 1
730 msg = ("The copy of the plaintext_root_hash we received from"
731 " %s was bad" % bucket)
732 raise BadPlaintextHashValue(msg)
733 pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
734 pt_hashes = dict(list(enumerate(proposal)))
736 pt_hashtree.set_hashes(pt_hashes)
737 except hashtree.BadHashError:
738 # the hashes they gave us were not self-consistent, even
739 # though the root matched what we saw in the uri_extension
741 self._fetch_failures["plaintext_hashtree"] += 1
743 self._plaintext_hashtree = pt_hashtree
744 d = self._obtain_validated_thing(None,
745 self._uri_extension_sources,
747 "get_plaintext_hashes", (),
748 _validate_plaintext_hashtree)
751 def _get_crypttext_hashtrees(self, res):
752 # Ciphertext hash tree root is mandatory, so that there is at
753 # most one ciphertext that matches this read-cap or
754 # verify-cap. The integrity check on the shares is not
755 # sufficient to prevent the original encoder from creating
756 # some shares of file A and other shares of file B.
757 if "crypttext_root_hash" not in self._uri_extension_data:
758 raise BadURIExtension("URI Extension block did not have the ciphertext hash tree root")
759 def _validate_crypttext_hashtree(proposal, bucket):
760 if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
761 self._fetch_failures["crypttext_hashroot"] += 1
762 msg = ("The copy of the crypttext_root_hash we received from"
763 " %s was bad" % bucket)
764 raise BadCrypttextHashValue(msg)
765 ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
766 ct_hashes = dict(list(enumerate(proposal)))
768 ct_hashtree.set_hashes(ct_hashes)
769 except hashtree.BadHashError:
770 self._fetch_failures["crypttext_hashtree"] += 1
772 ct_hashtree.set_hashes(ct_hashes)
773 self._crypttext_hashtree = ct_hashtree
774 d = self._obtain_validated_thing(None,
775 self._uri_extension_sources,
777 "get_crypttext_hashes", (),
778 _validate_crypttext_hashtree)
781 def _setup_hashtrees(self, res):
782 self._output.setup_hashtrees(self._plaintext_hashtree,
783 self._crypttext_hashtree)
785 elapsed = time.time() - self._get_hashtrees_started
786 self._results.timings["hashtrees"] = elapsed
788 def _create_validated_buckets(self, ignored=None):
789 self._share_vbuckets = {}
790 for sharenum, bucket in self._share_buckets:
791 vbucket = ValidatedBucket(sharenum, bucket,
792 self._share_hashtree,
794 self._total_segments)
795 s = self._share_vbuckets.setdefault(sharenum, set())
798 def _activate_enough_buckets(self):
799 """either return a mapping from shnum to a ValidatedBucket that can
800 provide data for that share, or raise NotEnoughSharesError"""
802 while len(self.active_buckets) < self._num_needed_shares:
804 handled_shnums = set(self.active_buckets.keys())
805 available_shnums = set(self._share_vbuckets.keys())
806 potential_shnums = list(available_shnums - handled_shnums)
807 if not potential_shnums:
808 raise NotEnoughSharesError
809 # choose a random share
810 shnum = random.choice(potential_shnums)
811 # and a random bucket that will provide it
812 validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
813 self.active_buckets[shnum] = validated_bucket
814 return self.active_buckets
817 def _download_all_segments(self, res):
818 # the promise: upon entry to this function, self._share_vbuckets
819 # contains enough buckets to complete the download, and some extra
820 # ones to tolerate some buckets dropping out or having errors.
821 # self._share_vbuckets is a dictionary that maps from shnum to a set
822 # of ValidatedBuckets, which themselves are wrappers around
823 # RIBucketReader references.
824 self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
826 self._started_fetching = time.time()
828 d = defer.succeed(None)
829 for segnum in range(self._total_segments-1):
830 d.addCallback(self._download_segment, segnum)
831 # this pause, at the end of write, prevents pre-fetch from
832 # happening until the consumer is ready for more data.
833 d.addCallback(self._check_for_pause)
834 d.addCallback(self._download_tail_segment, self._total_segments-1)
837 def _check_for_pause(self, res):
840 self._paused.addCallback(lambda ignored: d.callback(res))
843 raise DownloadStopped("our Consumer called stopProducing()")
846 def _download_segment(self, res, segnum):
848 self._status.set_status("Downloading segment %d of %d" %
849 (segnum+1, self._total_segments))
850 self.log("downloading seg#%d of %d (%d%%)"
851 % (segnum, self._total_segments,
852 100.0 * segnum / self._total_segments))
853 # memory footprint: when the SegmentDownloader finishes pulling down
854 # all shares, we have 1*segment_size of usage.
855 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
857 started = time.time()
858 d = segmentdler.start()
859 def _finished_fetching(res):
860 elapsed = time.time() - started
861 self._results.timings["cumulative_fetch"] += elapsed
864 d.addCallback(_finished_fetching)
865 # pause before using more memory
866 d.addCallback(self._check_for_pause)
867 # while the codec does its job, we hit 2*segment_size
868 def _started_decode(res):
869 self._started_decode = time.time()
872 d.addCallback(_started_decode)
873 d.addCallback(lambda (shares, shareids):
874 self._codec.decode(shares, shareids))
875 # once the codec is done, we drop back to 1*segment_size, because
876 # 'shares' goes out of scope. The memory usage is all in the
877 # plaintext now, spread out into a bunch of tiny buffers.
878 def _finished_decode(res):
879 elapsed = time.time() - self._started_decode
880 self._results.timings["cumulative_decode"] += elapsed
883 d.addCallback(_finished_decode)
885 # pause/check-for-stop just before writing, to honor stopProducing
886 d.addCallback(self._check_for_pause)
888 # we start by joining all these buffers together into a single
889 # string. This makes Output.write easier, since it wants to hash
890 # data one segment at a time anyways, and doesn't impact our
891 # memory footprint since we're already peaking at 2*segment_size
892 # inside the codec a moment ago.
893 segment = "".join(buffers)
895 # we're down to 1*segment_size right now, but write_segment()
896 # will decrypt a copy of the segment internally, which will push
897 # us up to 2*segment_size while it runs.
898 started_decrypt = time.time()
899 self._output.write_segment(segment)
901 elapsed = time.time() - started_decrypt
902 self._results.timings["cumulative_decrypt"] += elapsed
906 def _download_tail_segment(self, res, segnum):
907 self.log("downloading seg#%d of %d (%d%%)"
908 % (segnum, self._total_segments,
909 100.0 * segnum / self._total_segments))
910 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
912 started = time.time()
913 d = segmentdler.start()
914 def _finished_fetching(res):
915 elapsed = time.time() - started
916 self._results.timings["cumulative_fetch"] += elapsed
919 d.addCallback(_finished_fetching)
920 # pause before using more memory
921 d.addCallback(self._check_for_pause)
922 def _started_decode(res):
923 self._started_decode = time.time()
926 d.addCallback(_started_decode)
927 d.addCallback(lambda (shares, shareids):
928 self._tail_codec.decode(shares, shareids))
929 def _finished_decode(res):
930 elapsed = time.time() - self._started_decode
931 self._results.timings["cumulative_decode"] += elapsed
934 d.addCallback(_finished_decode)
935 # pause/check-for-stop just before writing, to honor stopProducing
936 d.addCallback(self._check_for_pause)
938 # trim off any padding added by the upload side
939 segment = "".join(buffers)
941 # we never send empty segments. If the data was an exact multiple
942 # of the segment size, the last segment will be full.
943 pad_size = mathutil.pad_size(self._size, self._segment_size)
944 tail_size = self._segment_size - pad_size
945 segment = segment[:tail_size]
946 started_decrypt = time.time()
947 self._output.write_segment(segment)
949 elapsed = time.time() - started_decrypt
950 self._results.timings["cumulative_decrypt"] += elapsed
954 def _done(self, res):
955 self.log("download done")
958 self._results.timings["total"] = now - self._started
959 self._results.timings["segments"] = now - self._started_fetching
961 if self.check_crypttext_hash and self._crypttext_hash:
962 _assert(self._crypttext_hash == self._output.crypttext_hash,
963 "bad crypttext_hash: computed=%s, expected=%s" %
964 (base32.b2a(self._output.crypttext_hash),
965 base32.b2a(self._crypttext_hash)))
966 if self.check_plaintext_hash and self._plaintext_hash:
967 _assert(self._plaintext_hash == self._output.plaintext_hash,
968 "bad plaintext_hash: computed=%s, expected=%s" %
969 (base32.b2a(self._output.plaintext_hash),
970 base32.b2a(self._plaintext_hash)))
971 _assert(self._output.length == self._size,
972 got=self._output.length, expected=self._size)
973 return self._output.finish()
975 def get_download_status(self):
980 implements(IDownloadTarget)
981 def __init__(self, filename):
982 self._filename = filename
984 def open(self, size):
985 self.f = open(self._filename, "wb")
987 def write(self, data):
995 os.unlink(self._filename)
996 def register_canceller(self, cb):
997 pass # we won't use it
1002 implements(IDownloadTarget)
1005 def open(self, size):
1007 def write(self, data):
1008 self._data.append(data)
1010 self.data = "".join(self._data)
1012 def fail(self, why):
1014 def register_canceller(self, cb):
1015 pass # we won't use it
1020 """Use me to download data to a pre-defined filehandle-like object. I
1021 will use the target's write() method. I will *not* close the filehandle:
1022 I leave that up to the originator of the filehandle. The download process
1023 will return the filehandle when it completes.
1025 implements(IDownloadTarget)
1026 def __init__(self, filehandle):
1027 self._filehandle = filehandle
1028 def open(self, size):
1030 def write(self, data):
1031 self._filehandle.write(data)
1033 # the originator of the filehandle reserves the right to close it
1035 def fail(self, why):
1037 def register_canceller(self, cb):
1040 return self._filehandle
1042 class ConsumerAdapter:
1043 implements(IDownloadTarget, IConsumer)
1044 def __init__(self, consumer):
1045 self._consumer = consumer
1046 self._when_finished = observer.OneShotObserverList()
1048 def when_finished(self):
1049 # I think this is unused, along with self._when_finished . But I need
1050 # to trace the error paths to be sure.
1051 return self._when_finished.when_fired()
1053 def registerProducer(self, producer, streaming):
1054 self._consumer.registerProducer(producer, streaming)
1055 def unregisterProducer(self):
1056 self._consumer.unregisterProducer()
1058 def open(self, size):
1060 def write(self, data):
1061 self._consumer.write(data)
1063 self._when_finished.fire(None)
1065 def fail(self, why):
1066 self._when_finished.fire(why)
1067 def register_canceller(self, cb):
1073 class Downloader(service.MultiService):
1074 """I am a service that allows file downloading.
1076 # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1077 # It is scheduled to go away, to be replaced by filenode.download()
1078 implements(IDownloader)
1080 MAX_DOWNLOAD_STATUSES = 10
1082 def __init__(self, stats_provider=None):
1083 service.MultiService.__init__(self)
1084 self.stats_provider = stats_provider
1085 self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1086 self._all_download_statuses = weakref.WeakKeyDictionary()
1087 self._recent_download_statuses = []
1089 def download(self, u, t):
1093 t = IDownloadTarget(t)
1097 assert isinstance(u, uri.CHKFileURI)
1098 if self.stats_provider:
1099 # these counters are meant for network traffic, and don't
1101 self.stats_provider.count('downloader.files_downloaded', 1)
1102 self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1103 dl = FileDownloader(self.parent, u, t)
1104 self._add_download(dl)
1109 def download_to_data(self, uri):
1110 return self.download(uri, Data())
1111 def download_to_filename(self, uri, filename):
1112 return self.download(uri, FileName(filename))
1113 def download_to_filehandle(self, uri, filehandle):
1114 return self.download(uri, FileHandle(filehandle))
1116 def _add_download(self, downloader):
1117 self._all_downloads[downloader] = None
1118 s = downloader.get_download_status()
1119 self._all_download_statuses[s] = None
1120 self._recent_download_statuses.append(s)
1121 while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1122 self._recent_download_statuses.pop(0)
1124 def list_all_download_statuses(self):
1125 for ds in self._all_download_statuses: