]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
12f94bc9f419244f53b1f8cf588ac7f986bdbeb7
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughSharesError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class DownloadResults:
32     implements(IDownloadResults)
33
34     def __init__(self):
35         self.servers_used = set()
36         self.server_problems = {}
37         self.servermap = {}
38         self.timings = {}
39         self.file_size = None
40
41 class Output:
42     def __init__(self, downloadable, key, total_length, log_parent,
43                  download_status):
44         self.downloadable = downloadable
45         self._decryptor = AES(key)
46         self._crypttext_hasher = hashutil.crypttext_hasher()
47         self._plaintext_hasher = hashutil.plaintext_hasher()
48         self.length = 0
49         self.total_length = total_length
50         self._segment_number = 0
51         self._plaintext_hash_tree = None
52         self._crypttext_hash_tree = None
53         self._opened = False
54         self._log_parent = log_parent
55         self._status = download_status
56         self._status.set_progress(0.0)
57
58     def log(self, *args, **kwargs):
59         if "parent" not in kwargs:
60             kwargs["parent"] = self._log_parent
61         if "facility" not in kwargs:
62             kwargs["facility"] = "download.output"
63         return log.msg(*args, **kwargs)
64
65     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66         self._plaintext_hash_tree = plaintext_hashtree
67         self._crypttext_hash_tree = crypttext_hashtree
68
69     def write_segment(self, crypttext):
70         self.length += len(crypttext)
71         self._status.set_progress( float(self.length) / self.total_length )
72
73         # memory footprint: 'crypttext' is the only segment_size usage
74         # outstanding. While we decrypt it into 'plaintext', we hit
75         # 2*segment_size.
76         self._crypttext_hasher.update(crypttext)
77         if self._crypttext_hash_tree:
78             ch = hashutil.crypttext_segment_hasher()
79             ch.update(crypttext)
80             crypttext_leaves = {self._segment_number: ch.digest()}
81             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
82                      bytes=len(crypttext),
83                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
84                      level=log.NOISY)
85             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
86
87         plaintext = self._decryptor.process(crypttext)
88         del crypttext
89
90         # now we're back down to 1*segment_size.
91
92         self._plaintext_hasher.update(plaintext)
93         if self._plaintext_hash_tree:
94             ph = hashutil.plaintext_segment_hasher()
95             ph.update(plaintext)
96             plaintext_leaves = {self._segment_number: ph.digest()}
97             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
98                      bytes=len(plaintext),
99                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
100                      level=log.NOISY)
101             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
102
103         self._segment_number += 1
104         # We're still at 1*segment_size. The Downloadable is responsible for
105         # any memory usage beyond this.
106         if not self._opened:
107             self._opened = True
108             self.downloadable.open(self.total_length)
109         self.downloadable.write(plaintext)
110
111     def fail(self, why):
112         # this is really unusual, and deserves maximum forensics
113         if why.check(DownloadStopped):
114             # except DownloadStopped just means the consumer aborted the
115             # download, not so scary
116             self.log("download stopped", level=log.UNUSUAL)
117         else:
118             self.log("download failed!", failure=why, level=log.SCARY)
119         self.downloadable.fail(why)
120
121     def close(self):
122         self.crypttext_hash = self._crypttext_hasher.digest()
123         self.plaintext_hash = self._plaintext_hasher.digest()
124         self.log("download finished, closing IDownloadable", level=log.NOISY)
125         self.downloadable.close()
126
127     def finish(self):
128         return self.downloadable.finish()
129
130 class ValidatedBucket:
131     """I am a front-end for a remote storage bucket, responsible for
132     retrieving and validating data from that bucket.
133
134     My get_block() method is used by BlockDownloaders.
135     """
136
137     def __init__(self, sharenum, bucket,
138                  share_hash_tree, roothash,
139                  num_blocks):
140         self.sharenum = sharenum
141         self.bucket = bucket
142         self._share_hash = None # None means not validated yet
143         self.share_hash_tree = share_hash_tree
144         self._roothash = roothash
145         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
146         self.started = False
147
148     def get_block(self, blocknum):
149         if not self.started:
150             d = self.bucket.start()
151             def _started(res):
152                 self.started = True
153                 return self.get_block(blocknum)
154             d.addCallback(_started)
155             return d
156
157         # the first time we use this bucket, we need to fetch enough elements
158         # of the share hash tree to validate it from our share hash up to the
159         # hashroot.
160         if not self._share_hash:
161             d1 = self.bucket.get_share_hashes()
162         else:
163             d1 = defer.succeed([])
164
165         # we might need to grab some elements of our block hash tree, to
166         # validate the requested block up to the share hash
167         needed = self.block_hash_tree.needed_hashes(blocknum)
168         if needed:
169             # TODO: get fewer hashes, use get_block_hashes(needed)
170             d2 = self.bucket.get_block_hashes()
171         else:
172             d2 = defer.succeed([])
173
174         d3 = self.bucket.get_block(blocknum)
175
176         d = defer.gatherResults([d1, d2, d3])
177         d.addCallback(self._got_data, blocknum)
178         return d
179
180     def _got_data(self, res, blocknum):
181         sharehashes, blockhashes, blockdata = res
182         blockhash = None # to make logging it safe
183
184         try:
185             if not self._share_hash:
186                 sh = dict(sharehashes)
187                 sh[0] = self._roothash # always use our own root, from the URI
188                 sht = self.share_hash_tree
189                 if sht.get_leaf_index(self.sharenum) not in sh:
190                     raise hashtree.NotEnoughHashesError
191                 sht.set_hashes(sh)
192                 self._share_hash = sht.get_leaf(self.sharenum)
193
194             blockhash = hashutil.block_hash(blockdata)
195             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
196             #        "%r .. %r: %s" %
197             #        (self.sharenum, blocknum, len(blockdata),
198             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
199
200             # we always validate the blockhash
201             bh = dict(enumerate(blockhashes))
202             # replace blockhash root with validated value
203             bh[0] = self._share_hash
204             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
205
206         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
207             # log.WEIRD: indicates undetected disk/network error, or more
208             # likely a programming error
209             log.msg("hash failure in block=%d, shnum=%d on %s" %
210                     (blocknum, self.sharenum, self.bucket))
211             if self._share_hash:
212                 log.msg(""" failure occurred when checking the block_hash_tree.
213                 This suggests that either the block data was bad, or that the
214                 block hashes we received along with it were bad.""")
215             else:
216                 log.msg(""" the failure probably occurred when checking the
217                 share_hash_tree, which suggests that the share hashes we
218                 received from the remote peer were bad.""")
219             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
220             log.msg(" block length: %d" % len(blockdata))
221             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
222             if len(blockdata) < 100:
223                 log.msg(" block data: %r" % (blockdata,))
224             else:
225                 log.msg(" block data start/end: %r .. %r" %
226                         (blockdata[:50], blockdata[-50:]))
227             log.msg(" root hash: %s" % base32.b2a(self._roothash))
228             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
229             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
230             lines = []
231             for i,h in sorted(sharehashes):
232                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
233             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
234             lines = []
235             for i,h in enumerate(blockhashes):
236                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
237             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
238             raise
239
240         # If we made it here, the block is good. If the hash trees didn't
241         # like what they saw, they would have raised a BadHashError, causing
242         # our caller to see a Failure and thus ignore this block (as well as
243         # dropping this bucket).
244         return blockdata
245
246
247
248 class BlockDownloader:
249     """I am responsible for downloading a single block (from a single bucket)
250     for a single segment.
251
252     I am a child of the SegmentDownloader.
253     """
254
255     def __init__(self, vbucket, blocknum, parent, results):
256         self.vbucket = vbucket
257         self.blocknum = blocknum
258         self.parent = parent
259         self.results = results
260         self._log_number = self.parent.log("starting block %d" % blocknum)
261
262     def log(self, msg, parent=None):
263         if parent is None:
264             parent = self._log_number
265         return self.parent.log(msg, parent=parent)
266
267     def start(self, segnum):
268         lognum = self.log("get_block(segnum=%d)" % segnum)
269         started = time.time()
270         d = self.vbucket.get_block(segnum)
271         d.addCallbacks(self._hold_block, self._got_block_error,
272                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
273         return d
274
275     def _hold_block(self, data, started, lognum):
276         if self.results:
277             elapsed = time.time() - started
278             peerid = self.vbucket.bucket.get_peerid()
279             if peerid not in self.results.timings["fetch_per_server"]:
280                 self.results.timings["fetch_per_server"][peerid] = []
281             self.results.timings["fetch_per_server"][peerid].append(elapsed)
282         self.log("got block", parent=lognum)
283         self.parent.hold_block(self.blocknum, data)
284
285     def _got_block_error(self, f, lognum):
286         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
287                  parent=lognum)
288         if self.results:
289             peerid = self.vbucket.bucket.get_peerid()
290             self.results.server_problems[peerid] = str(f)
291         self.parent.bucket_failed(self.vbucket)
292
293 class SegmentDownloader:
294     """I am responsible for downloading all the blocks for a single segment
295     of data.
296
297     I am a child of the FileDownloader.
298     """
299
300     def __init__(self, parent, segmentnumber, needed_shares, results):
301         self.parent = parent
302         self.segmentnumber = segmentnumber
303         self.needed_blocks = needed_shares
304         self.blocks = {} # k: blocknum, v: data
305         self.results = results
306         self._log_number = self.parent.log("starting segment %d" %
307                                            segmentnumber)
308
309     def log(self, msg, parent=None):
310         if parent is None:
311             parent = self._log_number
312         return self.parent.log(msg, parent=parent)
313
314     def start(self):
315         return self._download()
316
317     def _download(self):
318         d = self._try()
319         def _done(res):
320             if len(self.blocks) >= self.needed_blocks:
321                 # we only need self.needed_blocks blocks
322                 # we want to get the smallest blockids, because they are
323                 # more likely to be fast "primary blocks"
324                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
325                 blocks = []
326                 for blocknum in blockids:
327                     blocks.append(self.blocks[blocknum])
328                 return (blocks, blockids)
329             else:
330                 return self._download()
331         d.addCallback(_done)
332         return d
333
334     def _try(self):
335         # fill our set of active buckets, maybe raising NotEnoughSharesError
336         active_buckets = self.parent._activate_enough_buckets()
337         # Now we have enough buckets, in self.parent.active_buckets.
338
339         # in test cases, bd.start might mutate active_buckets right away, so
340         # we need to put off calling start() until we've iterated all the way
341         # through it.
342         downloaders = []
343         for blocknum, vbucket in active_buckets.iteritems():
344             bd = BlockDownloader(vbucket, blocknum, self, self.results)
345             downloaders.append(bd)
346             if self.results:
347                 self.results.servers_used.add(vbucket.bucket.get_peerid())
348         l = [bd.start(self.segmentnumber) for bd in downloaders]
349         return defer.DeferredList(l, fireOnOneErrback=True)
350
351     def hold_block(self, blocknum, data):
352         self.blocks[blocknum] = data
353
354     def bucket_failed(self, vbucket):
355         self.parent.bucket_failed(vbucket)
356
357 class DownloadStatus:
358     implements(IDownloadStatus)
359     statusid_counter = itertools.count(0)
360
361     def __init__(self):
362         self.storage_index = None
363         self.size = None
364         self.helper = False
365         self.status = "Not started"
366         self.progress = 0.0
367         self.paused = False
368         self.stopped = False
369         self.active = True
370         self.results = None
371         self.counter = self.statusid_counter.next()
372         self.started = time.time()
373
374     def get_started(self):
375         return self.started
376     def get_storage_index(self):
377         return self.storage_index
378     def get_size(self):
379         return self.size
380     def using_helper(self):
381         return self.helper
382     def get_status(self):
383         status = self.status
384         if self.paused:
385             status += " (output paused)"
386         if self.stopped:
387             status += " (output stopped)"
388         return status
389     def get_progress(self):
390         return self.progress
391     def get_active(self):
392         return self.active
393     def get_results(self):
394         return self.results
395     def get_counter(self):
396         return self.counter
397
398     def set_storage_index(self, si):
399         self.storage_index = si
400     def set_size(self, size):
401         self.size = size
402     def set_helper(self, helper):
403         self.helper = helper
404     def set_status(self, status):
405         self.status = status
406     def set_paused(self, paused):
407         self.paused = paused
408     def set_stopped(self, stopped):
409         self.stopped = stopped
410     def set_progress(self, value):
411         self.progress = value
412     def set_active(self, value):
413         self.active = value
414     def set_results(self, value):
415         self.results = value
416
417 class FileDownloader:
418     implements(IPushProducer)
419     check_crypttext_hash = True
420     check_plaintext_hash = True
421     _status = None
422
423     def __init__(self, client, u, downloadable):
424         self._client = client
425
426         u = IFileURI(u)
427         self._storage_index = u.storage_index
428         self._uri_extension_hash = u.uri_extension_hash
429         self._total_shares = u.total_shares
430         self._size = u.size
431         self._num_needed_shares = u.needed_shares
432
433         self._si_s = storage.si_b2a(self._storage_index)
434         self.init_logging()
435
436         self._started = time.time()
437         self._status = s = DownloadStatus()
438         s.set_status("Starting")
439         s.set_storage_index(self._storage_index)
440         s.set_size(self._size)
441         s.set_helper(False)
442         s.set_active(True)
443
444         self._results = DownloadResults()
445         s.set_results(self._results)
446         self._results.file_size = self._size
447         self._results.timings["servers_peer_selection"] = {}
448         self._results.timings["fetch_per_server"] = {}
449         self._results.timings["cumulative_fetch"] = 0.0
450         self._results.timings["cumulative_decode"] = 0.0
451         self._results.timings["cumulative_decrypt"] = 0.0
452
453         if IConsumer.providedBy(downloadable):
454             downloadable.registerProducer(self, True)
455         self._downloadable = downloadable
456         self._output = Output(downloadable, u.key, self._size, self._log_number,
457                               self._status)
458         self._paused = False
459         self._stopped = False
460
461         self.active_buckets = {} # k: shnum, v: bucket
462         self._share_buckets = [] # list of (sharenum, bucket) tuples
463         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
464         self._uri_extension_sources = []
465
466         self._uri_extension_data = None
467
468         self._fetch_failures = {"uri_extension": 0,
469                                 "plaintext_hashroot": 0,
470                                 "plaintext_hashtree": 0,
471                                 "crypttext_hashroot": 0,
472                                 "crypttext_hashtree": 0,
473                                 }
474
475     def init_logging(self):
476         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
477         num = self._client.log(format="FileDownloader(%(si)s): starting",
478                                si=storage.si_b2a(self._storage_index))
479         self._log_number = num
480
481     def log(self, *args, **kwargs):
482         if "parent" not in kwargs:
483             kwargs["parent"] = self._log_number
484         if "facility" not in kwargs:
485             kwargs["facility"] = "tahoe.download"
486         return log.msg(*args, **kwargs)
487
488     def pauseProducing(self):
489         if self._paused:
490             return
491         self._paused = defer.Deferred()
492         if self._status:
493             self._status.set_paused(True)
494
495     def resumeProducing(self):
496         if self._paused:
497             p = self._paused
498             self._paused = None
499             eventually(p.callback, None)
500             if self._status:
501                 self._status.set_paused(False)
502
503     def stopProducing(self):
504         self.log("Download.stopProducing")
505         self._stopped = True
506         if self._status:
507             self._status.set_stopped(True)
508             self._status.set_active(False)
509
510     def start(self):
511         self.log("starting download")
512
513         # first step: who should we download from?
514         d = defer.maybeDeferred(self._get_all_shareholders)
515         d.addCallback(self._got_all_shareholders)
516         # now get the uri_extension block from somebody and validate it
517         d.addCallback(self._obtain_uri_extension)
518         d.addCallback(self._got_uri_extension)
519         d.addCallback(self._get_hashtrees)
520         d.addCallback(self._create_validated_buckets)
521         # once we know that, we can download blocks from everybody
522         d.addCallback(self._download_all_segments)
523         def _finished(res):
524             if self._status:
525                 self._status.set_status("Finished")
526                 self._status.set_active(False)
527                 self._status.set_paused(False)
528             if IConsumer.providedBy(self._downloadable):
529                 self._downloadable.unregisterProducer()
530             return res
531         d.addBoth(_finished)
532         def _failed(why):
533             if self._status:
534                 self._status.set_status("Failed")
535                 self._status.set_active(False)
536             self._output.fail(why)
537             return why
538         d.addErrback(_failed)
539         d.addCallback(self._done)
540         return d
541
542     def _get_all_shareholders(self):
543         dl = []
544         for (peerid,ss) in self._client.get_permuted_peers("storage",
545                                                            self._storage_index):
546             d = ss.callRemote("get_buckets", self._storage_index)
547             d.addCallbacks(self._got_response, self._got_error,
548                            callbackArgs=(peerid,))
549             dl.append(d)
550         self._responses_received = 0
551         self._queries_sent = len(dl)
552         if self._status:
553             self._status.set_status("Locating Shares (%d/%d)" %
554                                     (self._responses_received,
555                                      self._queries_sent))
556         return defer.DeferredList(dl)
557
558     def _got_response(self, buckets, peerid):
559         self._responses_received += 1
560         if self._results:
561             elapsed = time.time() - self._started
562             self._results.timings["servers_peer_selection"][peerid] = elapsed
563         if self._status:
564             self._status.set_status("Locating Shares (%d/%d)" %
565                                     (self._responses_received,
566                                      self._queries_sent))
567         for sharenum, bucket in buckets.iteritems():
568             b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
569             self.add_share_bucket(sharenum, b)
570             self._uri_extension_sources.append(b)
571             if self._results:
572                 if peerid not in self._results.servermap:
573                     self._results.servermap[peerid] = set()
574                 self._results.servermap[peerid].add(sharenum)
575
576     def add_share_bucket(self, sharenum, bucket):
577         # this is split out for the benefit of test_encode.py
578         self._share_buckets.append( (sharenum, bucket) )
579
580     def _got_error(self, f):
581         self._client.log("Somebody failed. -- %s" % (f,))
582
583     def bucket_failed(self, vbucket):
584         shnum = vbucket.sharenum
585         del self.active_buckets[shnum]
586         s = self._share_vbuckets[shnum]
587         # s is a set of ValidatedBucket instances
588         s.remove(vbucket)
589         # ... which might now be empty
590         if not s:
591             # there are no more buckets which can provide this share, so
592             # remove the key. This may prompt us to use a different share.
593             del self._share_vbuckets[shnum]
594
595     def _got_all_shareholders(self, res):
596         if self._results:
597             now = time.time()
598             self._results.timings["peer_selection"] = now - self._started
599
600         if len(self._share_buckets) < self._num_needed_shares:
601             raise NotEnoughSharesError
602
603         #for s in self._share_vbuckets.values():
604         #    for vb in s:
605         #        assert isinstance(vb, ValidatedBucket), \
606         #               "vb is %s but should be a ValidatedBucket" % (vb,)
607
608     def _unpack_uri_extension_data(self, data):
609         return uri.unpack_extension(data)
610
611     def _obtain_uri_extension(self, ignored):
612         # all shareholders are supposed to have a copy of uri_extension, and
613         # all are supposed to be identical. We compute the hash of the data
614         # that comes back, and compare it against the version in our URI. If
615         # they don't match, ignore their data and try someone else.
616         if self._status:
617             self._status.set_status("Obtaining URI Extension")
618
619         self._uri_extension_fetch_started = time.time()
620         def _validate(proposal, bucket):
621             h = hashutil.uri_extension_hash(proposal)
622             if h != self._uri_extension_hash:
623                 self._fetch_failures["uri_extension"] += 1
624                 msg = ("The copy of uri_extension we received from "
625                        "%s was bad: wanted %s, got %s" %
626                        (bucket,
627                         base32.b2a(self._uri_extension_hash),
628                         base32.b2a(h)))
629                 self.log(msg, level=log.SCARY)
630                 raise BadURIExtensionHashValue(msg)
631             return self._unpack_uri_extension_data(proposal)
632         return self._obtain_validated_thing(None,
633                                             self._uri_extension_sources,
634                                             "uri_extension",
635                                             "get_uri_extension", (), _validate)
636
637     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
638                                 validatorfunc):
639         if not sources:
640             raise NotEnoughSharesError("started with zero peers while fetching "
641                                       "%s" % name)
642         bucket = sources[0]
643         sources = sources[1:]
644         #d = bucket.callRemote(methname, *args)
645         d = bucket.startIfNecessary()
646         d.addCallback(lambda res: getattr(bucket, methname)(*args))
647         d.addCallback(validatorfunc, bucket)
648         def _bad(f):
649             self.log("%s from vbucket %s failed:" % (name, bucket),
650                      failure=f, level=log.WEIRD)
651             if not sources:
652                 raise NotEnoughSharesError("ran out of peers, last error was %s"
653                                           % (f,))
654             # try again with a different one
655             return self._obtain_validated_thing(None, sources, name,
656                                                 methname, args, validatorfunc)
657         d.addErrback(_bad)
658         return d
659
660     def _got_uri_extension(self, uri_extension_data):
661         if self._results:
662             elapsed = time.time() - self._uri_extension_fetch_started
663             self._results.timings["uri_extension"] = elapsed
664
665         d = self._uri_extension_data = uri_extension_data
666
667         self._codec = codec.get_decoder_by_name(d['codec_name'])
668         self._codec.set_serialized_params(d['codec_params'])
669         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
670         self._tail_codec.set_serialized_params(d['tail_codec_params'])
671
672         crypttext_hash = d.get('crypttext_hash', None) # optional
673         if crypttext_hash:
674             assert isinstance(crypttext_hash, str)
675             assert len(crypttext_hash) == 32
676         self._crypttext_hash = crypttext_hash
677         self._plaintext_hash = d.get('plaintext_hash', None) # optional
678
679         self._roothash = d['share_root_hash']
680
681         self._segment_size = segment_size = d['segment_size']
682         self._total_segments = mathutil.div_ceil(self._size, segment_size)
683         self._current_segnum = 0
684
685         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
686         self._share_hashtree.set_hashes({0: self._roothash})
687
688     def _get_hashtrees(self, res):
689         self._get_hashtrees_started = time.time()
690         if self._status:
691             self._status.set_status("Retrieving Hash Trees")
692         d = defer.maybeDeferred(self._get_plaintext_hashtrees)
693         d.addCallback(self._get_crypttext_hashtrees)
694         d.addCallback(self._setup_hashtrees)
695         return d
696
697     def _get_plaintext_hashtrees(self):
698         # plaintext hashes are optional. If the root isn't in the UEB, then
699         # the share will be holding an empty list. We don't even bother
700         # fetching it.
701         if "plaintext_root_hash" not in self._uri_extension_data:
702             self._plaintext_hashtree = None
703             return
704         def _validate_plaintext_hashtree(proposal, bucket):
705             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
706                 self._fetch_failures["plaintext_hashroot"] += 1
707                 msg = ("The copy of the plaintext_root_hash we received from"
708                        " %s was bad" % bucket)
709                 raise BadPlaintextHashValue(msg)
710             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
711             pt_hashes = dict(list(enumerate(proposal)))
712             try:
713                 pt_hashtree.set_hashes(pt_hashes)
714             except hashtree.BadHashError:
715                 # the hashes they gave us were not self-consistent, even
716                 # though the root matched what we saw in the uri_extension
717                 # block
718                 self._fetch_failures["plaintext_hashtree"] += 1
719                 raise
720             self._plaintext_hashtree = pt_hashtree
721         d = self._obtain_validated_thing(None,
722                                          self._uri_extension_sources,
723                                          "plaintext_hashes",
724                                          "get_plaintext_hashes", (),
725                                          _validate_plaintext_hashtree)
726         return d
727
728     def _get_crypttext_hashtrees(self, res):
729         # crypttext hashes are optional too
730         if "crypttext_root_hash" not in self._uri_extension_data:
731             self._crypttext_hashtree = None
732             return
733         def _validate_crypttext_hashtree(proposal, bucket):
734             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
735                 self._fetch_failures["crypttext_hashroot"] += 1
736                 msg = ("The copy of the crypttext_root_hash we received from"
737                        " %s was bad" % bucket)
738                 raise BadCrypttextHashValue(msg)
739             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
740             ct_hashes = dict(list(enumerate(proposal)))
741             try:
742                 ct_hashtree.set_hashes(ct_hashes)
743             except hashtree.BadHashError:
744                 self._fetch_failures["crypttext_hashtree"] += 1
745                 raise
746             ct_hashtree.set_hashes(ct_hashes)
747             self._crypttext_hashtree = ct_hashtree
748         d = self._obtain_validated_thing(None,
749                                          self._uri_extension_sources,
750                                          "crypttext_hashes",
751                                          "get_crypttext_hashes", (),
752                                          _validate_crypttext_hashtree)
753         return d
754
755     def _setup_hashtrees(self, res):
756         self._output.setup_hashtrees(self._plaintext_hashtree,
757                                      self._crypttext_hashtree)
758         if self._results:
759             elapsed = time.time() - self._get_hashtrees_started
760             self._results.timings["hashtrees"] = elapsed
761
762     def _create_validated_buckets(self, ignored=None):
763         self._share_vbuckets = {}
764         for sharenum, bucket in self._share_buckets:
765             vbucket = ValidatedBucket(sharenum, bucket,
766                                       self._share_hashtree,
767                                       self._roothash,
768                                       self._total_segments)
769             s = self._share_vbuckets.setdefault(sharenum, set())
770             s.add(vbucket)
771
772     def _activate_enough_buckets(self):
773         """either return a mapping from shnum to a ValidatedBucket that can
774         provide data for that share, or raise NotEnoughSharesError"""
775
776         while len(self.active_buckets) < self._num_needed_shares:
777             # need some more
778             handled_shnums = set(self.active_buckets.keys())
779             available_shnums = set(self._share_vbuckets.keys())
780             potential_shnums = list(available_shnums - handled_shnums)
781             if not potential_shnums:
782                 raise NotEnoughSharesError
783             # choose a random share
784             shnum = random.choice(potential_shnums)
785             # and a random bucket that will provide it
786             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
787             self.active_buckets[shnum] = validated_bucket
788         return self.active_buckets
789
790
791     def _download_all_segments(self, res):
792         # the promise: upon entry to this function, self._share_vbuckets
793         # contains enough buckets to complete the download, and some extra
794         # ones to tolerate some buckets dropping out or having errors.
795         # self._share_vbuckets is a dictionary that maps from shnum to a set
796         # of ValidatedBuckets, which themselves are wrappers around
797         # RIBucketReader references.
798         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
799
800         self._started_fetching = time.time()
801
802         d = defer.succeed(None)
803         for segnum in range(self._total_segments-1):
804             d.addCallback(self._download_segment, segnum)
805             # this pause, at the end of write, prevents pre-fetch from
806             # happening until the consumer is ready for more data.
807             d.addCallback(self._check_for_pause)
808         d.addCallback(self._download_tail_segment, self._total_segments-1)
809         return d
810
811     def _check_for_pause(self, res):
812         if self._paused:
813             d = defer.Deferred()
814             self._paused.addCallback(lambda ignored: d.callback(res))
815             return d
816         if self._stopped:
817             raise DownloadStopped("our Consumer called stopProducing()")
818         return res
819
820     def _download_segment(self, res, segnum):
821         if self._status:
822             self._status.set_status("Downloading segment %d of %d" %
823                                     (segnum+1, self._total_segments))
824         self.log("downloading seg#%d of %d (%d%%)"
825                  % (segnum, self._total_segments,
826                     100.0 * segnum / self._total_segments))
827         # memory footprint: when the SegmentDownloader finishes pulling down
828         # all shares, we have 1*segment_size of usage.
829         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
830                                         self._results)
831         started = time.time()
832         d = segmentdler.start()
833         def _finished_fetching(res):
834             elapsed = time.time() - started
835             self._results.timings["cumulative_fetch"] += elapsed
836             return res
837         if self._results:
838             d.addCallback(_finished_fetching)
839         # pause before using more memory
840         d.addCallback(self._check_for_pause)
841         # while the codec does its job, we hit 2*segment_size
842         def _started_decode(res):
843             self._started_decode = time.time()
844             return res
845         if self._results:
846             d.addCallback(_started_decode)
847         d.addCallback(lambda (shares, shareids):
848                       self._codec.decode(shares, shareids))
849         # once the codec is done, we drop back to 1*segment_size, because
850         # 'shares' goes out of scope. The memory usage is all in the
851         # plaintext now, spread out into a bunch of tiny buffers.
852         def _finished_decode(res):
853             elapsed = time.time() - self._started_decode
854             self._results.timings["cumulative_decode"] += elapsed
855             return res
856         if self._results:
857             d.addCallback(_finished_decode)
858
859         # pause/check-for-stop just before writing, to honor stopProducing
860         d.addCallback(self._check_for_pause)
861         def _done(buffers):
862             # we start by joining all these buffers together into a single
863             # string. This makes Output.write easier, since it wants to hash
864             # data one segment at a time anyways, and doesn't impact our
865             # memory footprint since we're already peaking at 2*segment_size
866             # inside the codec a moment ago.
867             segment = "".join(buffers)
868             del buffers
869             # we're down to 1*segment_size right now, but write_segment()
870             # will decrypt a copy of the segment internally, which will push
871             # us up to 2*segment_size while it runs.
872             started_decrypt = time.time()
873             self._output.write_segment(segment)
874             if self._results:
875                 elapsed = time.time() - started_decrypt
876                 self._results.timings["cumulative_decrypt"] += elapsed
877         d.addCallback(_done)
878         return d
879
880     def _download_tail_segment(self, res, segnum):
881         self.log("downloading seg#%d of %d (%d%%)"
882                  % (segnum, self._total_segments,
883                     100.0 * segnum / self._total_segments))
884         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
885                                         self._results)
886         started = time.time()
887         d = segmentdler.start()
888         def _finished_fetching(res):
889             elapsed = time.time() - started
890             self._results.timings["cumulative_fetch"] += elapsed
891             return res
892         if self._results:
893             d.addCallback(_finished_fetching)
894         # pause before using more memory
895         d.addCallback(self._check_for_pause)
896         def _started_decode(res):
897             self._started_decode = time.time()
898             return res
899         if self._results:
900             d.addCallback(_started_decode)
901         d.addCallback(lambda (shares, shareids):
902                       self._tail_codec.decode(shares, shareids))
903         def _finished_decode(res):
904             elapsed = time.time() - self._started_decode
905             self._results.timings["cumulative_decode"] += elapsed
906             return res
907         if self._results:
908             d.addCallback(_finished_decode)
909         # pause/check-for-stop just before writing, to honor stopProducing
910         d.addCallback(self._check_for_pause)
911         def _done(buffers):
912             # trim off any padding added by the upload side
913             segment = "".join(buffers)
914             del buffers
915             # we never send empty segments. If the data was an exact multiple
916             # of the segment size, the last segment will be full.
917             pad_size = mathutil.pad_size(self._size, self._segment_size)
918             tail_size = self._segment_size - pad_size
919             segment = segment[:tail_size]
920             started_decrypt = time.time()
921             self._output.write_segment(segment)
922             if self._results:
923                 elapsed = time.time() - started_decrypt
924                 self._results.timings["cumulative_decrypt"] += elapsed
925         d.addCallback(_done)
926         return d
927
928     def _done(self, res):
929         self.log("download done")
930         if self._results:
931             now = time.time()
932             self._results.timings["total"] = now - self._started
933             self._results.timings["segments"] = now - self._started_fetching
934         self._output.close()
935         if self.check_crypttext_hash and self._crypttext_hash:
936             _assert(self._crypttext_hash == self._output.crypttext_hash,
937                     "bad crypttext_hash: computed=%s, expected=%s" %
938                     (base32.b2a(self._output.crypttext_hash),
939                      base32.b2a(self._crypttext_hash)))
940         if self.check_plaintext_hash and self._plaintext_hash:
941             _assert(self._plaintext_hash == self._output.plaintext_hash,
942                     "bad plaintext_hash: computed=%s, expected=%s" %
943                     (base32.b2a(self._output.plaintext_hash),
944                      base32.b2a(self._plaintext_hash)))
945         _assert(self._output.length == self._size,
946                 got=self._output.length, expected=self._size)
947         return self._output.finish()
948
949     def get_download_status(self):
950         return self._status
951
952
953 class LiteralDownloader:
954     def __init__(self, client, u, downloadable):
955         self._uri = IFileURI(u)
956         assert isinstance(self._uri, uri.LiteralFileURI)
957         self._downloadable = downloadable
958         self._status = s = DownloadStatus()
959         s.set_storage_index(None)
960         s.set_helper(False)
961         s.set_status("Done")
962         s.set_active(False)
963         s.set_progress(1.0)
964
965     def start(self):
966         data = self._uri.data
967         self._status.set_size(len(data))
968         self._downloadable.open(len(data))
969         self._downloadable.write(data)
970         self._downloadable.close()
971         return defer.maybeDeferred(self._downloadable.finish)
972
973     def get_download_status(self):
974         return self._status
975
976 class FileName:
977     implements(IDownloadTarget)
978     def __init__(self, filename):
979         self._filename = filename
980         self.f = None
981     def open(self, size):
982         self.f = open(self._filename, "wb")
983         return self.f
984     def write(self, data):
985         self.f.write(data)
986     def close(self):
987         if self.f:
988             self.f.close()
989     def fail(self, why):
990         if self.f:
991             self.f.close()
992             os.unlink(self._filename)
993     def register_canceller(self, cb):
994         pass # we won't use it
995     def finish(self):
996         pass
997
998 class Data:
999     implements(IDownloadTarget)
1000     def __init__(self):
1001         self._data = []
1002     def open(self, size):
1003         pass
1004     def write(self, data):
1005         self._data.append(data)
1006     def close(self):
1007         self.data = "".join(self._data)
1008         del self._data
1009     def fail(self, why):
1010         del self._data
1011     def register_canceller(self, cb):
1012         pass # we won't use it
1013     def finish(self):
1014         return self.data
1015
1016 class FileHandle:
1017     """Use me to download data to a pre-defined filehandle-like object. I
1018     will use the target's write() method. I will *not* close the filehandle:
1019     I leave that up to the originator of the filehandle. The download process
1020     will return the filehandle when it completes.
1021     """
1022     implements(IDownloadTarget)
1023     def __init__(self, filehandle):
1024         self._filehandle = filehandle
1025     def open(self, size):
1026         pass
1027     def write(self, data):
1028         self._filehandle.write(data)
1029     def close(self):
1030         # the originator of the filehandle reserves the right to close it
1031         pass
1032     def fail(self, why):
1033         pass
1034     def register_canceller(self, cb):
1035         pass
1036     def finish(self):
1037         return self._filehandle
1038
1039 class Downloader(service.MultiService):
1040     """I am a service that allows file downloading.
1041     """
1042     implements(IDownloader)
1043     name = "downloader"
1044     MAX_DOWNLOAD_STATUSES = 10
1045
1046     def __init__(self, stats_provider=None):
1047         service.MultiService.__init__(self)
1048         self._all_downloads = weakref.WeakKeyDictionary()
1049         self.stats_provider = stats_provider
1050         self._recent_download_status = []
1051
1052     def download(self, u, t):
1053         assert self.parent
1054         assert self.running
1055         u = IFileURI(u)
1056         t = IDownloadTarget(t)
1057         assert t.write
1058         assert t.close
1059
1060         if self.stats_provider:
1061             self.stats_provider.count('downloader.files_downloaded', 1)
1062             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1063
1064         if isinstance(u, uri.LiteralFileURI):
1065             dl = LiteralDownloader(self.parent, u, t)
1066         elif isinstance(u, uri.CHKFileURI):
1067             dl = FileDownloader(self.parent, u, t)
1068         else:
1069             raise RuntimeError("I don't know how to download a %s" % u)
1070         self._all_downloads[dl] = None
1071         self._recent_download_status.append(dl.get_download_status())
1072         while len(self._recent_download_status) > self.MAX_DOWNLOAD_STATUSES:
1073             self._recent_download_status.pop(0)
1074         d = dl.start()
1075         return d
1076
1077     # utility functions
1078     def download_to_data(self, uri):
1079         return self.download(uri, Data())
1080     def download_to_filename(self, uri, filename):
1081         return self.download(uri, FileName(filename))
1082     def download_to_filehandle(self, uri, filehandle):
1083         return self.download(uri, FileHandle(filehandle))
1084
1085
1086     def list_all_downloads(self):
1087         return self._all_downloads.keys()
1088     def list_active_downloads(self):
1089         return [d.get_download_status() for d in self._all_downloads.keys()
1090                 if d.get_download_status().get_active()]
1091     def list_recent_downloads(self):
1092         return self._recent_download_status