]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
7632f1e0d00a56c3fa6106e7dd20c889cb46a970
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class Output:
32     def __init__(self, downloadable, key, total_length, log_parent,
33                  download_status):
34         self.downloadable = downloadable
35         self._decryptor = AES(key)
36         self._crypttext_hasher = hashutil.crypttext_hasher()
37         self._plaintext_hasher = hashutil.plaintext_hasher()
38         self.length = 0
39         self.total_length = total_length
40         self._segment_number = 0
41         self._plaintext_hash_tree = None
42         self._crypttext_hash_tree = None
43         self._opened = False
44         self._log_parent = log_parent
45         self._status = download_status
46         self._status.set_progress(0.0)
47
48     def log(self, *args, **kwargs):
49         if "parent" not in kwargs:
50             kwargs["parent"] = self._log_parent
51         if "facility" not in kwargs:
52             kwargs["facility"] = "download.output"
53         return log.msg(*args, **kwargs)
54
55     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
56         self._plaintext_hash_tree = plaintext_hashtree
57         self._crypttext_hash_tree = crypttext_hashtree
58
59     def write_segment(self, crypttext):
60         self.length += len(crypttext)
61         self._status.set_progress( float(self.length) / self.total_length )
62
63         # memory footprint: 'crypttext' is the only segment_size usage
64         # outstanding. While we decrypt it into 'plaintext', we hit
65         # 2*segment_size.
66         self._crypttext_hasher.update(crypttext)
67         if self._crypttext_hash_tree:
68             ch = hashutil.crypttext_segment_hasher()
69             ch.update(crypttext)
70             crypttext_leaves = {self._segment_number: ch.digest()}
71             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
72                      bytes=len(crypttext),
73                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
74                      level=log.NOISY)
75             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
76
77         plaintext = self._decryptor.process(crypttext)
78         del crypttext
79
80         # now we're back down to 1*segment_size.
81
82         self._plaintext_hasher.update(plaintext)
83         if self._plaintext_hash_tree:
84             ph = hashutil.plaintext_segment_hasher()
85             ph.update(plaintext)
86             plaintext_leaves = {self._segment_number: ph.digest()}
87             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
88                      bytes=len(plaintext),
89                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
90                      level=log.NOISY)
91             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
92
93         self._segment_number += 1
94         # We're still at 1*segment_size. The Downloadable is responsible for
95         # any memory usage beyond this.
96         if not self._opened:
97             self._opened = True
98             self.downloadable.open(self.total_length)
99         self.downloadable.write(plaintext)
100
101     def fail(self, why):
102         # this is really unusual, and deserves maximum forensics
103         self.log("download failed!", failure=why, level=log.SCARY)
104         self.downloadable.fail(why)
105
106     def close(self):
107         self.crypttext_hash = self._crypttext_hasher.digest()
108         self.plaintext_hash = self._plaintext_hasher.digest()
109         self.log("download finished, closing IDownloadable", level=log.NOISY)
110         self.downloadable.close()
111
112     def finish(self):
113         return self.downloadable.finish()
114
115 class ValidatedBucket:
116     """I am a front-end for a remote storage bucket, responsible for
117     retrieving and validating data from that bucket.
118
119     My get_block() method is used by BlockDownloaders.
120     """
121
122     def __init__(self, sharenum, bucket,
123                  share_hash_tree, roothash,
124                  num_blocks):
125         self.sharenum = sharenum
126         self.bucket = bucket
127         self._share_hash = None # None means not validated yet
128         self.share_hash_tree = share_hash_tree
129         self._roothash = roothash
130         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
131         self.started = False
132
133     def get_block(self, blocknum):
134         if not self.started:
135             d = self.bucket.start()
136             def _started(res):
137                 self.started = True
138                 return self.get_block(blocknum)
139             d.addCallback(_started)
140             return d
141
142         # the first time we use this bucket, we need to fetch enough elements
143         # of the share hash tree to validate it from our share hash up to the
144         # hashroot.
145         if not self._share_hash:
146             d1 = self.bucket.get_share_hashes()
147         else:
148             d1 = defer.succeed([])
149
150         # we might need to grab some elements of our block hash tree, to
151         # validate the requested block up to the share hash
152         needed = self.block_hash_tree.needed_hashes(blocknum)
153         if needed:
154             # TODO: get fewer hashes, use get_block_hashes(needed)
155             d2 = self.bucket.get_block_hashes()
156         else:
157             d2 = defer.succeed([])
158
159         d3 = self.bucket.get_block(blocknum)
160
161         d = defer.gatherResults([d1, d2, d3])
162         d.addCallback(self._got_data, blocknum)
163         return d
164
165     def _got_data(self, res, blocknum):
166         sharehashes, blockhashes, blockdata = res
167         blockhash = None # to make logging it safe
168
169         try:
170             if not self._share_hash:
171                 sh = dict(sharehashes)
172                 sh[0] = self._roothash # always use our own root, from the URI
173                 sht = self.share_hash_tree
174                 if sht.get_leaf_index(self.sharenum) not in sh:
175                     raise hashtree.NotEnoughHashesError
176                 sht.set_hashes(sh)
177                 self._share_hash = sht.get_leaf(self.sharenum)
178
179             blockhash = hashutil.block_hash(blockdata)
180             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
181             #        "%r .. %r: %s" %
182             #        (self.sharenum, blocknum, len(blockdata),
183             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
184
185             # we always validate the blockhash
186             bh = dict(enumerate(blockhashes))
187             # replace blockhash root with validated value
188             bh[0] = self._share_hash
189             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
190
191         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
192             # log.WEIRD: indicates undetected disk/network error, or more
193             # likely a programming error
194             log.msg("hash failure in block=%d, shnum=%d on %s" %
195                     (blocknum, self.sharenum, self.bucket))
196             if self._share_hash:
197                 log.msg(""" failure occurred when checking the block_hash_tree.
198                 This suggests that either the block data was bad, or that the
199                 block hashes we received along with it were bad.""")
200             else:
201                 log.msg(""" the failure probably occurred when checking the
202                 share_hash_tree, which suggests that the share hashes we
203                 received from the remote peer were bad.""")
204             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
205             log.msg(" block length: %d" % len(blockdata))
206             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
207             if len(blockdata) < 100:
208                 log.msg(" block data: %r" % (blockdata,))
209             else:
210                 log.msg(" block data start/end: %r .. %r" %
211                         (blockdata[:50], blockdata[-50:]))
212             log.msg(" root hash: %s" % base32.b2a(self._roothash))
213             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
214             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
215             lines = []
216             for i,h in sorted(sharehashes):
217                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
218             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
219             lines = []
220             for i,h in enumerate(blockhashes):
221                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
222             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
223             raise
224
225         # If we made it here, the block is good. If the hash trees didn't
226         # like what they saw, they would have raised a BadHashError, causing
227         # our caller to see a Failure and thus ignore this block (as well as
228         # dropping this bucket).
229         return blockdata
230
231
232
233 class BlockDownloader:
234     """I am responsible for downloading a single block (from a single bucket)
235     for a single segment.
236
237     I am a child of the SegmentDownloader.
238     """
239
240     def __init__(self, vbucket, blocknum, parent):
241         self.vbucket = vbucket
242         self.blocknum = blocknum
243         self.parent = parent
244         self._log_number = self.parent.log("starting block %d" % blocknum)
245
246     def log(self, msg, parent=None):
247         if parent is None:
248             parent = self._log_number
249         return self.parent.log(msg, parent=parent)
250
251     def start(self, segnum):
252         lognum = self.log("get_block(segnum=%d)" % segnum)
253         d = self.vbucket.get_block(segnum)
254         d.addCallbacks(self._hold_block, self._got_block_error,
255                        callbackArgs=(lognum,), errbackArgs=(lognum,))
256         return d
257
258     def _hold_block(self, data, lognum):
259         self.log("got block", parent=lognum)
260         self.parent.hold_block(self.blocknum, data)
261
262     def _got_block_error(self, f, lognum):
263         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
264                  parent=lognum)
265         self.parent.bucket_failed(self.vbucket)
266
267 class SegmentDownloader:
268     """I am responsible for downloading all the blocks for a single segment
269     of data.
270
271     I am a child of the FileDownloader.
272     """
273
274     def __init__(self, parent, segmentnumber, needed_shares):
275         self.parent = parent
276         self.segmentnumber = segmentnumber
277         self.needed_blocks = needed_shares
278         self.blocks = {} # k: blocknum, v: data
279         self._log_number = self.parent.log("starting segment %d" %
280                                            segmentnumber)
281
282     def log(self, msg, parent=None):
283         if parent is None:
284             parent = self._log_number
285         return self.parent.log(msg, parent=parent)
286
287     def start(self):
288         return self._download()
289
290     def _download(self):
291         d = self._try()
292         def _done(res):
293             if len(self.blocks) >= self.needed_blocks:
294                 # we only need self.needed_blocks blocks
295                 # we want to get the smallest blockids, because they are
296                 # more likely to be fast "primary blocks"
297                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
298                 blocks = []
299                 for blocknum in blockids:
300                     blocks.append(self.blocks[blocknum])
301                 return (blocks, blockids)
302             else:
303                 return self._download()
304         d.addCallback(_done)
305         return d
306
307     def _try(self):
308         # fill our set of active buckets, maybe raising NotEnoughPeersError
309         active_buckets = self.parent._activate_enough_buckets()
310         # Now we have enough buckets, in self.parent.active_buckets.
311
312         # in test cases, bd.start might mutate active_buckets right away, so
313         # we need to put off calling start() until we've iterated all the way
314         # through it.
315         downloaders = []
316         for blocknum, vbucket in active_buckets.iteritems():
317             bd = BlockDownloader(vbucket, blocknum, self)
318             downloaders.append(bd)
319         l = [bd.start(self.segmentnumber) for bd in downloaders]
320         return defer.DeferredList(l, fireOnOneErrback=True)
321
322     def hold_block(self, blocknum, data):
323         self.blocks[blocknum] = data
324
325     def bucket_failed(self, vbucket):
326         self.parent.bucket_failed(vbucket)
327
328 class DownloadStatus:
329     implements(IDownloadStatus)
330
331     def __init__(self):
332         self.storage_index = None
333         self.size = None
334         self.helper = False
335         self.status = "Not started"
336         self.progress = 0.0
337         self.paused = False
338         self.stopped = False
339
340     def get_storage_index(self):
341         return self.storage_index
342     def get_size(self):
343         return self.size
344     def using_helper(self):
345         return self.helper
346     def get_status(self):
347         status = self.status
348         if self.paused:
349             status += " (output paused)"
350         if self.stopped:
351             status += " (output stopped)"
352         return status
353     def get_progress(self):
354         return self.progress
355
356     def set_storage_index(self, si):
357         self.storage_index = si
358     def set_size(self, size):
359         self.size = size
360     def set_helper(self, helper):
361         self.helper = helper
362     def set_status(self, status):
363         self.status = status
364     def set_paused(self, paused):
365         self.paused = paused
366     def set_stopped(self, stopped):
367         self.stopped = stopped
368     def set_progress(self, value):
369         self.progress = value
370
371
372 class FileDownloader:
373     implements(IPushProducer)
374     check_crypttext_hash = True
375     check_plaintext_hash = True
376     _status = None
377
378     def __init__(self, client, u, downloadable):
379         self._client = client
380
381         u = IFileURI(u)
382         self._storage_index = u.storage_index
383         self._uri_extension_hash = u.uri_extension_hash
384         self._total_shares = u.total_shares
385         self._size = u.size
386         self._num_needed_shares = u.needed_shares
387
388         self.init_logging()
389
390         self._status = s = DownloadStatus()
391         s.set_status("Starting")
392         s.set_storage_index(self._storage_index)
393         s.set_size(self._size)
394         s.set_helper(False)
395
396         if IConsumer.providedBy(downloadable):
397             downloadable.registerProducer(self, True)
398         self._downloadable = downloadable
399         self._output = Output(downloadable, u.key, self._size, self._log_number,
400                               self._status)
401         self._paused = False
402         self._stopped = False
403
404         self.active_buckets = {} # k: shnum, v: bucket
405         self._share_buckets = [] # list of (sharenum, bucket) tuples
406         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
407         self._uri_extension_sources = []
408
409         self._uri_extension_data = None
410
411         self._fetch_failures = {"uri_extension": 0,
412                                 "plaintext_hashroot": 0,
413                                 "plaintext_hashtree": 0,
414                                 "crypttext_hashroot": 0,
415                                 "crypttext_hashtree": 0,
416                                 }
417
418     def init_logging(self):
419         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
420         num = self._client.log(format="FileDownloader(%(si)s): starting",
421                                si=storage.si_b2a(self._storage_index))
422         self._log_number = num
423
424     def log(self, *args, **kwargs):
425         if "parent" not in kwargs:
426             kwargs["parent"] = self._log_number
427         if "facility" not in kwargs:
428             kwargs["facility"] = "tahoe.download"
429         return log.msg(*args, **kwargs)
430
431     def pauseProducing(self):
432         if self._paused:
433             return
434         self._paused = defer.Deferred()
435         if self._status:
436             self._status.set_paused(True)
437
438     def resumeProducing(self):
439         if self._paused:
440             p = self._paused
441             self._paused = None
442             eventually(p.callback, None)
443             if self._status:
444                 self._status.set_paused(False)
445
446     def stopProducing(self):
447         self.log("Download.stopProducing")
448         self._stopped = True
449         if self._status:
450             self._status.set_stopped(True)
451
452     def start(self):
453         self.log("starting download")
454
455         # first step: who should we download from?
456         d = defer.maybeDeferred(self._get_all_shareholders)
457         d.addCallback(self._got_all_shareholders)
458         # now get the uri_extension block from somebody and validate it
459         d.addCallback(self._obtain_uri_extension)
460         d.addCallback(self._got_uri_extension)
461         d.addCallback(self._get_hashtrees)
462         d.addCallback(self._create_validated_buckets)
463         # once we know that, we can download blocks from everybody
464         d.addCallback(self._download_all_segments)
465         def _finished(res):
466             if self._status:
467                 self._status.set_status("Finished")
468             if IConsumer.providedBy(self._downloadable):
469                 self._downloadable.unregisterProducer()
470             return res
471         d.addBoth(_finished)
472         def _failed(why):
473             if self._status:
474                 self._status.set_status("Failed")
475             self._output.fail(why)
476             return why
477         d.addErrback(_failed)
478         d.addCallback(self._done)
479         return d
480
481     def _get_all_shareholders(self):
482         dl = []
483         for (peerid,ss) in self._client.get_permuted_peers("storage",
484                                                            self._storage_index):
485             d = ss.callRemote("get_buckets", self._storage_index)
486             d.addCallbacks(self._got_response, self._got_error)
487             dl.append(d)
488         self._responses_received = 0
489         self._queries_sent = len(dl)
490         if self._status:
491             self._status.set_status("Locating Shares (%d/%d)" %
492                                     (self._responses_received,
493                                      self._queries_sent))
494         return defer.DeferredList(dl)
495
496     def _got_response(self, buckets):
497         self._responses_received += 1
498         if self._status:
499             self._status.set_status("Locating Shares (%d/%d)" %
500                                     (self._responses_received,
501                                      self._queries_sent))
502         for sharenum, bucket in buckets.iteritems():
503             b = storage.ReadBucketProxy(bucket)
504             self.add_share_bucket(sharenum, b)
505             self._uri_extension_sources.append(b)
506
507     def add_share_bucket(self, sharenum, bucket):
508         # this is split out for the benefit of test_encode.py
509         self._share_buckets.append( (sharenum, bucket) )
510
511     def _got_error(self, f):
512         self._client.log("Somebody failed. -- %s" % (f,))
513
514     def bucket_failed(self, vbucket):
515         shnum = vbucket.sharenum
516         del self.active_buckets[shnum]
517         s = self._share_vbuckets[shnum]
518         # s is a set of ValidatedBucket instances
519         s.remove(vbucket)
520         # ... which might now be empty
521         if not s:
522             # there are no more buckets which can provide this share, so
523             # remove the key. This may prompt us to use a different share.
524             del self._share_vbuckets[shnum]
525
526     def _got_all_shareholders(self, res):
527         if len(self._share_buckets) < self._num_needed_shares:
528             raise NotEnoughPeersError
529
530         #for s in self._share_vbuckets.values():
531         #    for vb in s:
532         #        assert isinstance(vb, ValidatedBucket), \
533         #               "vb is %s but should be a ValidatedBucket" % (vb,)
534
535     def _unpack_uri_extension_data(self, data):
536         return uri.unpack_extension(data)
537
538     def _obtain_uri_extension(self, ignored):
539         # all shareholders are supposed to have a copy of uri_extension, and
540         # all are supposed to be identical. We compute the hash of the data
541         # that comes back, and compare it against the version in our URI. If
542         # they don't match, ignore their data and try someone else.
543         if self._status:
544             self._status.set_status("Obtaining URI Extension")
545
546         def _validate(proposal, bucket):
547             h = hashutil.uri_extension_hash(proposal)
548             if h != self._uri_extension_hash:
549                 self._fetch_failures["uri_extension"] += 1
550                 msg = ("The copy of uri_extension we received from "
551                        "%s was bad" % bucket)
552                 raise BadURIExtensionHashValue(msg)
553             return self._unpack_uri_extension_data(proposal)
554         return self._obtain_validated_thing(None,
555                                             self._uri_extension_sources,
556                                             "uri_extension",
557                                             "get_uri_extension", (), _validate)
558
559     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
560                                 validatorfunc):
561         if not sources:
562             raise NotEnoughPeersError("started with zero peers while fetching "
563                                       "%s" % name)
564         bucket = sources[0]
565         sources = sources[1:]
566         #d = bucket.callRemote(methname, *args)
567         d = bucket.startIfNecessary()
568         d.addCallback(lambda res: getattr(bucket, methname)(*args))
569         d.addCallback(validatorfunc, bucket)
570         def _bad(f):
571             self.log("WEIRD: %s from vbucket %s failed: %s" % (name, bucket, f))
572             if not sources:
573                 raise NotEnoughPeersError("ran out of peers, last error was %s"
574                                           % (f,))
575             # try again with a different one
576             return self._obtain_validated_thing(None, sources, name,
577                                                 methname, args, validatorfunc)
578         d.addErrback(_bad)
579         return d
580
581     def _got_uri_extension(self, uri_extension_data):
582         d = self._uri_extension_data = uri_extension_data
583
584         self._codec = codec.get_decoder_by_name(d['codec_name'])
585         self._codec.set_serialized_params(d['codec_params'])
586         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
587         self._tail_codec.set_serialized_params(d['tail_codec_params'])
588
589         crypttext_hash = d['crypttext_hash']
590         assert isinstance(crypttext_hash, str)
591         assert len(crypttext_hash) == 32
592         self._crypttext_hash = crypttext_hash
593         self._plaintext_hash = d['plaintext_hash']
594         self._roothash = d['share_root_hash']
595
596         self._segment_size = segment_size = d['segment_size']
597         self._total_segments = mathutil.div_ceil(self._size, segment_size)
598         self._current_segnum = 0
599
600         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
601         self._share_hashtree.set_hashes({0: self._roothash})
602
603     def _get_hashtrees(self, res):
604         if self._status:
605             self._status.set_status("Retrieving Hash Trees")
606         d = self._get_plaintext_hashtrees()
607         d.addCallback(self._get_crypttext_hashtrees)
608         d.addCallback(self._setup_hashtrees)
609         return d
610
611     def _get_plaintext_hashtrees(self):
612         def _validate_plaintext_hashtree(proposal, bucket):
613             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
614                 self._fetch_failures["plaintext_hashroot"] += 1
615                 msg = ("The copy of the plaintext_root_hash we received from"
616                        " %s was bad" % bucket)
617                 raise BadPlaintextHashValue(msg)
618             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
619             pt_hashes = dict(list(enumerate(proposal)))
620             try:
621                 pt_hashtree.set_hashes(pt_hashes)
622             except hashtree.BadHashError:
623                 # the hashes they gave us were not self-consistent, even
624                 # though the root matched what we saw in the uri_extension
625                 # block
626                 self._fetch_failures["plaintext_hashtree"] += 1
627                 raise
628             self._plaintext_hashtree = pt_hashtree
629         d = self._obtain_validated_thing(None,
630                                          self._uri_extension_sources,
631                                          "plaintext_hashes",
632                                          "get_plaintext_hashes", (),
633                                          _validate_plaintext_hashtree)
634         return d
635
636     def _get_crypttext_hashtrees(self, res):
637         def _validate_crypttext_hashtree(proposal, bucket):
638             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
639                 self._fetch_failures["crypttext_hashroot"] += 1
640                 msg = ("The copy of the crypttext_root_hash we received from"
641                        " %s was bad" % bucket)
642                 raise BadCrypttextHashValue(msg)
643             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
644             ct_hashes = dict(list(enumerate(proposal)))
645             try:
646                 ct_hashtree.set_hashes(ct_hashes)
647             except hashtree.BadHashError:
648                 self._fetch_failures["crypttext_hashtree"] += 1
649                 raise
650             ct_hashtree.set_hashes(ct_hashes)
651             self._crypttext_hashtree = ct_hashtree
652         d = self._obtain_validated_thing(None,
653                                          self._uri_extension_sources,
654                                          "crypttext_hashes",
655                                          "get_crypttext_hashes", (),
656                                          _validate_crypttext_hashtree)
657         return d
658
659     def _setup_hashtrees(self, res):
660         self._output.setup_hashtrees(self._plaintext_hashtree,
661                                      self._crypttext_hashtree)
662
663
664     def _create_validated_buckets(self, ignored=None):
665         self._share_vbuckets = {}
666         for sharenum, bucket in self._share_buckets:
667             vbucket = ValidatedBucket(sharenum, bucket,
668                                       self._share_hashtree,
669                                       self._roothash,
670                                       self._total_segments)
671             s = self._share_vbuckets.setdefault(sharenum, set())
672             s.add(vbucket)
673
674     def _activate_enough_buckets(self):
675         """either return a mapping from shnum to a ValidatedBucket that can
676         provide data for that share, or raise NotEnoughPeersError"""
677
678         while len(self.active_buckets) < self._num_needed_shares:
679             # need some more
680             handled_shnums = set(self.active_buckets.keys())
681             available_shnums = set(self._share_vbuckets.keys())
682             potential_shnums = list(available_shnums - handled_shnums)
683             if not potential_shnums:
684                 raise NotEnoughPeersError
685             # choose a random share
686             shnum = random.choice(potential_shnums)
687             # and a random bucket that will provide it
688             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
689             self.active_buckets[shnum] = validated_bucket
690         return self.active_buckets
691
692
693     def _download_all_segments(self, res):
694         # the promise: upon entry to this function, self._share_vbuckets
695         # contains enough buckets to complete the download, and some extra
696         # ones to tolerate some buckets dropping out or having errors.
697         # self._share_vbuckets is a dictionary that maps from shnum to a set
698         # of ValidatedBuckets, which themselves are wrappers around
699         # RIBucketReader references.
700         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
701
702         d = defer.succeed(None)
703         for segnum in range(self._total_segments-1):
704             d.addCallback(self._download_segment, segnum)
705             # this pause, at the end of write, prevents pre-fetch from
706             # happening until the consumer is ready for more data.
707             d.addCallback(self._check_for_pause)
708         d.addCallback(self._download_tail_segment, self._total_segments-1)
709         return d
710
711     def _check_for_pause(self, res):
712         if self._paused:
713             d = defer.Deferred()
714             self._paused.addCallback(lambda ignored: d.callback(res))
715             return d
716         if self._stopped:
717             raise DownloadStopped("our Consumer called stopProducing()")
718         return res
719
720     def _download_segment(self, res, segnum):
721         if self._status:
722             self._status.set_status("Downloading segment %d of %d" %
723                                     (segnum+1, self._total_segments))
724         self.log("downloading seg#%d of %d (%d%%)"
725                  % (segnum, self._total_segments,
726                     100.0 * segnum / self._total_segments))
727         # memory footprint: when the SegmentDownloader finishes pulling down
728         # all shares, we have 1*segment_size of usage.
729         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
730         d = segmentdler.start()
731         # pause before using more memory
732         d.addCallback(self._check_for_pause)
733         # while the codec does its job, we hit 2*segment_size
734         d.addCallback(lambda (shares, shareids):
735                       self._codec.decode(shares, shareids))
736         # once the codec is done, we drop back to 1*segment_size, because
737         # 'shares' goes out of scope. The memory usage is all in the
738         # plaintext now, spread out into a bunch of tiny buffers.
739
740         # pause/check-for-stop just before writing, to honor stopProducing
741         d.addCallback(self._check_for_pause)
742         def _done(buffers):
743             # we start by joining all these buffers together into a single
744             # string. This makes Output.write easier, since it wants to hash
745             # data one segment at a time anyways, and doesn't impact our
746             # memory footprint since we're already peaking at 2*segment_size
747             # inside the codec a moment ago.
748             segment = "".join(buffers)
749             del buffers
750             # we're down to 1*segment_size right now, but write_segment()
751             # will decrypt a copy of the segment internally, which will push
752             # us up to 2*segment_size while it runs.
753             self._output.write_segment(segment)
754         d.addCallback(_done)
755         return d
756
757     def _download_tail_segment(self, res, segnum):
758         self.log("downloading seg#%d of %d (%d%%)"
759                  % (segnum, self._total_segments,
760                     100.0 * segnum / self._total_segments))
761         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
762         d = segmentdler.start()
763         # pause before using more memory
764         d.addCallback(self._check_for_pause)
765         d.addCallback(lambda (shares, shareids):
766                       self._tail_codec.decode(shares, shareids))
767         # pause/check-for-stop just before writing, to honor stopProducing
768         d.addCallback(self._check_for_pause)
769         def _done(buffers):
770             # trim off any padding added by the upload side
771             segment = "".join(buffers)
772             del buffers
773             # we never send empty segments. If the data was an exact multiple
774             # of the segment size, the last segment will be full.
775             pad_size = mathutil.pad_size(self._size, self._segment_size)
776             tail_size = self._segment_size - pad_size
777             segment = segment[:tail_size]
778             self._output.write_segment(segment)
779         d.addCallback(_done)
780         return d
781
782     def _done(self, res):
783         self.log("download done")
784         self._output.close()
785         if self.check_crypttext_hash:
786             _assert(self._crypttext_hash == self._output.crypttext_hash,
787                     "bad crypttext_hash: computed=%s, expected=%s" %
788                     (base32.b2a(self._output.crypttext_hash),
789                      base32.b2a(self._crypttext_hash)))
790         if self.check_plaintext_hash:
791             _assert(self._plaintext_hash == self._output.plaintext_hash,
792                     "bad plaintext_hash: computed=%s, expected=%s" %
793                     (base32.b2a(self._output.plaintext_hash),
794                      base32.b2a(self._plaintext_hash)))
795         _assert(self._output.length == self._size,
796                 got=self._output.length, expected=self._size)
797         return self._output.finish()
798
799     def get_download_status(self):
800         return self._status
801
802
803 class LiteralDownloader:
804     def __init__(self, client, u, downloadable):
805         self._uri = IFileURI(u)
806         assert isinstance(self._uri, uri.LiteralFileURI)
807         self._downloadable = downloadable
808         self._status = s = DownloadStatus()
809         s.set_storage_index(None)
810         s.set_helper(False)
811         s.set_status("Done")
812         s.set_progress(1.0)
813
814     def start(self):
815         data = self._uri.data
816         self._status.set_size(len(data))
817         self._downloadable.open(len(data))
818         self._downloadable.write(data)
819         self._downloadable.close()
820         return defer.maybeDeferred(self._downloadable.finish)
821
822     def get_download_status(self):
823         return self._status
824
825 class FileName:
826     implements(IDownloadTarget)
827     def __init__(self, filename):
828         self._filename = filename
829         self.f = None
830     def open(self, size):
831         self.f = open(self._filename, "wb")
832         return self.f
833     def write(self, data):
834         self.f.write(data)
835     def close(self):
836         if self.f:
837             self.f.close()
838     def fail(self, why):
839         if self.f:
840             self.f.close()
841             os.unlink(self._filename)
842     def register_canceller(self, cb):
843         pass # we won't use it
844     def finish(self):
845         pass
846
847 class Data:
848     implements(IDownloadTarget)
849     def __init__(self):
850         self._data = []
851     def open(self, size):
852         pass
853     def write(self, data):
854         self._data.append(data)
855     def close(self):
856         self.data = "".join(self._data)
857         del self._data
858     def fail(self, why):
859         del self._data
860     def register_canceller(self, cb):
861         pass # we won't use it
862     def finish(self):
863         return self.data
864
865 class FileHandle:
866     """Use me to download data to a pre-defined filehandle-like object. I
867     will use the target's write() method. I will *not* close the filehandle:
868     I leave that up to the originator of the filehandle. The download process
869     will return the filehandle when it completes.
870     """
871     implements(IDownloadTarget)
872     def __init__(self, filehandle):
873         self._filehandle = filehandle
874     def open(self, size):
875         pass
876     def write(self, data):
877         self._filehandle.write(data)
878     def close(self):
879         # the originator of the filehandle reserves the right to close it
880         pass
881     def fail(self, why):
882         pass
883     def register_canceller(self, cb):
884         pass
885     def finish(self):
886         return self._filehandle
887
888 class Downloader(service.MultiService):
889     """I am a service that allows file downloading.
890     """
891     implements(IDownloader)
892     name = "downloader"
893
894     def __init__(self):
895         service.MultiService.__init__(self)
896         self._all_downloads = weakref.WeakKeyDictionary()
897
898     def download(self, u, t):
899         assert self.parent
900         assert self.running
901         u = IFileURI(u)
902         t = IDownloadTarget(t)
903         assert t.write
904         assert t.close
905         if isinstance(u, uri.LiteralFileURI):
906             dl = LiteralDownloader(self.parent, u, t)
907         elif isinstance(u, uri.CHKFileURI):
908             dl = FileDownloader(self.parent, u, t)
909         else:
910             raise RuntimeError("I don't know how to download a %s" % u)
911         self._all_downloads[dl.get_download_status()] = None
912         d = dl.start()
913         return d
914
915     # utility functions
916     def download_to_data(self, uri):
917         return self.download(uri, Data())
918     def download_to_filename(self, uri, filename):
919         return self.download(uri, FileName(filename))
920     def download_to_filehandle(self, uri, filehandle):
921         return self.download(uri, FileHandle(filehandle))
922
923
924     def list_downloads(self):
925         return self._all_downloads.keys()