2 import os, random, weakref
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
17 class HaveAllPeersError(Exception):
18 # we use this to jump out of the loop
21 class BadURIExtensionHashValue(Exception):
23 class BadPlaintextHashValue(Exception):
25 class BadCrypttextHashValue(Exception):
28 class DownloadStopped(Exception):
32 def __init__(self, downloadable, key, total_length, log_parent,
34 self.downloadable = downloadable
35 self._decryptor = AES(key)
36 self._crypttext_hasher = hashutil.crypttext_hasher()
37 self._plaintext_hasher = hashutil.plaintext_hasher()
39 self.total_length = total_length
40 self._segment_number = 0
41 self._plaintext_hash_tree = None
42 self._crypttext_hash_tree = None
44 self._log_parent = log_parent
45 self._status = download_status
46 self._status.set_progress(0.0)
48 def log(self, *args, **kwargs):
49 if "parent" not in kwargs:
50 kwargs["parent"] = self._log_parent
51 if "facility" not in kwargs:
52 kwargs["facility"] = "download.output"
53 return log.msg(*args, **kwargs)
55 def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
56 self._plaintext_hash_tree = plaintext_hashtree
57 self._crypttext_hash_tree = crypttext_hashtree
59 def write_segment(self, crypttext):
60 self.length += len(crypttext)
61 self._status.set_progress( float(self.length) / self.total_length )
63 # memory footprint: 'crypttext' is the only segment_size usage
64 # outstanding. While we decrypt it into 'plaintext', we hit
66 self._crypttext_hasher.update(crypttext)
67 if self._crypttext_hash_tree:
68 ch = hashutil.crypttext_segment_hasher()
70 crypttext_leaves = {self._segment_number: ch.digest()}
71 self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
73 segnum=self._segment_number, hash=base32.b2a(ch.digest()),
75 self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
77 plaintext = self._decryptor.process(crypttext)
80 # now we're back down to 1*segment_size.
82 self._plaintext_hasher.update(plaintext)
83 if self._plaintext_hash_tree:
84 ph = hashutil.plaintext_segment_hasher()
86 plaintext_leaves = {self._segment_number: ph.digest()}
87 self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
89 segnum=self._segment_number, hash=base32.b2a(ph.digest()),
91 self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
93 self._segment_number += 1
94 # We're still at 1*segment_size. The Downloadable is responsible for
95 # any memory usage beyond this.
98 self.downloadable.open(self.total_length)
99 self.downloadable.write(plaintext)
102 # this is really unusual, and deserves maximum forensics
103 self.log("download failed!", failure=why, level=log.SCARY)
104 self.downloadable.fail(why)
107 self.crypttext_hash = self._crypttext_hasher.digest()
108 self.plaintext_hash = self._plaintext_hasher.digest()
109 self.log("download finished, closing IDownloadable", level=log.NOISY)
110 self.downloadable.close()
113 return self.downloadable.finish()
115 class ValidatedBucket:
116 """I am a front-end for a remote storage bucket, responsible for
117 retrieving and validating data from that bucket.
119 My get_block() method is used by BlockDownloaders.
122 def __init__(self, sharenum, bucket,
123 share_hash_tree, roothash,
125 self.sharenum = sharenum
127 self._share_hash = None # None means not validated yet
128 self.share_hash_tree = share_hash_tree
129 self._roothash = roothash
130 self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
133 def get_block(self, blocknum):
135 d = self.bucket.start()
138 return self.get_block(blocknum)
139 d.addCallback(_started)
142 # the first time we use this bucket, we need to fetch enough elements
143 # of the share hash tree to validate it from our share hash up to the
145 if not self._share_hash:
146 d1 = self.bucket.get_share_hashes()
148 d1 = defer.succeed([])
150 # we might need to grab some elements of our block hash tree, to
151 # validate the requested block up to the share hash
152 needed = self.block_hash_tree.needed_hashes(blocknum)
154 # TODO: get fewer hashes, use get_block_hashes(needed)
155 d2 = self.bucket.get_block_hashes()
157 d2 = defer.succeed([])
159 d3 = self.bucket.get_block(blocknum)
161 d = defer.gatherResults([d1, d2, d3])
162 d.addCallback(self._got_data, blocknum)
165 def _got_data(self, res, blocknum):
166 sharehashes, blockhashes, blockdata = res
167 blockhash = None # to make logging it safe
170 if not self._share_hash:
171 sh = dict(sharehashes)
172 sh[0] = self._roothash # always use our own root, from the URI
173 sht = self.share_hash_tree
174 if sht.get_leaf_index(self.sharenum) not in sh:
175 raise hashtree.NotEnoughHashesError
177 self._share_hash = sht.get_leaf(self.sharenum)
179 blockhash = hashutil.block_hash(blockdata)
180 #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
182 # (self.sharenum, blocknum, len(blockdata),
183 # blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
185 # we always validate the blockhash
186 bh = dict(enumerate(blockhashes))
187 # replace blockhash root with validated value
188 bh[0] = self._share_hash
189 self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
191 except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
192 # log.WEIRD: indicates undetected disk/network error, or more
193 # likely a programming error
194 log.msg("hash failure in block=%d, shnum=%d on %s" %
195 (blocknum, self.sharenum, self.bucket))
197 log.msg(""" failure occurred when checking the block_hash_tree.
198 This suggests that either the block data was bad, or that the
199 block hashes we received along with it were bad.""")
201 log.msg(""" the failure probably occurred when checking the
202 share_hash_tree, which suggests that the share hashes we
203 received from the remote peer were bad.""")
204 log.msg(" have self._share_hash: %s" % bool(self._share_hash))
205 log.msg(" block length: %d" % len(blockdata))
206 log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
207 if len(blockdata) < 100:
208 log.msg(" block data: %r" % (blockdata,))
210 log.msg(" block data start/end: %r .. %r" %
211 (blockdata[:50], blockdata[-50:]))
212 log.msg(" root hash: %s" % base32.b2a(self._roothash))
213 log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
214 log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
216 for i,h in sorted(sharehashes):
217 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
218 log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
220 for i,h in enumerate(blockhashes):
221 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
222 log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
225 # If we made it here, the block is good. If the hash trees didn't
226 # like what they saw, they would have raised a BadHashError, causing
227 # our caller to see a Failure and thus ignore this block (as well as
228 # dropping this bucket).
233 class BlockDownloader:
234 """I am responsible for downloading a single block (from a single bucket)
235 for a single segment.
237 I am a child of the SegmentDownloader.
240 def __init__(self, vbucket, blocknum, parent):
241 self.vbucket = vbucket
242 self.blocknum = blocknum
244 self._log_number = self.parent.log("starting block %d" % blocknum)
246 def log(self, msg, parent=None):
248 parent = self._log_number
249 return self.parent.log(msg, parent=parent)
251 def start(self, segnum):
252 lognum = self.log("get_block(segnum=%d)" % segnum)
253 d = self.vbucket.get_block(segnum)
254 d.addCallbacks(self._hold_block, self._got_block_error,
255 callbackArgs=(lognum,), errbackArgs=(lognum,))
258 def _hold_block(self, data, lognum):
259 self.log("got block", parent=lognum)
260 self.parent.hold_block(self.blocknum, data)
262 def _got_block_error(self, f, lognum):
263 self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
265 self.parent.bucket_failed(self.vbucket)
267 class SegmentDownloader:
268 """I am responsible for downloading all the blocks for a single segment
271 I am a child of the FileDownloader.
274 def __init__(self, parent, segmentnumber, needed_shares):
276 self.segmentnumber = segmentnumber
277 self.needed_blocks = needed_shares
278 self.blocks = {} # k: blocknum, v: data
279 self._log_number = self.parent.log("starting segment %d" %
282 def log(self, msg, parent=None):
284 parent = self._log_number
285 return self.parent.log(msg, parent=parent)
288 return self._download()
293 if len(self.blocks) >= self.needed_blocks:
294 # we only need self.needed_blocks blocks
295 # we want to get the smallest blockids, because they are
296 # more likely to be fast "primary blocks"
297 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
299 for blocknum in blockids:
300 blocks.append(self.blocks[blocknum])
301 return (blocks, blockids)
303 return self._download()
308 # fill our set of active buckets, maybe raising NotEnoughPeersError
309 active_buckets = self.parent._activate_enough_buckets()
310 # Now we have enough buckets, in self.parent.active_buckets.
312 # in test cases, bd.start might mutate active_buckets right away, so
313 # we need to put off calling start() until we've iterated all the way
316 for blocknum, vbucket in active_buckets.iteritems():
317 bd = BlockDownloader(vbucket, blocknum, self)
318 downloaders.append(bd)
319 l = [bd.start(self.segmentnumber) for bd in downloaders]
320 return defer.DeferredList(l, fireOnOneErrback=True)
322 def hold_block(self, blocknum, data):
323 self.blocks[blocknum] = data
325 def bucket_failed(self, vbucket):
326 self.parent.bucket_failed(vbucket)
328 class DownloadStatus:
329 implements(IDownloadStatus)
332 self.storage_index = None
335 self.status = "Not started"
340 def get_storage_index(self):
341 return self.storage_index
344 def using_helper(self):
346 def get_status(self):
349 status += " (output paused)"
351 status += " (output stopped)"
353 def get_progress(self):
356 def set_storage_index(self, si):
357 self.storage_index = si
358 def set_size(self, size):
360 def set_helper(self, helper):
362 def set_status(self, status):
364 def set_paused(self, paused):
366 def set_stopped(self, stopped):
367 self.stopped = stopped
368 def set_progress(self, value):
369 self.progress = value
372 class FileDownloader:
373 implements(IPushProducer)
374 check_crypttext_hash = True
375 check_plaintext_hash = True
378 def __init__(self, client, u, downloadable):
379 self._client = client
382 self._storage_index = u.storage_index
383 self._uri_extension_hash = u.uri_extension_hash
384 self._total_shares = u.total_shares
386 self._num_needed_shares = u.needed_shares
390 self._status = s = DownloadStatus()
391 s.set_status("Starting")
392 s.set_storage_index(self._storage_index)
393 s.set_size(self._size)
396 if IConsumer.providedBy(downloadable):
397 downloadable.registerProducer(self, True)
398 self._downloadable = downloadable
399 self._output = Output(downloadable, u.key, self._size, self._log_number,
402 self._stopped = False
404 self.active_buckets = {} # k: shnum, v: bucket
405 self._share_buckets = [] # list of (sharenum, bucket) tuples
406 self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
407 self._uri_extension_sources = []
409 self._uri_extension_data = None
411 self._fetch_failures = {"uri_extension": 0,
412 "plaintext_hashroot": 0,
413 "plaintext_hashtree": 0,
414 "crypttext_hashroot": 0,
415 "crypttext_hashtree": 0,
418 def init_logging(self):
419 self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
420 num = self._client.log(format="FileDownloader(%(si)s): starting",
421 si=storage.si_b2a(self._storage_index))
422 self._log_number = num
424 def log(self, *args, **kwargs):
425 if "parent" not in kwargs:
426 kwargs["parent"] = self._log_number
427 if "facility" not in kwargs:
428 kwargs["facility"] = "tahoe.download"
429 return log.msg(*args, **kwargs)
431 def pauseProducing(self):
434 self._paused = defer.Deferred()
436 self._status.set_paused(True)
438 def resumeProducing(self):
442 eventually(p.callback, None)
444 self._status.set_paused(False)
446 def stopProducing(self):
447 self.log("Download.stopProducing")
450 self._status.set_stopped(True)
453 self.log("starting download")
455 # first step: who should we download from?
456 d = defer.maybeDeferred(self._get_all_shareholders)
457 d.addCallback(self._got_all_shareholders)
458 # now get the uri_extension block from somebody and validate it
459 d.addCallback(self._obtain_uri_extension)
460 d.addCallback(self._got_uri_extension)
461 d.addCallback(self._get_hashtrees)
462 d.addCallback(self._create_validated_buckets)
463 # once we know that, we can download blocks from everybody
464 d.addCallback(self._download_all_segments)
467 self._status.set_status("Finished")
468 if IConsumer.providedBy(self._downloadable):
469 self._downloadable.unregisterProducer()
474 self._status.set_status("Failed")
475 self._output.fail(why)
477 d.addErrback(_failed)
478 d.addCallback(self._done)
481 def _get_all_shareholders(self):
483 for (peerid,ss) in self._client.get_permuted_peers("storage",
484 self._storage_index):
485 d = ss.callRemote("get_buckets", self._storage_index)
486 d.addCallbacks(self._got_response, self._got_error)
488 self._responses_received = 0
489 self._queries_sent = len(dl)
491 self._status.set_status("Locating Shares (%d/%d)" %
492 (self._responses_received,
494 return defer.DeferredList(dl)
496 def _got_response(self, buckets):
497 self._responses_received += 1
499 self._status.set_status("Locating Shares (%d/%d)" %
500 (self._responses_received,
502 for sharenum, bucket in buckets.iteritems():
503 b = storage.ReadBucketProxy(bucket)
504 self.add_share_bucket(sharenum, b)
505 self._uri_extension_sources.append(b)
507 def add_share_bucket(self, sharenum, bucket):
508 # this is split out for the benefit of test_encode.py
509 self._share_buckets.append( (sharenum, bucket) )
511 def _got_error(self, f):
512 self._client.log("Somebody failed. -- %s" % (f,))
514 def bucket_failed(self, vbucket):
515 shnum = vbucket.sharenum
516 del self.active_buckets[shnum]
517 s = self._share_vbuckets[shnum]
518 # s is a set of ValidatedBucket instances
520 # ... which might now be empty
522 # there are no more buckets which can provide this share, so
523 # remove the key. This may prompt us to use a different share.
524 del self._share_vbuckets[shnum]
526 def _got_all_shareholders(self, res):
527 if len(self._share_buckets) < self._num_needed_shares:
528 raise NotEnoughPeersError
530 #for s in self._share_vbuckets.values():
532 # assert isinstance(vb, ValidatedBucket), \
533 # "vb is %s but should be a ValidatedBucket" % (vb,)
535 def _unpack_uri_extension_data(self, data):
536 return uri.unpack_extension(data)
538 def _obtain_uri_extension(self, ignored):
539 # all shareholders are supposed to have a copy of uri_extension, and
540 # all are supposed to be identical. We compute the hash of the data
541 # that comes back, and compare it against the version in our URI. If
542 # they don't match, ignore their data and try someone else.
544 self._status.set_status("Obtaining URI Extension")
546 def _validate(proposal, bucket):
547 h = hashutil.uri_extension_hash(proposal)
548 if h != self._uri_extension_hash:
549 self._fetch_failures["uri_extension"] += 1
550 msg = ("The copy of uri_extension we received from "
551 "%s was bad" % bucket)
552 raise BadURIExtensionHashValue(msg)
553 return self._unpack_uri_extension_data(proposal)
554 return self._obtain_validated_thing(None,
555 self._uri_extension_sources,
557 "get_uri_extension", (), _validate)
559 def _obtain_validated_thing(self, ignored, sources, name, methname, args,
562 raise NotEnoughPeersError("started with zero peers while fetching "
565 sources = sources[1:]
566 #d = bucket.callRemote(methname, *args)
567 d = bucket.startIfNecessary()
568 d.addCallback(lambda res: getattr(bucket, methname)(*args))
569 d.addCallback(validatorfunc, bucket)
571 self.log("WEIRD: %s from vbucket %s failed: %s" % (name, bucket, f))
573 raise NotEnoughPeersError("ran out of peers, last error was %s"
575 # try again with a different one
576 return self._obtain_validated_thing(None, sources, name,
577 methname, args, validatorfunc)
581 def _got_uri_extension(self, uri_extension_data):
582 d = self._uri_extension_data = uri_extension_data
584 self._codec = codec.get_decoder_by_name(d['codec_name'])
585 self._codec.set_serialized_params(d['codec_params'])
586 self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
587 self._tail_codec.set_serialized_params(d['tail_codec_params'])
589 crypttext_hash = d['crypttext_hash']
590 assert isinstance(crypttext_hash, str)
591 assert len(crypttext_hash) == 32
592 self._crypttext_hash = crypttext_hash
593 self._plaintext_hash = d['plaintext_hash']
594 self._roothash = d['share_root_hash']
596 self._segment_size = segment_size = d['segment_size']
597 self._total_segments = mathutil.div_ceil(self._size, segment_size)
598 self._current_segnum = 0
600 self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
601 self._share_hashtree.set_hashes({0: self._roothash})
603 def _get_hashtrees(self, res):
605 self._status.set_status("Retrieving Hash Trees")
606 d = self._get_plaintext_hashtrees()
607 d.addCallback(self._get_crypttext_hashtrees)
608 d.addCallback(self._setup_hashtrees)
611 def _get_plaintext_hashtrees(self):
612 def _validate_plaintext_hashtree(proposal, bucket):
613 if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
614 self._fetch_failures["plaintext_hashroot"] += 1
615 msg = ("The copy of the plaintext_root_hash we received from"
616 " %s was bad" % bucket)
617 raise BadPlaintextHashValue(msg)
618 pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
619 pt_hashes = dict(list(enumerate(proposal)))
621 pt_hashtree.set_hashes(pt_hashes)
622 except hashtree.BadHashError:
623 # the hashes they gave us were not self-consistent, even
624 # though the root matched what we saw in the uri_extension
626 self._fetch_failures["plaintext_hashtree"] += 1
628 self._plaintext_hashtree = pt_hashtree
629 d = self._obtain_validated_thing(None,
630 self._uri_extension_sources,
632 "get_plaintext_hashes", (),
633 _validate_plaintext_hashtree)
636 def _get_crypttext_hashtrees(self, res):
637 def _validate_crypttext_hashtree(proposal, bucket):
638 if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
639 self._fetch_failures["crypttext_hashroot"] += 1
640 msg = ("The copy of the crypttext_root_hash we received from"
641 " %s was bad" % bucket)
642 raise BadCrypttextHashValue(msg)
643 ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
644 ct_hashes = dict(list(enumerate(proposal)))
646 ct_hashtree.set_hashes(ct_hashes)
647 except hashtree.BadHashError:
648 self._fetch_failures["crypttext_hashtree"] += 1
650 ct_hashtree.set_hashes(ct_hashes)
651 self._crypttext_hashtree = ct_hashtree
652 d = self._obtain_validated_thing(None,
653 self._uri_extension_sources,
655 "get_crypttext_hashes", (),
656 _validate_crypttext_hashtree)
659 def _setup_hashtrees(self, res):
660 self._output.setup_hashtrees(self._plaintext_hashtree,
661 self._crypttext_hashtree)
664 def _create_validated_buckets(self, ignored=None):
665 self._share_vbuckets = {}
666 for sharenum, bucket in self._share_buckets:
667 vbucket = ValidatedBucket(sharenum, bucket,
668 self._share_hashtree,
670 self._total_segments)
671 s = self._share_vbuckets.setdefault(sharenum, set())
674 def _activate_enough_buckets(self):
675 """either return a mapping from shnum to a ValidatedBucket that can
676 provide data for that share, or raise NotEnoughPeersError"""
678 while len(self.active_buckets) < self._num_needed_shares:
680 handled_shnums = set(self.active_buckets.keys())
681 available_shnums = set(self._share_vbuckets.keys())
682 potential_shnums = list(available_shnums - handled_shnums)
683 if not potential_shnums:
684 raise NotEnoughPeersError
685 # choose a random share
686 shnum = random.choice(potential_shnums)
687 # and a random bucket that will provide it
688 validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
689 self.active_buckets[shnum] = validated_bucket
690 return self.active_buckets
693 def _download_all_segments(self, res):
694 # the promise: upon entry to this function, self._share_vbuckets
695 # contains enough buckets to complete the download, and some extra
696 # ones to tolerate some buckets dropping out or having errors.
697 # self._share_vbuckets is a dictionary that maps from shnum to a set
698 # of ValidatedBuckets, which themselves are wrappers around
699 # RIBucketReader references.
700 self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
702 d = defer.succeed(None)
703 for segnum in range(self._total_segments-1):
704 d.addCallback(self._download_segment, segnum)
705 # this pause, at the end of write, prevents pre-fetch from
706 # happening until the consumer is ready for more data.
707 d.addCallback(self._check_for_pause)
708 d.addCallback(self._download_tail_segment, self._total_segments-1)
711 def _check_for_pause(self, res):
714 self._paused.addCallback(lambda ignored: d.callback(res))
717 raise DownloadStopped("our Consumer called stopProducing()")
720 def _download_segment(self, res, segnum):
722 self._status.set_status("Downloading segment %d of %d" %
723 (segnum+1, self._total_segments))
724 self.log("downloading seg#%d of %d (%d%%)"
725 % (segnum, self._total_segments,
726 100.0 * segnum / self._total_segments))
727 # memory footprint: when the SegmentDownloader finishes pulling down
728 # all shares, we have 1*segment_size of usage.
729 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
730 d = segmentdler.start()
731 # pause before using more memory
732 d.addCallback(self._check_for_pause)
733 # while the codec does its job, we hit 2*segment_size
734 d.addCallback(lambda (shares, shareids):
735 self._codec.decode(shares, shareids))
736 # once the codec is done, we drop back to 1*segment_size, because
737 # 'shares' goes out of scope. The memory usage is all in the
738 # plaintext now, spread out into a bunch of tiny buffers.
740 # pause/check-for-stop just before writing, to honor stopProducing
741 d.addCallback(self._check_for_pause)
743 # we start by joining all these buffers together into a single
744 # string. This makes Output.write easier, since it wants to hash
745 # data one segment at a time anyways, and doesn't impact our
746 # memory footprint since we're already peaking at 2*segment_size
747 # inside the codec a moment ago.
748 segment = "".join(buffers)
750 # we're down to 1*segment_size right now, but write_segment()
751 # will decrypt a copy of the segment internally, which will push
752 # us up to 2*segment_size while it runs.
753 self._output.write_segment(segment)
757 def _download_tail_segment(self, res, segnum):
758 self.log("downloading seg#%d of %d (%d%%)"
759 % (segnum, self._total_segments,
760 100.0 * segnum / self._total_segments))
761 segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
762 d = segmentdler.start()
763 # pause before using more memory
764 d.addCallback(self._check_for_pause)
765 d.addCallback(lambda (shares, shareids):
766 self._tail_codec.decode(shares, shareids))
767 # pause/check-for-stop just before writing, to honor stopProducing
768 d.addCallback(self._check_for_pause)
770 # trim off any padding added by the upload side
771 segment = "".join(buffers)
773 # we never send empty segments. If the data was an exact multiple
774 # of the segment size, the last segment will be full.
775 pad_size = mathutil.pad_size(self._size, self._segment_size)
776 tail_size = self._segment_size - pad_size
777 segment = segment[:tail_size]
778 self._output.write_segment(segment)
782 def _done(self, res):
783 self.log("download done")
785 if self.check_crypttext_hash:
786 _assert(self._crypttext_hash == self._output.crypttext_hash,
787 "bad crypttext_hash: computed=%s, expected=%s" %
788 (base32.b2a(self._output.crypttext_hash),
789 base32.b2a(self._crypttext_hash)))
790 if self.check_plaintext_hash:
791 _assert(self._plaintext_hash == self._output.plaintext_hash,
792 "bad plaintext_hash: computed=%s, expected=%s" %
793 (base32.b2a(self._output.plaintext_hash),
794 base32.b2a(self._plaintext_hash)))
795 _assert(self._output.length == self._size,
796 got=self._output.length, expected=self._size)
797 return self._output.finish()
799 def get_download_status(self):
803 class LiteralDownloader:
804 def __init__(self, client, u, downloadable):
805 self._uri = IFileURI(u)
806 assert isinstance(self._uri, uri.LiteralFileURI)
807 self._downloadable = downloadable
808 self._status = s = DownloadStatus()
809 s.set_storage_index(None)
815 data = self._uri.data
816 self._status.set_size(len(data))
817 self._downloadable.open(len(data))
818 self._downloadable.write(data)
819 self._downloadable.close()
820 return defer.maybeDeferred(self._downloadable.finish)
822 def get_download_status(self):
826 implements(IDownloadTarget)
827 def __init__(self, filename):
828 self._filename = filename
830 def open(self, size):
831 self.f = open(self._filename, "wb")
833 def write(self, data):
841 os.unlink(self._filename)
842 def register_canceller(self, cb):
843 pass # we won't use it
848 implements(IDownloadTarget)
851 def open(self, size):
853 def write(self, data):
854 self._data.append(data)
856 self.data = "".join(self._data)
860 def register_canceller(self, cb):
861 pass # we won't use it
866 """Use me to download data to a pre-defined filehandle-like object. I
867 will use the target's write() method. I will *not* close the filehandle:
868 I leave that up to the originator of the filehandle. The download process
869 will return the filehandle when it completes.
871 implements(IDownloadTarget)
872 def __init__(self, filehandle):
873 self._filehandle = filehandle
874 def open(self, size):
876 def write(self, data):
877 self._filehandle.write(data)
879 # the originator of the filehandle reserves the right to close it
883 def register_canceller(self, cb):
886 return self._filehandle
888 class Downloader(service.MultiService):
889 """I am a service that allows file downloading.
891 implements(IDownloader)
895 service.MultiService.__init__(self)
896 self._all_downloads = weakref.WeakKeyDictionary()
898 def download(self, u, t):
902 t = IDownloadTarget(t)
905 if isinstance(u, uri.LiteralFileURI):
906 dl = LiteralDownloader(self.parent, u, t)
907 elif isinstance(u, uri.CHKFileURI):
908 dl = FileDownloader(self.parent, u, t)
910 raise RuntimeError("I don't know how to download a %s" % u)
911 self._all_downloads[dl.get_download_status()] = None
916 def download_to_data(self, uri):
917 return self.download(uri, Data())
918 def download_to_filename(self, uri, filename):
919 return self.download(uri, FileName(filename))
920 def download_to_filehandle(self, uri, filehandle):
921 return self.download(uri, FileHandle(filehandle))
924 def list_downloads(self):
925 return self._all_downloads.keys()