]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
download status: add time spent paused by the client (when we're serving over a slow...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughSharesError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class DownloadResults:
32     implements(IDownloadResults)
33
34     def __init__(self):
35         self.servers_used = set()
36         self.server_problems = {}
37         self.servermap = {}
38         self.timings = {}
39         self.file_size = None
40
41 class Output:
42     def __init__(self, downloadable, key, total_length, log_parent,
43                  download_status):
44         self.downloadable = downloadable
45         self._decryptor = AES(key)
46         self._crypttext_hasher = hashutil.crypttext_hasher()
47         self._plaintext_hasher = hashutil.plaintext_hasher()
48         self.length = 0
49         self.total_length = total_length
50         self._segment_number = 0
51         self._plaintext_hash_tree = None
52         self._crypttext_hash_tree = None
53         self._opened = False
54         self._log_parent = log_parent
55         self._status = download_status
56         self._status.set_progress(0.0)
57
58     def log(self, *args, **kwargs):
59         if "parent" not in kwargs:
60             kwargs["parent"] = self._log_parent
61         if "facility" not in kwargs:
62             kwargs["facility"] = "download.output"
63         return log.msg(*args, **kwargs)
64
65     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66         self._plaintext_hash_tree = plaintext_hashtree
67         self._crypttext_hash_tree = crypttext_hashtree
68
69     def write_segment(self, crypttext):
70         self.length += len(crypttext)
71         self._status.set_progress( float(self.length) / self.total_length )
72
73         # memory footprint: 'crypttext' is the only segment_size usage
74         # outstanding. While we decrypt it into 'plaintext', we hit
75         # 2*segment_size.
76         self._crypttext_hasher.update(crypttext)
77         if self._crypttext_hash_tree:
78             ch = hashutil.crypttext_segment_hasher()
79             ch.update(crypttext)
80             crypttext_leaves = {self._segment_number: ch.digest()}
81             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
82                      bytes=len(crypttext),
83                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
84                      level=log.NOISY)
85             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
86
87         plaintext = self._decryptor.process(crypttext)
88         del crypttext
89
90         # now we're back down to 1*segment_size.
91
92         self._plaintext_hasher.update(plaintext)
93         if self._plaintext_hash_tree:
94             ph = hashutil.plaintext_segment_hasher()
95             ph.update(plaintext)
96             plaintext_leaves = {self._segment_number: ph.digest()}
97             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
98                      bytes=len(plaintext),
99                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
100                      level=log.NOISY)
101             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
102
103         self._segment_number += 1
104         # We're still at 1*segment_size. The Downloadable is responsible for
105         # any memory usage beyond this.
106         if not self._opened:
107             self._opened = True
108             self.downloadable.open(self.total_length)
109         self.downloadable.write(plaintext)
110
111     def fail(self, why):
112         # this is really unusual, and deserves maximum forensics
113         if why.check(DownloadStopped):
114             # except DownloadStopped just means the consumer aborted the
115             # download, not so scary
116             self.log("download stopped", level=log.UNUSUAL)
117         else:
118             self.log("download failed!", failure=why, level=log.SCARY)
119         self.downloadable.fail(why)
120
121     def close(self):
122         self.crypttext_hash = self._crypttext_hasher.digest()
123         self.plaintext_hash = self._plaintext_hasher.digest()
124         self.log("download finished, closing IDownloadable", level=log.NOISY)
125         self.downloadable.close()
126
127     def finish(self):
128         return self.downloadable.finish()
129
130 class ValidatedBucket:
131     """I am a front-end for a remote storage bucket, responsible for
132     retrieving and validating data from that bucket.
133
134     My get_block() method is used by BlockDownloaders.
135     """
136
137     def __init__(self, sharenum, bucket,
138                  share_hash_tree, roothash,
139                  num_blocks):
140         self.sharenum = sharenum
141         self.bucket = bucket
142         self._share_hash = None # None means not validated yet
143         self.share_hash_tree = share_hash_tree
144         self._roothash = roothash
145         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
146         self.started = False
147
148     def get_block(self, blocknum):
149         if not self.started:
150             d = self.bucket.start()
151             def _started(res):
152                 self.started = True
153                 return self.get_block(blocknum)
154             d.addCallback(_started)
155             return d
156
157         # the first time we use this bucket, we need to fetch enough elements
158         # of the share hash tree to validate it from our share hash up to the
159         # hashroot.
160         if not self._share_hash:
161             d1 = self.bucket.get_share_hashes()
162         else:
163             d1 = defer.succeed([])
164
165         # we might need to grab some elements of our block hash tree, to
166         # validate the requested block up to the share hash
167         needed = self.block_hash_tree.needed_hashes(blocknum)
168         if needed:
169             # TODO: get fewer hashes, use get_block_hashes(needed)
170             d2 = self.bucket.get_block_hashes()
171         else:
172             d2 = defer.succeed([])
173
174         d3 = self.bucket.get_block(blocknum)
175
176         d = defer.gatherResults([d1, d2, d3])
177         d.addCallback(self._got_data, blocknum)
178         return d
179
180     def _got_data(self, res, blocknum):
181         sharehashes, blockhashes, blockdata = res
182         blockhash = None # to make logging it safe
183
184         try:
185             if not self._share_hash:
186                 sh = dict(sharehashes)
187                 sh[0] = self._roothash # always use our own root, from the URI
188                 sht = self.share_hash_tree
189                 if sht.get_leaf_index(self.sharenum) not in sh:
190                     raise hashtree.NotEnoughHashesError
191                 sht.set_hashes(sh)
192                 self._share_hash = sht.get_leaf(self.sharenum)
193
194             blockhash = hashutil.block_hash(blockdata)
195             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
196             #        "%r .. %r: %s" %
197             #        (self.sharenum, blocknum, len(blockdata),
198             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
199
200             # we always validate the blockhash
201             bh = dict(enumerate(blockhashes))
202             # replace blockhash root with validated value
203             bh[0] = self._share_hash
204             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
205
206         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
207             # log.WEIRD: indicates undetected disk/network error, or more
208             # likely a programming error
209             log.msg("hash failure in block=%d, shnum=%d on %s" %
210                     (blocknum, self.sharenum, self.bucket))
211             if self._share_hash:
212                 log.msg(""" failure occurred when checking the block_hash_tree.
213                 This suggests that either the block data was bad, or that the
214                 block hashes we received along with it were bad.""")
215             else:
216                 log.msg(""" the failure probably occurred when checking the
217                 share_hash_tree, which suggests that the share hashes we
218                 received from the remote peer were bad.""")
219             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
220             log.msg(" block length: %d" % len(blockdata))
221             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
222             if len(blockdata) < 100:
223                 log.msg(" block data: %r" % (blockdata,))
224             else:
225                 log.msg(" block data start/end: %r .. %r" %
226                         (blockdata[:50], blockdata[-50:]))
227             log.msg(" root hash: %s" % base32.b2a(self._roothash))
228             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
229             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
230             lines = []
231             for i,h in sorted(sharehashes):
232                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
233             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
234             lines = []
235             for i,h in enumerate(blockhashes):
236                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
237             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
238             raise
239
240         # If we made it here, the block is good. If the hash trees didn't
241         # like what they saw, they would have raised a BadHashError, causing
242         # our caller to see a Failure and thus ignore this block (as well as
243         # dropping this bucket).
244         return blockdata
245
246
247
248 class BlockDownloader:
249     """I am responsible for downloading a single block (from a single bucket)
250     for a single segment.
251
252     I am a child of the SegmentDownloader.
253     """
254
255     def __init__(self, vbucket, blocknum, parent, results):
256         self.vbucket = vbucket
257         self.blocknum = blocknum
258         self.parent = parent
259         self.results = results
260         self._log_number = self.parent.log("starting block %d" % blocknum)
261
262     def log(self, msg, parent=None):
263         if parent is None:
264             parent = self._log_number
265         return self.parent.log(msg, parent=parent)
266
267     def start(self, segnum):
268         lognum = self.log("get_block(segnum=%d)" % segnum)
269         started = time.time()
270         d = self.vbucket.get_block(segnum)
271         d.addCallbacks(self._hold_block, self._got_block_error,
272                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
273         return d
274
275     def _hold_block(self, data, started, lognum):
276         if self.results:
277             elapsed = time.time() - started
278             peerid = self.vbucket.bucket.get_peerid()
279             if peerid not in self.results.timings["fetch_per_server"]:
280                 self.results.timings["fetch_per_server"][peerid] = []
281             self.results.timings["fetch_per_server"][peerid].append(elapsed)
282         self.log("got block", parent=lognum)
283         self.parent.hold_block(self.blocknum, data)
284
285     def _got_block_error(self, f, lognum):
286         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
287                  parent=lognum)
288         if self.results:
289             peerid = self.vbucket.bucket.get_peerid()
290             self.results.server_problems[peerid] = str(f)
291         self.parent.bucket_failed(self.vbucket)
292
293 class SegmentDownloader:
294     """I am responsible for downloading all the blocks for a single segment
295     of data.
296
297     I am a child of the FileDownloader.
298     """
299
300     def __init__(self, parent, segmentnumber, needed_shares, results):
301         self.parent = parent
302         self.segmentnumber = segmentnumber
303         self.needed_blocks = needed_shares
304         self.blocks = {} # k: blocknum, v: data
305         self.results = results
306         self._log_number = self.parent.log("starting segment %d" %
307                                            segmentnumber)
308
309     def log(self, msg, parent=None):
310         if parent is None:
311             parent = self._log_number
312         return self.parent.log(msg, parent=parent)
313
314     def start(self):
315         return self._download()
316
317     def _download(self):
318         d = self._try()
319         def _done(res):
320             if len(self.blocks) >= self.needed_blocks:
321                 # we only need self.needed_blocks blocks
322                 # we want to get the smallest blockids, because they are
323                 # more likely to be fast "primary blocks"
324                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
325                 blocks = []
326                 for blocknum in blockids:
327                     blocks.append(self.blocks[blocknum])
328                 return (blocks, blockids)
329             else:
330                 return self._download()
331         d.addCallback(_done)
332         return d
333
334     def _try(self):
335         # fill our set of active buckets, maybe raising NotEnoughSharesError
336         active_buckets = self.parent._activate_enough_buckets()
337         # Now we have enough buckets, in self.parent.active_buckets.
338
339         # in test cases, bd.start might mutate active_buckets right away, so
340         # we need to put off calling start() until we've iterated all the way
341         # through it.
342         downloaders = []
343         for blocknum, vbucket in active_buckets.iteritems():
344             bd = BlockDownloader(vbucket, blocknum, self, self.results)
345             downloaders.append(bd)
346             if self.results:
347                 self.results.servers_used.add(vbucket.bucket.get_peerid())
348         l = [bd.start(self.segmentnumber) for bd in downloaders]
349         return defer.DeferredList(l, fireOnOneErrback=True)
350
351     def hold_block(self, blocknum, data):
352         self.blocks[blocknum] = data
353
354     def bucket_failed(self, vbucket):
355         self.parent.bucket_failed(vbucket)
356
357 class DownloadStatus:
358     implements(IDownloadStatus)
359     statusid_counter = itertools.count(0)
360
361     def __init__(self):
362         self.storage_index = None
363         self.size = None
364         self.helper = False
365         self.status = "Not started"
366         self.progress = 0.0
367         self.paused = False
368         self.stopped = False
369         self.active = True
370         self.results = None
371         self.counter = self.statusid_counter.next()
372         self.started = time.time()
373
374     def get_started(self):
375         return self.started
376     def get_storage_index(self):
377         return self.storage_index
378     def get_size(self):
379         return self.size
380     def using_helper(self):
381         return self.helper
382     def get_status(self):
383         status = self.status
384         if self.paused:
385             status += " (output paused)"
386         if self.stopped:
387             status += " (output stopped)"
388         return status
389     def get_progress(self):
390         return self.progress
391     def get_active(self):
392         return self.active
393     def get_results(self):
394         return self.results
395     def get_counter(self):
396         return self.counter
397
398     def set_storage_index(self, si):
399         self.storage_index = si
400     def set_size(self, size):
401         self.size = size
402     def set_helper(self, helper):
403         self.helper = helper
404     def set_status(self, status):
405         self.status = status
406     def set_paused(self, paused):
407         self.paused = paused
408     def set_stopped(self, stopped):
409         self.stopped = stopped
410     def set_progress(self, value):
411         self.progress = value
412     def set_active(self, value):
413         self.active = value
414     def set_results(self, value):
415         self.results = value
416
417 class FileDownloader:
418     implements(IPushProducer)
419     check_crypttext_hash = True
420     check_plaintext_hash = True
421     _status = None
422
423     def __init__(self, client, u, downloadable):
424         self._client = client
425
426         u = IFileURI(u)
427         self._storage_index = u.storage_index
428         self._uri_extension_hash = u.uri_extension_hash
429         self._total_shares = u.total_shares
430         self._size = u.size
431         self._num_needed_shares = u.needed_shares
432
433         self._si_s = storage.si_b2a(self._storage_index)
434         self.init_logging()
435
436         self._started = time.time()
437         self._status = s = DownloadStatus()
438         s.set_status("Starting")
439         s.set_storage_index(self._storage_index)
440         s.set_size(self._size)
441         s.set_helper(False)
442         s.set_active(True)
443
444         self._results = DownloadResults()
445         s.set_results(self._results)
446         self._results.file_size = self._size
447         self._results.timings["servers_peer_selection"] = {}
448         self._results.timings["fetch_per_server"] = {}
449         self._results.timings["cumulative_fetch"] = 0.0
450         self._results.timings["cumulative_decode"] = 0.0
451         self._results.timings["cumulative_decrypt"] = 0.0
452         self._results.timings["paused"] = 0.0
453
454         if IConsumer.providedBy(downloadable):
455             downloadable.registerProducer(self, True)
456         self._downloadable = downloadable
457         self._output = Output(downloadable, u.key, self._size, self._log_number,
458                               self._status)
459         self._paused = False
460         self._stopped = False
461
462         self.active_buckets = {} # k: shnum, v: bucket
463         self._share_buckets = [] # list of (sharenum, bucket) tuples
464         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
465         self._uri_extension_sources = []
466
467         self._uri_extension_data = None
468
469         self._fetch_failures = {"uri_extension": 0,
470                                 "plaintext_hashroot": 0,
471                                 "plaintext_hashtree": 0,
472                                 "crypttext_hashroot": 0,
473                                 "crypttext_hashtree": 0,
474                                 }
475
476     def init_logging(self):
477         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
478         num = self._client.log(format="FileDownloader(%(si)s): starting",
479                                si=storage.si_b2a(self._storage_index))
480         self._log_number = num
481
482     def log(self, *args, **kwargs):
483         if "parent" not in kwargs:
484             kwargs["parent"] = self._log_number
485         if "facility" not in kwargs:
486             kwargs["facility"] = "tahoe.download"
487         return log.msg(*args, **kwargs)
488
489     def pauseProducing(self):
490         if self._paused:
491             return
492         self._paused = defer.Deferred()
493         self._paused_at = time.time()
494         if self._status:
495             self._status.set_paused(True)
496
497     def resumeProducing(self):
498         if self._paused:
499             p = self._paused
500             self._paused = None
501             eventually(p.callback, None)
502             if self._status:
503                 self._status.set_paused(False)
504
505     def stopProducing(self):
506         self.log("Download.stopProducing")
507         self._stopped = True
508         paused_for = time.time() - self._paused_at
509         self._results.timings['paused'] += paused_for
510         if self._status:
511             self._status.set_stopped(True)
512             self._status.set_active(False)
513
514     def start(self):
515         self.log("starting download")
516
517         # first step: who should we download from?
518         d = defer.maybeDeferred(self._get_all_shareholders)
519         d.addCallback(self._got_all_shareholders)
520         # now get the uri_extension block from somebody and validate it
521         d.addCallback(self._obtain_uri_extension)
522         d.addCallback(self._got_uri_extension)
523         d.addCallback(self._get_hashtrees)
524         d.addCallback(self._create_validated_buckets)
525         # once we know that, we can download blocks from everybody
526         d.addCallback(self._download_all_segments)
527         def _finished(res):
528             if self._status:
529                 self._status.set_status("Finished")
530                 self._status.set_active(False)
531                 self._status.set_paused(False)
532             if IConsumer.providedBy(self._downloadable):
533                 self._downloadable.unregisterProducer()
534             return res
535         d.addBoth(_finished)
536         def _failed(why):
537             if self._status:
538                 self._status.set_status("Failed")
539                 self._status.set_active(False)
540             self._output.fail(why)
541             return why
542         d.addErrback(_failed)
543         d.addCallback(self._done)
544         return d
545
546     def _get_all_shareholders(self):
547         dl = []
548         for (peerid,ss) in self._client.get_permuted_peers("storage",
549                                                            self._storage_index):
550             d = ss.callRemote("get_buckets", self._storage_index)
551             d.addCallbacks(self._got_response, self._got_error,
552                            callbackArgs=(peerid,))
553             dl.append(d)
554         self._responses_received = 0
555         self._queries_sent = len(dl)
556         if self._status:
557             self._status.set_status("Locating Shares (%d/%d)" %
558                                     (self._responses_received,
559                                      self._queries_sent))
560         return defer.DeferredList(dl)
561
562     def _got_response(self, buckets, peerid):
563         self._responses_received += 1
564         if self._results:
565             elapsed = time.time() - self._started
566             self._results.timings["servers_peer_selection"][peerid] = elapsed
567         if self._status:
568             self._status.set_status("Locating Shares (%d/%d)" %
569                                     (self._responses_received,
570                                      self._queries_sent))
571         for sharenum, bucket in buckets.iteritems():
572             b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
573             self.add_share_bucket(sharenum, b)
574             self._uri_extension_sources.append(b)
575             if self._results:
576                 if peerid not in self._results.servermap:
577                     self._results.servermap[peerid] = set()
578                 self._results.servermap[peerid].add(sharenum)
579
580     def add_share_bucket(self, sharenum, bucket):
581         # this is split out for the benefit of test_encode.py
582         self._share_buckets.append( (sharenum, bucket) )
583
584     def _got_error(self, f):
585         self._client.log("Somebody failed. -- %s" % (f,))
586
587     def bucket_failed(self, vbucket):
588         shnum = vbucket.sharenum
589         del self.active_buckets[shnum]
590         s = self._share_vbuckets[shnum]
591         # s is a set of ValidatedBucket instances
592         s.remove(vbucket)
593         # ... which might now be empty
594         if not s:
595             # there are no more buckets which can provide this share, so
596             # remove the key. This may prompt us to use a different share.
597             del self._share_vbuckets[shnum]
598
599     def _got_all_shareholders(self, res):
600         if self._results:
601             now = time.time()
602             self._results.timings["peer_selection"] = now - self._started
603
604         if len(self._share_buckets) < self._num_needed_shares:
605             raise NotEnoughSharesError
606
607         #for s in self._share_vbuckets.values():
608         #    for vb in s:
609         #        assert isinstance(vb, ValidatedBucket), \
610         #               "vb is %s but should be a ValidatedBucket" % (vb,)
611
612     def _unpack_uri_extension_data(self, data):
613         return uri.unpack_extension(data)
614
615     def _obtain_uri_extension(self, ignored):
616         # all shareholders are supposed to have a copy of uri_extension, and
617         # all are supposed to be identical. We compute the hash of the data
618         # that comes back, and compare it against the version in our URI. If
619         # they don't match, ignore their data and try someone else.
620         if self._status:
621             self._status.set_status("Obtaining URI Extension")
622
623         self._uri_extension_fetch_started = time.time()
624         def _validate(proposal, bucket):
625             h = hashutil.uri_extension_hash(proposal)
626             if h != self._uri_extension_hash:
627                 self._fetch_failures["uri_extension"] += 1
628                 msg = ("The copy of uri_extension we received from "
629                        "%s was bad: wanted %s, got %s" %
630                        (bucket,
631                         base32.b2a(self._uri_extension_hash),
632                         base32.b2a(h)))
633                 self.log(msg, level=log.SCARY)
634                 raise BadURIExtensionHashValue(msg)
635             return self._unpack_uri_extension_data(proposal)
636         return self._obtain_validated_thing(None,
637                                             self._uri_extension_sources,
638                                             "uri_extension",
639                                             "get_uri_extension", (), _validate)
640
641     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
642                                 validatorfunc):
643         if not sources:
644             raise NotEnoughSharesError("started with zero peers while fetching "
645                                       "%s" % name)
646         bucket = sources[0]
647         sources = sources[1:]
648         #d = bucket.callRemote(methname, *args)
649         d = bucket.startIfNecessary()
650         d.addCallback(lambda res: getattr(bucket, methname)(*args))
651         d.addCallback(validatorfunc, bucket)
652         def _bad(f):
653             self.log("%s from vbucket %s failed:" % (name, bucket),
654                      failure=f, level=log.WEIRD)
655             if not sources:
656                 raise NotEnoughSharesError("ran out of peers, last error was %s"
657                                           % (f,))
658             # try again with a different one
659             return self._obtain_validated_thing(None, sources, name,
660                                                 methname, args, validatorfunc)
661         d.addErrback(_bad)
662         return d
663
664     def _got_uri_extension(self, uri_extension_data):
665         if self._results:
666             elapsed = time.time() - self._uri_extension_fetch_started
667             self._results.timings["uri_extension"] = elapsed
668
669         d = self._uri_extension_data = uri_extension_data
670
671         self._codec = codec.get_decoder_by_name(d['codec_name'])
672         self._codec.set_serialized_params(d['codec_params'])
673         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
674         self._tail_codec.set_serialized_params(d['tail_codec_params'])
675
676         crypttext_hash = d.get('crypttext_hash', None) # optional
677         if crypttext_hash:
678             assert isinstance(crypttext_hash, str)
679             assert len(crypttext_hash) == 32
680         self._crypttext_hash = crypttext_hash
681         self._plaintext_hash = d.get('plaintext_hash', None) # optional
682
683         self._roothash = d['share_root_hash']
684
685         self._segment_size = segment_size = d['segment_size']
686         self._total_segments = mathutil.div_ceil(self._size, segment_size)
687         self._current_segnum = 0
688
689         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
690         self._share_hashtree.set_hashes({0: self._roothash})
691
692     def _get_hashtrees(self, res):
693         self._get_hashtrees_started = time.time()
694         if self._status:
695             self._status.set_status("Retrieving Hash Trees")
696         d = defer.maybeDeferred(self._get_plaintext_hashtrees)
697         d.addCallback(self._get_crypttext_hashtrees)
698         d.addCallback(self._setup_hashtrees)
699         return d
700
701     def _get_plaintext_hashtrees(self):
702         # plaintext hashes are optional. If the root isn't in the UEB, then
703         # the share will be holding an empty list. We don't even bother
704         # fetching it.
705         if "plaintext_root_hash" not in self._uri_extension_data:
706             self._plaintext_hashtree = None
707             return
708         def _validate_plaintext_hashtree(proposal, bucket):
709             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
710                 self._fetch_failures["plaintext_hashroot"] += 1
711                 msg = ("The copy of the plaintext_root_hash we received from"
712                        " %s was bad" % bucket)
713                 raise BadPlaintextHashValue(msg)
714             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
715             pt_hashes = dict(list(enumerate(proposal)))
716             try:
717                 pt_hashtree.set_hashes(pt_hashes)
718             except hashtree.BadHashError:
719                 # the hashes they gave us were not self-consistent, even
720                 # though the root matched what we saw in the uri_extension
721                 # block
722                 self._fetch_failures["plaintext_hashtree"] += 1
723                 raise
724             self._plaintext_hashtree = pt_hashtree
725         d = self._obtain_validated_thing(None,
726                                          self._uri_extension_sources,
727                                          "plaintext_hashes",
728                                          "get_plaintext_hashes", (),
729                                          _validate_plaintext_hashtree)
730         return d
731
732     def _get_crypttext_hashtrees(self, res):
733         # crypttext hashes are optional too
734         if "crypttext_root_hash" not in self._uri_extension_data:
735             self._crypttext_hashtree = None
736             return
737         def _validate_crypttext_hashtree(proposal, bucket):
738             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
739                 self._fetch_failures["crypttext_hashroot"] += 1
740                 msg = ("The copy of the crypttext_root_hash we received from"
741                        " %s was bad" % bucket)
742                 raise BadCrypttextHashValue(msg)
743             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
744             ct_hashes = dict(list(enumerate(proposal)))
745             try:
746                 ct_hashtree.set_hashes(ct_hashes)
747             except hashtree.BadHashError:
748                 self._fetch_failures["crypttext_hashtree"] += 1
749                 raise
750             ct_hashtree.set_hashes(ct_hashes)
751             self._crypttext_hashtree = ct_hashtree
752         d = self._obtain_validated_thing(None,
753                                          self._uri_extension_sources,
754                                          "crypttext_hashes",
755                                          "get_crypttext_hashes", (),
756                                          _validate_crypttext_hashtree)
757         return d
758
759     def _setup_hashtrees(self, res):
760         self._output.setup_hashtrees(self._plaintext_hashtree,
761                                      self._crypttext_hashtree)
762         if self._results:
763             elapsed = time.time() - self._get_hashtrees_started
764             self._results.timings["hashtrees"] = elapsed
765
766     def _create_validated_buckets(self, ignored=None):
767         self._share_vbuckets = {}
768         for sharenum, bucket in self._share_buckets:
769             vbucket = ValidatedBucket(sharenum, bucket,
770                                       self._share_hashtree,
771                                       self._roothash,
772                                       self._total_segments)
773             s = self._share_vbuckets.setdefault(sharenum, set())
774             s.add(vbucket)
775
776     def _activate_enough_buckets(self):
777         """either return a mapping from shnum to a ValidatedBucket that can
778         provide data for that share, or raise NotEnoughSharesError"""
779
780         while len(self.active_buckets) < self._num_needed_shares:
781             # need some more
782             handled_shnums = set(self.active_buckets.keys())
783             available_shnums = set(self._share_vbuckets.keys())
784             potential_shnums = list(available_shnums - handled_shnums)
785             if not potential_shnums:
786                 raise NotEnoughSharesError
787             # choose a random share
788             shnum = random.choice(potential_shnums)
789             # and a random bucket that will provide it
790             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
791             self.active_buckets[shnum] = validated_bucket
792         return self.active_buckets
793
794
795     def _download_all_segments(self, res):
796         # the promise: upon entry to this function, self._share_vbuckets
797         # contains enough buckets to complete the download, and some extra
798         # ones to tolerate some buckets dropping out or having errors.
799         # self._share_vbuckets is a dictionary that maps from shnum to a set
800         # of ValidatedBuckets, which themselves are wrappers around
801         # RIBucketReader references.
802         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
803
804         self._started_fetching = time.time()
805
806         d = defer.succeed(None)
807         for segnum in range(self._total_segments-1):
808             d.addCallback(self._download_segment, segnum)
809             # this pause, at the end of write, prevents pre-fetch from
810             # happening until the consumer is ready for more data.
811             d.addCallback(self._check_for_pause)
812         d.addCallback(self._download_tail_segment, self._total_segments-1)
813         return d
814
815     def _check_for_pause(self, res):
816         if self._paused:
817             d = defer.Deferred()
818             self._paused.addCallback(lambda ignored: d.callback(res))
819             return d
820         if self._stopped:
821             raise DownloadStopped("our Consumer called stopProducing()")
822         return res
823
824     def _download_segment(self, res, segnum):
825         if self._status:
826             self._status.set_status("Downloading segment %d of %d" %
827                                     (segnum+1, self._total_segments))
828         self.log("downloading seg#%d of %d (%d%%)"
829                  % (segnum, self._total_segments,
830                     100.0 * segnum / self._total_segments))
831         # memory footprint: when the SegmentDownloader finishes pulling down
832         # all shares, we have 1*segment_size of usage.
833         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
834                                         self._results)
835         started = time.time()
836         d = segmentdler.start()
837         def _finished_fetching(res):
838             elapsed = time.time() - started
839             self._results.timings["cumulative_fetch"] += elapsed
840             return res
841         if self._results:
842             d.addCallback(_finished_fetching)
843         # pause before using more memory
844         d.addCallback(self._check_for_pause)
845         # while the codec does its job, we hit 2*segment_size
846         def _started_decode(res):
847             self._started_decode = time.time()
848             return res
849         if self._results:
850             d.addCallback(_started_decode)
851         d.addCallback(lambda (shares, shareids):
852                       self._codec.decode(shares, shareids))
853         # once the codec is done, we drop back to 1*segment_size, because
854         # 'shares' goes out of scope. The memory usage is all in the
855         # plaintext now, spread out into a bunch of tiny buffers.
856         def _finished_decode(res):
857             elapsed = time.time() - self._started_decode
858             self._results.timings["cumulative_decode"] += elapsed
859             return res
860         if self._results:
861             d.addCallback(_finished_decode)
862
863         # pause/check-for-stop just before writing, to honor stopProducing
864         d.addCallback(self._check_for_pause)
865         def _done(buffers):
866             # we start by joining all these buffers together into a single
867             # string. This makes Output.write easier, since it wants to hash
868             # data one segment at a time anyways, and doesn't impact our
869             # memory footprint since we're already peaking at 2*segment_size
870             # inside the codec a moment ago.
871             segment = "".join(buffers)
872             del buffers
873             # we're down to 1*segment_size right now, but write_segment()
874             # will decrypt a copy of the segment internally, which will push
875             # us up to 2*segment_size while it runs.
876             started_decrypt = time.time()
877             self._output.write_segment(segment)
878             if self._results:
879                 elapsed = time.time() - started_decrypt
880                 self._results.timings["cumulative_decrypt"] += elapsed
881         d.addCallback(_done)
882         return d
883
884     def _download_tail_segment(self, res, segnum):
885         self.log("downloading seg#%d of %d (%d%%)"
886                  % (segnum, self._total_segments,
887                     100.0 * segnum / self._total_segments))
888         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
889                                         self._results)
890         started = time.time()
891         d = segmentdler.start()
892         def _finished_fetching(res):
893             elapsed = time.time() - started
894             self._results.timings["cumulative_fetch"] += elapsed
895             return res
896         if self._results:
897             d.addCallback(_finished_fetching)
898         # pause before using more memory
899         d.addCallback(self._check_for_pause)
900         def _started_decode(res):
901             self._started_decode = time.time()
902             return res
903         if self._results:
904             d.addCallback(_started_decode)
905         d.addCallback(lambda (shares, shareids):
906                       self._tail_codec.decode(shares, shareids))
907         def _finished_decode(res):
908             elapsed = time.time() - self._started_decode
909             self._results.timings["cumulative_decode"] += elapsed
910             return res
911         if self._results:
912             d.addCallback(_finished_decode)
913         # pause/check-for-stop just before writing, to honor stopProducing
914         d.addCallback(self._check_for_pause)
915         def _done(buffers):
916             # trim off any padding added by the upload side
917             segment = "".join(buffers)
918             del buffers
919             # we never send empty segments. If the data was an exact multiple
920             # of the segment size, the last segment will be full.
921             pad_size = mathutil.pad_size(self._size, self._segment_size)
922             tail_size = self._segment_size - pad_size
923             segment = segment[:tail_size]
924             started_decrypt = time.time()
925             self._output.write_segment(segment)
926             if self._results:
927                 elapsed = time.time() - started_decrypt
928                 self._results.timings["cumulative_decrypt"] += elapsed
929         d.addCallback(_done)
930         return d
931
932     def _done(self, res):
933         self.log("download done")
934         if self._results:
935             now = time.time()
936             self._results.timings["total"] = now - self._started
937             self._results.timings["segments"] = now - self._started_fetching
938         self._output.close()
939         if self.check_crypttext_hash and self._crypttext_hash:
940             _assert(self._crypttext_hash == self._output.crypttext_hash,
941                     "bad crypttext_hash: computed=%s, expected=%s" %
942                     (base32.b2a(self._output.crypttext_hash),
943                      base32.b2a(self._crypttext_hash)))
944         if self.check_plaintext_hash and self._plaintext_hash:
945             _assert(self._plaintext_hash == self._output.plaintext_hash,
946                     "bad plaintext_hash: computed=%s, expected=%s" %
947                     (base32.b2a(self._output.plaintext_hash),
948                      base32.b2a(self._plaintext_hash)))
949         _assert(self._output.length == self._size,
950                 got=self._output.length, expected=self._size)
951         return self._output.finish()
952
953     def get_download_status(self):
954         return self._status
955
956
957 class LiteralDownloader:
958     def __init__(self, client, u, downloadable):
959         self._uri = IFileURI(u)
960         assert isinstance(self._uri, uri.LiteralFileURI)
961         self._downloadable = downloadable
962         self._status = s = DownloadStatus()
963         s.set_storage_index(None)
964         s.set_helper(False)
965         s.set_status("Done")
966         s.set_active(False)
967         s.set_progress(1.0)
968
969     def start(self):
970         data = self._uri.data
971         self._status.set_size(len(data))
972         self._downloadable.open(len(data))
973         self._downloadable.write(data)
974         self._downloadable.close()
975         return defer.maybeDeferred(self._downloadable.finish)
976
977     def get_download_status(self):
978         return self._status
979
980 class FileName:
981     implements(IDownloadTarget)
982     def __init__(self, filename):
983         self._filename = filename
984         self.f = None
985     def open(self, size):
986         self.f = open(self._filename, "wb")
987         return self.f
988     def write(self, data):
989         self.f.write(data)
990     def close(self):
991         if self.f:
992             self.f.close()
993     def fail(self, why):
994         if self.f:
995             self.f.close()
996             os.unlink(self._filename)
997     def register_canceller(self, cb):
998         pass # we won't use it
999     def finish(self):
1000         pass
1001
1002 class Data:
1003     implements(IDownloadTarget)
1004     def __init__(self):
1005         self._data = []
1006     def open(self, size):
1007         pass
1008     def write(self, data):
1009         self._data.append(data)
1010     def close(self):
1011         self.data = "".join(self._data)
1012         del self._data
1013     def fail(self, why):
1014         del self._data
1015     def register_canceller(self, cb):
1016         pass # we won't use it
1017     def finish(self):
1018         return self.data
1019
1020 class FileHandle:
1021     """Use me to download data to a pre-defined filehandle-like object. I
1022     will use the target's write() method. I will *not* close the filehandle:
1023     I leave that up to the originator of the filehandle. The download process
1024     will return the filehandle when it completes.
1025     """
1026     implements(IDownloadTarget)
1027     def __init__(self, filehandle):
1028         self._filehandle = filehandle
1029     def open(self, size):
1030         pass
1031     def write(self, data):
1032         self._filehandle.write(data)
1033     def close(self):
1034         # the originator of the filehandle reserves the right to close it
1035         pass
1036     def fail(self, why):
1037         pass
1038     def register_canceller(self, cb):
1039         pass
1040     def finish(self):
1041         return self._filehandle
1042
1043 class Downloader(service.MultiService):
1044     """I am a service that allows file downloading.
1045     """
1046     implements(IDownloader)
1047     name = "downloader"
1048     MAX_DOWNLOAD_STATUSES = 10
1049
1050     def __init__(self, stats_provider=None):
1051         service.MultiService.__init__(self)
1052         self.stats_provider = stats_provider
1053         self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1054         self._all_download_statuses = weakref.WeakKeyDictionary()
1055         self._recent_download_statuses = []
1056
1057     def download(self, u, t):
1058         assert self.parent
1059         assert self.running
1060         u = IFileURI(u)
1061         t = IDownloadTarget(t)
1062         assert t.write
1063         assert t.close
1064
1065         if self.stats_provider:
1066             self.stats_provider.count('downloader.files_downloaded', 1)
1067             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1068
1069         if isinstance(u, uri.LiteralFileURI):
1070             dl = LiteralDownloader(self.parent, u, t)
1071         elif isinstance(u, uri.CHKFileURI):
1072             dl = FileDownloader(self.parent, u, t)
1073         else:
1074             raise RuntimeError("I don't know how to download a %s" % u)
1075         self._add_download(dl)
1076         d = dl.start()
1077         return d
1078
1079     # utility functions
1080     def download_to_data(self, uri):
1081         return self.download(uri, Data())
1082     def download_to_filename(self, uri, filename):
1083         return self.download(uri, FileName(filename))
1084     def download_to_filehandle(self, uri, filehandle):
1085         return self.download(uri, FileHandle(filehandle))
1086
1087     def _add_download(self, downloader):
1088         self._all_downloads[downloader] = None
1089         s = downloader.get_download_status()
1090         self._all_download_statuses[s] = None
1091         self._recent_download_statuses.append(s)
1092         while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1093             self._recent_download_statuses.pop(0)
1094
1095     def list_all_download_statuses(self):
1096         for ds in self._all_download_statuses:
1097             yield ds