]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
907f9387ebb42011af552feb4d3558706279e0a2
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref, itertools
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log, idlib
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class Output:
32     def __init__(self, downloadable, key, total_length, log_parent,
33                  download_status):
34         self.downloadable = downloadable
35         self._decryptor = AES(key)
36         self._crypttext_hasher = hashutil.crypttext_hasher()
37         self._plaintext_hasher = hashutil.plaintext_hasher()
38         self.length = 0
39         self.total_length = total_length
40         self._segment_number = 0
41         self._plaintext_hash_tree = None
42         self._crypttext_hash_tree = None
43         self._opened = False
44         self._log_parent = log_parent
45         self._status = download_status
46         self._status.set_progress(0.0)
47
48     def log(self, *args, **kwargs):
49         if "parent" not in kwargs:
50             kwargs["parent"] = self._log_parent
51         if "facility" not in kwargs:
52             kwargs["facility"] = "download.output"
53         return log.msg(*args, **kwargs)
54
55     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
56         self._plaintext_hash_tree = plaintext_hashtree
57         self._crypttext_hash_tree = crypttext_hashtree
58
59     def write_segment(self, crypttext):
60         self.length += len(crypttext)
61         self._status.set_progress( float(self.length) / self.total_length )
62
63         # memory footprint: 'crypttext' is the only segment_size usage
64         # outstanding. While we decrypt it into 'plaintext', we hit
65         # 2*segment_size.
66         self._crypttext_hasher.update(crypttext)
67         if self._crypttext_hash_tree:
68             ch = hashutil.crypttext_segment_hasher()
69             ch.update(crypttext)
70             crypttext_leaves = {self._segment_number: ch.digest()}
71             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
72                      bytes=len(crypttext),
73                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
74                      level=log.NOISY)
75             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
76
77         plaintext = self._decryptor.process(crypttext)
78         del crypttext
79
80         # now we're back down to 1*segment_size.
81
82         self._plaintext_hasher.update(plaintext)
83         if self._plaintext_hash_tree:
84             ph = hashutil.plaintext_segment_hasher()
85             ph.update(plaintext)
86             plaintext_leaves = {self._segment_number: ph.digest()}
87             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
88                      bytes=len(plaintext),
89                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
90                      level=log.NOISY)
91             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
92
93         self._segment_number += 1
94         # We're still at 1*segment_size. The Downloadable is responsible for
95         # any memory usage beyond this.
96         if not self._opened:
97             self._opened = True
98             self.downloadable.open(self.total_length)
99         self.downloadable.write(plaintext)
100
101     def fail(self, why):
102         # this is really unusual, and deserves maximum forensics
103         self.log("download failed!", failure=why, level=log.SCARY)
104         self.downloadable.fail(why)
105
106     def close(self):
107         self.crypttext_hash = self._crypttext_hasher.digest()
108         self.plaintext_hash = self._plaintext_hasher.digest()
109         self.log("download finished, closing IDownloadable", level=log.NOISY)
110         self.downloadable.close()
111
112     def finish(self):
113         return self.downloadable.finish()
114
115 class ValidatedBucket:
116     """I am a front-end for a remote storage bucket, responsible for
117     retrieving and validating data from that bucket.
118
119     My get_block() method is used by BlockDownloaders.
120     """
121
122     def __init__(self, sharenum, bucket,
123                  share_hash_tree, roothash,
124                  num_blocks):
125         self.sharenum = sharenum
126         self.bucket = bucket
127         self._share_hash = None # None means not validated yet
128         self.share_hash_tree = share_hash_tree
129         self._roothash = roothash
130         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
131         self.started = False
132
133     def get_block(self, blocknum):
134         if not self.started:
135             d = self.bucket.start()
136             def _started(res):
137                 self.started = True
138                 return self.get_block(blocknum)
139             d.addCallback(_started)
140             return d
141
142         # the first time we use this bucket, we need to fetch enough elements
143         # of the share hash tree to validate it from our share hash up to the
144         # hashroot.
145         if not self._share_hash:
146             d1 = self.bucket.get_share_hashes()
147         else:
148             d1 = defer.succeed([])
149
150         # we might need to grab some elements of our block hash tree, to
151         # validate the requested block up to the share hash
152         needed = self.block_hash_tree.needed_hashes(blocknum)
153         if needed:
154             # TODO: get fewer hashes, use get_block_hashes(needed)
155             d2 = self.bucket.get_block_hashes()
156         else:
157             d2 = defer.succeed([])
158
159         d3 = self.bucket.get_block(blocknum)
160
161         d = defer.gatherResults([d1, d2, d3])
162         d.addCallback(self._got_data, blocknum)
163         return d
164
165     def _got_data(self, res, blocknum):
166         sharehashes, blockhashes, blockdata = res
167         blockhash = None # to make logging it safe
168
169         try:
170             if not self._share_hash:
171                 sh = dict(sharehashes)
172                 sh[0] = self._roothash # always use our own root, from the URI
173                 sht = self.share_hash_tree
174                 if sht.get_leaf_index(self.sharenum) not in sh:
175                     raise hashtree.NotEnoughHashesError
176                 sht.set_hashes(sh)
177                 self._share_hash = sht.get_leaf(self.sharenum)
178
179             blockhash = hashutil.block_hash(blockdata)
180             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
181             #        "%r .. %r: %s" %
182             #        (self.sharenum, blocknum, len(blockdata),
183             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
184
185             # we always validate the blockhash
186             bh = dict(enumerate(blockhashes))
187             # replace blockhash root with validated value
188             bh[0] = self._share_hash
189             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
190
191         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
192             # log.WEIRD: indicates undetected disk/network error, or more
193             # likely a programming error
194             log.msg("hash failure in block=%d, shnum=%d on %s" %
195                     (blocknum, self.sharenum, self.bucket))
196             if self._share_hash:
197                 log.msg(""" failure occurred when checking the block_hash_tree.
198                 This suggests that either the block data was bad, or that the
199                 block hashes we received along with it were bad.""")
200             else:
201                 log.msg(""" the failure probably occurred when checking the
202                 share_hash_tree, which suggests that the share hashes we
203                 received from the remote peer were bad.""")
204             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
205             log.msg(" block length: %d" % len(blockdata))
206             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
207             if len(blockdata) < 100:
208                 log.msg(" block data: %r" % (blockdata,))
209             else:
210                 log.msg(" block data start/end: %r .. %r" %
211                         (blockdata[:50], blockdata[-50:]))
212             log.msg(" root hash: %s" % base32.b2a(self._roothash))
213             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
214             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
215             lines = []
216             for i,h in sorted(sharehashes):
217                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
218             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
219             lines = []
220             for i,h in enumerate(blockhashes):
221                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
222             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
223             raise
224
225         # If we made it here, the block is good. If the hash trees didn't
226         # like what they saw, they would have raised a BadHashError, causing
227         # our caller to see a Failure and thus ignore this block (as well as
228         # dropping this bucket).
229         return blockdata
230
231
232
233 class BlockDownloader:
234     """I am responsible for downloading a single block (from a single bucket)
235     for a single segment.
236
237     I am a child of the SegmentDownloader.
238     """
239
240     def __init__(self, vbucket, blocknum, parent):
241         self.vbucket = vbucket
242         self.blocknum = blocknum
243         self.parent = parent
244         self._log_number = self.parent.log("starting block %d" % blocknum)
245
246     def log(self, msg, parent=None):
247         if parent is None:
248             parent = self._log_number
249         return self.parent.log(msg, parent=parent)
250
251     def start(self, segnum):
252         lognum = self.log("get_block(segnum=%d)" % segnum)
253         d = self.vbucket.get_block(segnum)
254         d.addCallbacks(self._hold_block, self._got_block_error,
255                        callbackArgs=(lognum,), errbackArgs=(lognum,))
256         return d
257
258     def _hold_block(self, data, lognum):
259         self.log("got block", parent=lognum)
260         self.parent.hold_block(self.blocknum, data)
261
262     def _got_block_error(self, f, lognum):
263         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
264                  parent=lognum)
265         self.parent.bucket_failed(self.vbucket)
266
267 class SegmentDownloader:
268     """I am responsible for downloading all the blocks for a single segment
269     of data.
270
271     I am a child of the FileDownloader.
272     """
273
274     def __init__(self, parent, segmentnumber, needed_shares):
275         self.parent = parent
276         self.segmentnumber = segmentnumber
277         self.needed_blocks = needed_shares
278         self.blocks = {} # k: blocknum, v: data
279         self._log_number = self.parent.log("starting segment %d" %
280                                            segmentnumber)
281
282     def log(self, msg, parent=None):
283         if parent is None:
284             parent = self._log_number
285         return self.parent.log(msg, parent=parent)
286
287     def start(self):
288         return self._download()
289
290     def _download(self):
291         d = self._try()
292         def _done(res):
293             if len(self.blocks) >= self.needed_blocks:
294                 # we only need self.needed_blocks blocks
295                 # we want to get the smallest blockids, because they are
296                 # more likely to be fast "primary blocks"
297                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
298                 blocks = []
299                 for blocknum in blockids:
300                     blocks.append(self.blocks[blocknum])
301                 return (blocks, blockids)
302             else:
303                 return self._download()
304         d.addCallback(_done)
305         return d
306
307     def _try(self):
308         # fill our set of active buckets, maybe raising NotEnoughPeersError
309         active_buckets = self.parent._activate_enough_buckets()
310         # Now we have enough buckets, in self.parent.active_buckets.
311
312         # in test cases, bd.start might mutate active_buckets right away, so
313         # we need to put off calling start() until we've iterated all the way
314         # through it.
315         downloaders = []
316         for blocknum, vbucket in active_buckets.iteritems():
317             bd = BlockDownloader(vbucket, blocknum, self)
318             downloaders.append(bd)
319         l = [bd.start(self.segmentnumber) for bd in downloaders]
320         return defer.DeferredList(l, fireOnOneErrback=True)
321
322     def hold_block(self, blocknum, data):
323         self.blocks[blocknum] = data
324
325     def bucket_failed(self, vbucket):
326         self.parent.bucket_failed(vbucket)
327
328 class DownloadStatus:
329     implements(IDownloadStatus)
330     statusid_counter = itertools.count(0)
331
332     def __init__(self):
333         self.storage_index = None
334         self.size = None
335         self.helper = False
336         self.status = "Not started"
337         self.progress = 0.0
338         self.paused = False
339         self.stopped = False
340         self.active = True
341         self.counter = self.statusid_counter.next()
342
343     def get_storage_index(self):
344         return self.storage_index
345     def get_size(self):
346         return self.size
347     def using_helper(self):
348         return self.helper
349     def get_status(self):
350         status = self.status
351         if self.paused:
352             status += " (output paused)"
353         if self.stopped:
354             status += " (output stopped)"
355         return status
356     def get_progress(self):
357         return self.progress
358     def get_active(self):
359         return self.active
360     def get_counter(self):
361         return self.counter
362
363     def set_storage_index(self, si):
364         self.storage_index = si
365     def set_size(self, size):
366         self.size = size
367     def set_helper(self, helper):
368         self.helper = helper
369     def set_status(self, status):
370         self.status = status
371     def set_paused(self, paused):
372         self.paused = paused
373     def set_stopped(self, stopped):
374         self.stopped = stopped
375     def set_progress(self, value):
376         self.progress = value
377     def set_active(self, value):
378         self.active = value
379
380 class FileDownloader:
381     implements(IPushProducer)
382     check_crypttext_hash = True
383     check_plaintext_hash = True
384     _status = None
385
386     def __init__(self, client, u, downloadable):
387         self._client = client
388
389         u = IFileURI(u)
390         self._storage_index = u.storage_index
391         self._uri_extension_hash = u.uri_extension_hash
392         self._total_shares = u.total_shares
393         self._size = u.size
394         self._num_needed_shares = u.needed_shares
395
396         self._si_s = storage.si_b2a(self._storage_index)
397         self.init_logging()
398
399         self._status = s = DownloadStatus()
400         s.set_status("Starting")
401         s.set_storage_index(self._storage_index)
402         s.set_size(self._size)
403         s.set_helper(False)
404         s.set_active(True)
405
406         if IConsumer.providedBy(downloadable):
407             downloadable.registerProducer(self, True)
408         self._downloadable = downloadable
409         self._output = Output(downloadable, u.key, self._size, self._log_number,
410                               self._status)
411         self._paused = False
412         self._stopped = False
413
414         self.active_buckets = {} # k: shnum, v: bucket
415         self._share_buckets = [] # list of (sharenum, bucket) tuples
416         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
417         self._uri_extension_sources = []
418
419         self._uri_extension_data = None
420
421         self._fetch_failures = {"uri_extension": 0,
422                                 "plaintext_hashroot": 0,
423                                 "plaintext_hashtree": 0,
424                                 "crypttext_hashroot": 0,
425                                 "crypttext_hashtree": 0,
426                                 }
427
428     def init_logging(self):
429         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
430         num = self._client.log(format="FileDownloader(%(si)s): starting",
431                                si=storage.si_b2a(self._storage_index))
432         self._log_number = num
433
434     def log(self, *args, **kwargs):
435         if "parent" not in kwargs:
436             kwargs["parent"] = self._log_number
437         if "facility" not in kwargs:
438             kwargs["facility"] = "tahoe.download"
439         return log.msg(*args, **kwargs)
440
441     def pauseProducing(self):
442         if self._paused:
443             return
444         self._paused = defer.Deferred()
445         if self._status:
446             self._status.set_paused(True)
447
448     def resumeProducing(self):
449         if self._paused:
450             p = self._paused
451             self._paused = None
452             eventually(p.callback, None)
453             if self._status:
454                 self._status.set_paused(False)
455
456     def stopProducing(self):
457         self.log("Download.stopProducing")
458         self._stopped = True
459         if self._status:
460             self._status.set_stopped(True)
461             self._status.set_active(False)
462
463     def start(self):
464         self.log("starting download")
465
466         # first step: who should we download from?
467         d = defer.maybeDeferred(self._get_all_shareholders)
468         d.addCallback(self._got_all_shareholders)
469         # now get the uri_extension block from somebody and validate it
470         d.addCallback(self._obtain_uri_extension)
471         d.addCallback(self._got_uri_extension)
472         d.addCallback(self._get_hashtrees)
473         d.addCallback(self._create_validated_buckets)
474         # once we know that, we can download blocks from everybody
475         d.addCallback(self._download_all_segments)
476         def _finished(res):
477             if self._status:
478                 self._status.set_status("Finished")
479                 self._status.set_active(False)
480             if IConsumer.providedBy(self._downloadable):
481                 self._downloadable.unregisterProducer()
482             return res
483         d.addBoth(_finished)
484         def _failed(why):
485             if self._status:
486                 self._status.set_status("Failed")
487                 self._status.set_active(False)
488             self._output.fail(why)
489             return why
490         d.addErrback(_failed)
491         d.addCallback(self._done)
492         return d
493
494     def _get_all_shareholders(self):
495         dl = []
496         for (peerid,ss) in self._client.get_permuted_peers("storage",
497                                                            self._storage_index):
498             peerid_s = idlib.shortnodeid_b2a(peerid)
499             d = ss.callRemote("get_buckets", self._storage_index)
500             d.addCallbacks(self._got_response, self._got_error,
501                            callbackArgs=(peerid_s,))
502             dl.append(d)
503         self._responses_received = 0
504         self._queries_sent = len(dl)
505         if self._status:
506             self._status.set_status("Locating Shares (%d/%d)" %
507                                     (self._responses_received,
508                                      self._queries_sent))
509         return defer.DeferredList(dl)
510
511     def _got_response(self, buckets, peerid_s):
512         self._responses_received += 1
513         if self._status:
514             self._status.set_status("Locating Shares (%d/%d)" %
515                                     (self._responses_received,
516                                      self._queries_sent))
517         for sharenum, bucket in buckets.iteritems():
518             b = storage.ReadBucketProxy(bucket, peerid_s, self._si_s)
519             self.add_share_bucket(sharenum, b)
520             self._uri_extension_sources.append(b)
521
522     def add_share_bucket(self, sharenum, bucket):
523         # this is split out for the benefit of test_encode.py
524         self._share_buckets.append( (sharenum, bucket) )
525
526     def _got_error(self, f):
527         self._client.log("Somebody failed. -- %s" % (f,))
528
529     def bucket_failed(self, vbucket):
530         shnum = vbucket.sharenum
531         del self.active_buckets[shnum]
532         s = self._share_vbuckets[shnum]
533         # s is a set of ValidatedBucket instances
534         s.remove(vbucket)
535         # ... which might now be empty
536         if not s:
537             # there are no more buckets which can provide this share, so
538             # remove the key. This may prompt us to use a different share.
539             del self._share_vbuckets[shnum]
540
541     def _got_all_shareholders(self, res):
542         if len(self._share_buckets) < self._num_needed_shares:
543             raise NotEnoughPeersError
544
545         #for s in self._share_vbuckets.values():
546         #    for vb in s:
547         #        assert isinstance(vb, ValidatedBucket), \
548         #               "vb is %s but should be a ValidatedBucket" % (vb,)
549
550     def _unpack_uri_extension_data(self, data):
551         return uri.unpack_extension(data)
552
553     def _obtain_uri_extension(self, ignored):
554         # all shareholders are supposed to have a copy of uri_extension, and
555         # all are supposed to be identical. We compute the hash of the data
556         # that comes back, and compare it against the version in our URI. If
557         # they don't match, ignore their data and try someone else.
558         if self._status:
559             self._status.set_status("Obtaining URI Extension")
560
561         def _validate(proposal, bucket):
562             h = hashutil.uri_extension_hash(proposal)
563             if h != self._uri_extension_hash:
564                 self._fetch_failures["uri_extension"] += 1
565                 msg = ("The copy of uri_extension we received from "
566                        "%s was bad: wanted %s, got %s" %
567                        (bucket,
568                         base32.b2a(self._uri_extension_hash),
569                         base32.b2a(h)))
570                 self.log(msg, level=log.SCARY)
571                 raise BadURIExtensionHashValue(msg)
572             return self._unpack_uri_extension_data(proposal)
573         return self._obtain_validated_thing(None,
574                                             self._uri_extension_sources,
575                                             "uri_extension",
576                                             "get_uri_extension", (), _validate)
577
578     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
579                                 validatorfunc):
580         if not sources:
581             raise NotEnoughPeersError("started with zero peers while fetching "
582                                       "%s" % name)
583         bucket = sources[0]
584         sources = sources[1:]
585         #d = bucket.callRemote(methname, *args)
586         d = bucket.startIfNecessary()
587         d.addCallback(lambda res: getattr(bucket, methname)(*args))
588         d.addCallback(validatorfunc, bucket)
589         def _bad(f):
590             self.log("%s from vbucket %s failed:" % (name, bucket),
591                      failure=f, level=log.WEIRD)
592             if not sources:
593                 raise NotEnoughPeersError("ran out of peers, last error was %s"
594                                           % (f,))
595             # try again with a different one
596             return self._obtain_validated_thing(None, sources, name,
597                                                 methname, args, validatorfunc)
598         d.addErrback(_bad)
599         return d
600
601     def _got_uri_extension(self, uri_extension_data):
602         d = self._uri_extension_data = uri_extension_data
603
604         self._codec = codec.get_decoder_by_name(d['codec_name'])
605         self._codec.set_serialized_params(d['codec_params'])
606         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
607         self._tail_codec.set_serialized_params(d['tail_codec_params'])
608
609         crypttext_hash = d['crypttext_hash']
610         assert isinstance(crypttext_hash, str)
611         assert len(crypttext_hash) == 32
612         self._crypttext_hash = crypttext_hash
613         self._plaintext_hash = d['plaintext_hash']
614         self._roothash = d['share_root_hash']
615
616         self._segment_size = segment_size = d['segment_size']
617         self._total_segments = mathutil.div_ceil(self._size, segment_size)
618         self._current_segnum = 0
619
620         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
621         self._share_hashtree.set_hashes({0: self._roothash})
622
623     def _get_hashtrees(self, res):
624         if self._status:
625             self._status.set_status("Retrieving Hash Trees")
626         d = self._get_plaintext_hashtrees()
627         d.addCallback(self._get_crypttext_hashtrees)
628         d.addCallback(self._setup_hashtrees)
629         return d
630
631     def _get_plaintext_hashtrees(self):
632         def _validate_plaintext_hashtree(proposal, bucket):
633             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
634                 self._fetch_failures["plaintext_hashroot"] += 1
635                 msg = ("The copy of the plaintext_root_hash we received from"
636                        " %s was bad" % bucket)
637                 raise BadPlaintextHashValue(msg)
638             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
639             pt_hashes = dict(list(enumerate(proposal)))
640             try:
641                 pt_hashtree.set_hashes(pt_hashes)
642             except hashtree.BadHashError:
643                 # the hashes they gave us were not self-consistent, even
644                 # though the root matched what we saw in the uri_extension
645                 # block
646                 self._fetch_failures["plaintext_hashtree"] += 1
647                 raise
648             self._plaintext_hashtree = pt_hashtree
649         d = self._obtain_validated_thing(None,
650                                          self._uri_extension_sources,
651                                          "plaintext_hashes",
652                                          "get_plaintext_hashes", (),
653                                          _validate_plaintext_hashtree)
654         return d
655
656     def _get_crypttext_hashtrees(self, res):
657         def _validate_crypttext_hashtree(proposal, bucket):
658             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
659                 self._fetch_failures["crypttext_hashroot"] += 1
660                 msg = ("The copy of the crypttext_root_hash we received from"
661                        " %s was bad" % bucket)
662                 raise BadCrypttextHashValue(msg)
663             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
664             ct_hashes = dict(list(enumerate(proposal)))
665             try:
666                 ct_hashtree.set_hashes(ct_hashes)
667             except hashtree.BadHashError:
668                 self._fetch_failures["crypttext_hashtree"] += 1
669                 raise
670             ct_hashtree.set_hashes(ct_hashes)
671             self._crypttext_hashtree = ct_hashtree
672         d = self._obtain_validated_thing(None,
673                                          self._uri_extension_sources,
674                                          "crypttext_hashes",
675                                          "get_crypttext_hashes", (),
676                                          _validate_crypttext_hashtree)
677         return d
678
679     def _setup_hashtrees(self, res):
680         self._output.setup_hashtrees(self._plaintext_hashtree,
681                                      self._crypttext_hashtree)
682
683
684     def _create_validated_buckets(self, ignored=None):
685         self._share_vbuckets = {}
686         for sharenum, bucket in self._share_buckets:
687             vbucket = ValidatedBucket(sharenum, bucket,
688                                       self._share_hashtree,
689                                       self._roothash,
690                                       self._total_segments)
691             s = self._share_vbuckets.setdefault(sharenum, set())
692             s.add(vbucket)
693
694     def _activate_enough_buckets(self):
695         """either return a mapping from shnum to a ValidatedBucket that can
696         provide data for that share, or raise NotEnoughPeersError"""
697
698         while len(self.active_buckets) < self._num_needed_shares:
699             # need some more
700             handled_shnums = set(self.active_buckets.keys())
701             available_shnums = set(self._share_vbuckets.keys())
702             potential_shnums = list(available_shnums - handled_shnums)
703             if not potential_shnums:
704                 raise NotEnoughPeersError
705             # choose a random share
706             shnum = random.choice(potential_shnums)
707             # and a random bucket that will provide it
708             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
709             self.active_buckets[shnum] = validated_bucket
710         return self.active_buckets
711
712
713     def _download_all_segments(self, res):
714         # the promise: upon entry to this function, self._share_vbuckets
715         # contains enough buckets to complete the download, and some extra
716         # ones to tolerate some buckets dropping out or having errors.
717         # self._share_vbuckets is a dictionary that maps from shnum to a set
718         # of ValidatedBuckets, which themselves are wrappers around
719         # RIBucketReader references.
720         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
721
722         d = defer.succeed(None)
723         for segnum in range(self._total_segments-1):
724             d.addCallback(self._download_segment, segnum)
725             # this pause, at the end of write, prevents pre-fetch from
726             # happening until the consumer is ready for more data.
727             d.addCallback(self._check_for_pause)
728         d.addCallback(self._download_tail_segment, self._total_segments-1)
729         return d
730
731     def _check_for_pause(self, res):
732         if self._paused:
733             d = defer.Deferred()
734             self._paused.addCallback(lambda ignored: d.callback(res))
735             return d
736         if self._stopped:
737             raise DownloadStopped("our Consumer called stopProducing()")
738         return res
739
740     def _download_segment(self, res, segnum):
741         if self._status:
742             self._status.set_status("Downloading segment %d of %d" %
743                                     (segnum+1, self._total_segments))
744         self.log("downloading seg#%d of %d (%d%%)"
745                  % (segnum, self._total_segments,
746                     100.0 * segnum / self._total_segments))
747         # memory footprint: when the SegmentDownloader finishes pulling down
748         # all shares, we have 1*segment_size of usage.
749         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
750         d = segmentdler.start()
751         # pause before using more memory
752         d.addCallback(self._check_for_pause)
753         # while the codec does its job, we hit 2*segment_size
754         d.addCallback(lambda (shares, shareids):
755                       self._codec.decode(shares, shareids))
756         # once the codec is done, we drop back to 1*segment_size, because
757         # 'shares' goes out of scope. The memory usage is all in the
758         # plaintext now, spread out into a bunch of tiny buffers.
759
760         # pause/check-for-stop just before writing, to honor stopProducing
761         d.addCallback(self._check_for_pause)
762         def _done(buffers):
763             # we start by joining all these buffers together into a single
764             # string. This makes Output.write easier, since it wants to hash
765             # data one segment at a time anyways, and doesn't impact our
766             # memory footprint since we're already peaking at 2*segment_size
767             # inside the codec a moment ago.
768             segment = "".join(buffers)
769             del buffers
770             # we're down to 1*segment_size right now, but write_segment()
771             # will decrypt a copy of the segment internally, which will push
772             # us up to 2*segment_size while it runs.
773             self._output.write_segment(segment)
774         d.addCallback(_done)
775         return d
776
777     def _download_tail_segment(self, res, segnum):
778         self.log("downloading seg#%d of %d (%d%%)"
779                  % (segnum, self._total_segments,
780                     100.0 * segnum / self._total_segments))
781         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
782         d = segmentdler.start()
783         # pause before using more memory
784         d.addCallback(self._check_for_pause)
785         d.addCallback(lambda (shares, shareids):
786                       self._tail_codec.decode(shares, shareids))
787         # pause/check-for-stop just before writing, to honor stopProducing
788         d.addCallback(self._check_for_pause)
789         def _done(buffers):
790             # trim off any padding added by the upload side
791             segment = "".join(buffers)
792             del buffers
793             # we never send empty segments. If the data was an exact multiple
794             # of the segment size, the last segment will be full.
795             pad_size = mathutil.pad_size(self._size, self._segment_size)
796             tail_size = self._segment_size - pad_size
797             segment = segment[:tail_size]
798             self._output.write_segment(segment)
799         d.addCallback(_done)
800         return d
801
802     def _done(self, res):
803         self.log("download done")
804         self._output.close()
805         if self.check_crypttext_hash:
806             _assert(self._crypttext_hash == self._output.crypttext_hash,
807                     "bad crypttext_hash: computed=%s, expected=%s" %
808                     (base32.b2a(self._output.crypttext_hash),
809                      base32.b2a(self._crypttext_hash)))
810         if self.check_plaintext_hash:
811             _assert(self._plaintext_hash == self._output.plaintext_hash,
812                     "bad plaintext_hash: computed=%s, expected=%s" %
813                     (base32.b2a(self._output.plaintext_hash),
814                      base32.b2a(self._plaintext_hash)))
815         _assert(self._output.length == self._size,
816                 got=self._output.length, expected=self._size)
817         return self._output.finish()
818
819     def get_download_status(self):
820         return self._status
821
822
823 class LiteralDownloader:
824     def __init__(self, client, u, downloadable):
825         self._uri = IFileURI(u)
826         assert isinstance(self._uri, uri.LiteralFileURI)
827         self._downloadable = downloadable
828         self._status = s = DownloadStatus()
829         s.set_storage_index(None)
830         s.set_helper(False)
831         s.set_status("Done")
832         s.set_active(False)
833         s.set_progress(1.0)
834
835     def start(self):
836         data = self._uri.data
837         self._status.set_size(len(data))
838         self._downloadable.open(len(data))
839         self._downloadable.write(data)
840         self._downloadable.close()
841         return defer.maybeDeferred(self._downloadable.finish)
842
843     def get_download_status(self):
844         return self._status
845
846 class FileName:
847     implements(IDownloadTarget)
848     def __init__(self, filename):
849         self._filename = filename
850         self.f = None
851     def open(self, size):
852         self.f = open(self._filename, "wb")
853         return self.f
854     def write(self, data):
855         self.f.write(data)
856     def close(self):
857         if self.f:
858             self.f.close()
859     def fail(self, why):
860         if self.f:
861             self.f.close()
862             os.unlink(self._filename)
863     def register_canceller(self, cb):
864         pass # we won't use it
865     def finish(self):
866         pass
867
868 class Data:
869     implements(IDownloadTarget)
870     def __init__(self):
871         self._data = []
872     def open(self, size):
873         pass
874     def write(self, data):
875         self._data.append(data)
876     def close(self):
877         self.data = "".join(self._data)
878         del self._data
879     def fail(self, why):
880         del self._data
881     def register_canceller(self, cb):
882         pass # we won't use it
883     def finish(self):
884         return self.data
885
886 class FileHandle:
887     """Use me to download data to a pre-defined filehandle-like object. I
888     will use the target's write() method. I will *not* close the filehandle:
889     I leave that up to the originator of the filehandle. The download process
890     will return the filehandle when it completes.
891     """
892     implements(IDownloadTarget)
893     def __init__(self, filehandle):
894         self._filehandle = filehandle
895     def open(self, size):
896         pass
897     def write(self, data):
898         self._filehandle.write(data)
899     def close(self):
900         # the originator of the filehandle reserves the right to close it
901         pass
902     def fail(self, why):
903         pass
904     def register_canceller(self, cb):
905         pass
906     def finish(self):
907         return self._filehandle
908
909 class Downloader(service.MultiService):
910     """I am a service that allows file downloading.
911     """
912     implements(IDownloader)
913     name = "downloader"
914     MAX_DOWNLOAD_STATUSES = 10
915
916     def __init__(self):
917         service.MultiService.__init__(self)
918         self._all_downloads = weakref.WeakKeyDictionary()
919         self._recent_download_status = []
920
921     def download(self, u, t):
922         assert self.parent
923         assert self.running
924         u = IFileURI(u)
925         t = IDownloadTarget(t)
926         assert t.write
927         assert t.close
928         if isinstance(u, uri.LiteralFileURI):
929             dl = LiteralDownloader(self.parent, u, t)
930         elif isinstance(u, uri.CHKFileURI):
931             dl = FileDownloader(self.parent, u, t)
932         else:
933             raise RuntimeError("I don't know how to download a %s" % u)
934         self._all_downloads[dl] = None
935         self._recent_download_status.append(dl.get_download_status())
936         while len(self._recent_download_status) > self.MAX_DOWNLOAD_STATUSES:
937             self._recent_download_status.pop(0)
938         d = dl.start()
939         return d
940
941     # utility functions
942     def download_to_data(self, uri):
943         return self.download(uri, Data())
944     def download_to_filename(self, uri, filename):
945         return self.download(uri, FileName(filename))
946     def download_to_filehandle(self, uri, filehandle):
947         return self.download(uri, FileHandle(filehandle))
948
949
950     def list_all_downloads(self):
951         return self._all_downloads.keys()
952     def list_active_downloads(self):
953         return [d.get_download_status() for d in self._all_downloads.keys()
954                 if d.get_download_status().get_active()]
955     def list_recent_downloads(self):
956         return self._recent_download_status