]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/download.py
web: status: add 'started' timestamps to all operations
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap.eventual import eventually
8
9 from allmydata.util import base32, mathutil, hashutil, log
10 from allmydata.util.assertutil import _assert
11 from allmydata import codec, hashtree, storage, uri
12 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
13      IDownloadStatus, IDownloadResults
14 from allmydata.encode import NotEnoughPeersError
15 from pycryptopp.cipher.aes import AES
16
17 class HaveAllPeersError(Exception):
18     # we use this to jump out of the loop
19     pass
20
21 class BadURIExtensionHashValue(Exception):
22     pass
23 class BadPlaintextHashValue(Exception):
24     pass
25 class BadCrypttextHashValue(Exception):
26     pass
27
28 class DownloadStopped(Exception):
29     pass
30
31 class DownloadResults:
32     implements(IDownloadResults)
33
34     def __init__(self):
35         self.servers_used = set()
36         self.server_problems = {}
37         self.servermap = {}
38         self.timings = {}
39         self.file_size = None
40
41 class Output:
42     def __init__(self, downloadable, key, total_length, log_parent,
43                  download_status):
44         self.downloadable = downloadable
45         self._decryptor = AES(key)
46         self._crypttext_hasher = hashutil.crypttext_hasher()
47         self._plaintext_hasher = hashutil.plaintext_hasher()
48         self.length = 0
49         self.total_length = total_length
50         self._segment_number = 0
51         self._plaintext_hash_tree = None
52         self._crypttext_hash_tree = None
53         self._opened = False
54         self._log_parent = log_parent
55         self._status = download_status
56         self._status.set_progress(0.0)
57
58     def log(self, *args, **kwargs):
59         if "parent" not in kwargs:
60             kwargs["parent"] = self._log_parent
61         if "facility" not in kwargs:
62             kwargs["facility"] = "download.output"
63         return log.msg(*args, **kwargs)
64
65     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
66         self._plaintext_hash_tree = plaintext_hashtree
67         self._crypttext_hash_tree = crypttext_hashtree
68
69     def write_segment(self, crypttext):
70         self.length += len(crypttext)
71         self._status.set_progress( float(self.length) / self.total_length )
72
73         # memory footprint: 'crypttext' is the only segment_size usage
74         # outstanding. While we decrypt it into 'plaintext', we hit
75         # 2*segment_size.
76         self._crypttext_hasher.update(crypttext)
77         if self._crypttext_hash_tree:
78             ch = hashutil.crypttext_segment_hasher()
79             ch.update(crypttext)
80             crypttext_leaves = {self._segment_number: ch.digest()}
81             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
82                      bytes=len(crypttext),
83                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
84                      level=log.NOISY)
85             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
86
87         plaintext = self._decryptor.process(crypttext)
88         del crypttext
89
90         # now we're back down to 1*segment_size.
91
92         self._plaintext_hasher.update(plaintext)
93         if self._plaintext_hash_tree:
94             ph = hashutil.plaintext_segment_hasher()
95             ph.update(plaintext)
96             plaintext_leaves = {self._segment_number: ph.digest()}
97             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
98                      bytes=len(plaintext),
99                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
100                      level=log.NOISY)
101             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
102
103         self._segment_number += 1
104         # We're still at 1*segment_size. The Downloadable is responsible for
105         # any memory usage beyond this.
106         if not self._opened:
107             self._opened = True
108             self.downloadable.open(self.total_length)
109         self.downloadable.write(plaintext)
110
111     def fail(self, why):
112         # this is really unusual, and deserves maximum forensics
113         self.log("download failed!", failure=why, level=log.SCARY)
114         self.downloadable.fail(why)
115
116     def close(self):
117         self.crypttext_hash = self._crypttext_hasher.digest()
118         self.plaintext_hash = self._plaintext_hasher.digest()
119         self.log("download finished, closing IDownloadable", level=log.NOISY)
120         self.downloadable.close()
121
122     def finish(self):
123         return self.downloadable.finish()
124
125 class ValidatedBucket:
126     """I am a front-end for a remote storage bucket, responsible for
127     retrieving and validating data from that bucket.
128
129     My get_block() method is used by BlockDownloaders.
130     """
131
132     def __init__(self, sharenum, bucket,
133                  share_hash_tree, roothash,
134                  num_blocks):
135         self.sharenum = sharenum
136         self.bucket = bucket
137         self._share_hash = None # None means not validated yet
138         self.share_hash_tree = share_hash_tree
139         self._roothash = roothash
140         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
141         self.started = False
142
143     def get_block(self, blocknum):
144         if not self.started:
145             d = self.bucket.start()
146             def _started(res):
147                 self.started = True
148                 return self.get_block(blocknum)
149             d.addCallback(_started)
150             return d
151
152         # the first time we use this bucket, we need to fetch enough elements
153         # of the share hash tree to validate it from our share hash up to the
154         # hashroot.
155         if not self._share_hash:
156             d1 = self.bucket.get_share_hashes()
157         else:
158             d1 = defer.succeed([])
159
160         # we might need to grab some elements of our block hash tree, to
161         # validate the requested block up to the share hash
162         needed = self.block_hash_tree.needed_hashes(blocknum)
163         if needed:
164             # TODO: get fewer hashes, use get_block_hashes(needed)
165             d2 = self.bucket.get_block_hashes()
166         else:
167             d2 = defer.succeed([])
168
169         d3 = self.bucket.get_block(blocknum)
170
171         d = defer.gatherResults([d1, d2, d3])
172         d.addCallback(self._got_data, blocknum)
173         return d
174
175     def _got_data(self, res, blocknum):
176         sharehashes, blockhashes, blockdata = res
177         blockhash = None # to make logging it safe
178
179         try:
180             if not self._share_hash:
181                 sh = dict(sharehashes)
182                 sh[0] = self._roothash # always use our own root, from the URI
183                 sht = self.share_hash_tree
184                 if sht.get_leaf_index(self.sharenum) not in sh:
185                     raise hashtree.NotEnoughHashesError
186                 sht.set_hashes(sh)
187                 self._share_hash = sht.get_leaf(self.sharenum)
188
189             blockhash = hashutil.block_hash(blockdata)
190             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
191             #        "%r .. %r: %s" %
192             #        (self.sharenum, blocknum, len(blockdata),
193             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
194
195             # we always validate the blockhash
196             bh = dict(enumerate(blockhashes))
197             # replace blockhash root with validated value
198             bh[0] = self._share_hash
199             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
200
201         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
202             # log.WEIRD: indicates undetected disk/network error, or more
203             # likely a programming error
204             log.msg("hash failure in block=%d, shnum=%d on %s" %
205                     (blocknum, self.sharenum, self.bucket))
206             if self._share_hash:
207                 log.msg(""" failure occurred when checking the block_hash_tree.
208                 This suggests that either the block data was bad, or that the
209                 block hashes we received along with it were bad.""")
210             else:
211                 log.msg(""" the failure probably occurred when checking the
212                 share_hash_tree, which suggests that the share hashes we
213                 received from the remote peer were bad.""")
214             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
215             log.msg(" block length: %d" % len(blockdata))
216             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
217             if len(blockdata) < 100:
218                 log.msg(" block data: %r" % (blockdata,))
219             else:
220                 log.msg(" block data start/end: %r .. %r" %
221                         (blockdata[:50], blockdata[-50:]))
222             log.msg(" root hash: %s" % base32.b2a(self._roothash))
223             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
224             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
225             lines = []
226             for i,h in sorted(sharehashes):
227                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
228             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
229             lines = []
230             for i,h in enumerate(blockhashes):
231                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
232             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
233             raise
234
235         # If we made it here, the block is good. If the hash trees didn't
236         # like what they saw, they would have raised a BadHashError, causing
237         # our caller to see a Failure and thus ignore this block (as well as
238         # dropping this bucket).
239         return blockdata
240
241
242
243 class BlockDownloader:
244     """I am responsible for downloading a single block (from a single bucket)
245     for a single segment.
246
247     I am a child of the SegmentDownloader.
248     """
249
250     def __init__(self, vbucket, blocknum, parent, results):
251         self.vbucket = vbucket
252         self.blocknum = blocknum
253         self.parent = parent
254         self.results = results
255         self._log_number = self.parent.log("starting block %d" % blocknum)
256
257     def log(self, msg, parent=None):
258         if parent is None:
259             parent = self._log_number
260         return self.parent.log(msg, parent=parent)
261
262     def start(self, segnum):
263         lognum = self.log("get_block(segnum=%d)" % segnum)
264         started = time.time()
265         d = self.vbucket.get_block(segnum)
266         d.addCallbacks(self._hold_block, self._got_block_error,
267                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
268         return d
269
270     def _hold_block(self, data, started, lognum):
271         if self.results:
272             elapsed = time.time() - started
273             peerid = self.vbucket.bucket.get_peerid()
274             if peerid not in self.results.timings["fetch_per_server"]:
275                 self.results.timings["fetch_per_server"][peerid] = []
276             self.results.timings["fetch_per_server"][peerid].append(elapsed)
277         self.log("got block", parent=lognum)
278         self.parent.hold_block(self.blocknum, data)
279
280     def _got_block_error(self, f, lognum):
281         self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
282                  parent=lognum)
283         if self.results:
284             peerid = self.vbucket.bucket.get_peerid()
285             self.results.server_problems[peerid] = str(f)
286         self.parent.bucket_failed(self.vbucket)
287
288 class SegmentDownloader:
289     """I am responsible for downloading all the blocks for a single segment
290     of data.
291
292     I am a child of the FileDownloader.
293     """
294
295     def __init__(self, parent, segmentnumber, needed_shares, results):
296         self.parent = parent
297         self.segmentnumber = segmentnumber
298         self.needed_blocks = needed_shares
299         self.blocks = {} # k: blocknum, v: data
300         self.results = results
301         self._log_number = self.parent.log("starting segment %d" %
302                                            segmentnumber)
303
304     def log(self, msg, parent=None):
305         if parent is None:
306             parent = self._log_number
307         return self.parent.log(msg, parent=parent)
308
309     def start(self):
310         return self._download()
311
312     def _download(self):
313         d = self._try()
314         def _done(res):
315             if len(self.blocks) >= self.needed_blocks:
316                 # we only need self.needed_blocks blocks
317                 # we want to get the smallest blockids, because they are
318                 # more likely to be fast "primary blocks"
319                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
320                 blocks = []
321                 for blocknum in blockids:
322                     blocks.append(self.blocks[blocknum])
323                 return (blocks, blockids)
324             else:
325                 return self._download()
326         d.addCallback(_done)
327         return d
328
329     def _try(self):
330         # fill our set of active buckets, maybe raising NotEnoughPeersError
331         active_buckets = self.parent._activate_enough_buckets()
332         # Now we have enough buckets, in self.parent.active_buckets.
333
334         # in test cases, bd.start might mutate active_buckets right away, so
335         # we need to put off calling start() until we've iterated all the way
336         # through it.
337         downloaders = []
338         for blocknum, vbucket in active_buckets.iteritems():
339             bd = BlockDownloader(vbucket, blocknum, self, self.results)
340             downloaders.append(bd)
341             if self.results:
342                 self.results.servers_used.add(vbucket.bucket.get_peerid())
343         l = [bd.start(self.segmentnumber) for bd in downloaders]
344         return defer.DeferredList(l, fireOnOneErrback=True)
345
346     def hold_block(self, blocknum, data):
347         self.blocks[blocknum] = data
348
349     def bucket_failed(self, vbucket):
350         self.parent.bucket_failed(vbucket)
351
352 class DownloadStatus:
353     implements(IDownloadStatus)
354     statusid_counter = itertools.count(0)
355
356     def __init__(self):
357         self.storage_index = None
358         self.size = None
359         self.helper = False
360         self.status = "Not started"
361         self.progress = 0.0
362         self.paused = False
363         self.stopped = False
364         self.active = True
365         self.results = None
366         self.counter = self.statusid_counter.next()
367         self.started = time.time()
368
369     def get_started(self):
370         return self.started
371     def get_storage_index(self):
372         return self.storage_index
373     def get_size(self):
374         return self.size
375     def using_helper(self):
376         return self.helper
377     def get_status(self):
378         status = self.status
379         if self.paused:
380             status += " (output paused)"
381         if self.stopped:
382             status += " (output stopped)"
383         return status
384     def get_progress(self):
385         return self.progress
386     def get_active(self):
387         return self.active
388     def get_results(self):
389         return self.results
390     def get_counter(self):
391         return self.counter
392
393     def set_storage_index(self, si):
394         self.storage_index = si
395     def set_size(self, size):
396         self.size = size
397     def set_helper(self, helper):
398         self.helper = helper
399     def set_status(self, status):
400         self.status = status
401     def set_paused(self, paused):
402         self.paused = paused
403     def set_stopped(self, stopped):
404         self.stopped = stopped
405     def set_progress(self, value):
406         self.progress = value
407     def set_active(self, value):
408         self.active = value
409     def set_results(self, value):
410         self.results = value
411
412 class FileDownloader:
413     implements(IPushProducer)
414     check_crypttext_hash = True
415     check_plaintext_hash = True
416     _status = None
417
418     def __init__(self, client, u, downloadable):
419         self._client = client
420
421         u = IFileURI(u)
422         self._storage_index = u.storage_index
423         self._uri_extension_hash = u.uri_extension_hash
424         self._total_shares = u.total_shares
425         self._size = u.size
426         self._num_needed_shares = u.needed_shares
427
428         self._si_s = storage.si_b2a(self._storage_index)
429         self.init_logging()
430
431         self._started = time.time()
432         self._status = s = DownloadStatus()
433         s.set_status("Starting")
434         s.set_storage_index(self._storage_index)
435         s.set_size(self._size)
436         s.set_helper(False)
437         s.set_active(True)
438
439         self._results = DownloadResults()
440         s.set_results(self._results)
441         self._results.file_size = self._size
442         self._results.timings["servers_peer_selection"] = {}
443         self._results.timings["fetch_per_server"] = {}
444         self._results.timings["cumulative_fetch"] = 0.0
445         self._results.timings["cumulative_decode"] = 0.0
446         self._results.timings["cumulative_decrypt"] = 0.0
447
448         if IConsumer.providedBy(downloadable):
449             downloadable.registerProducer(self, True)
450         self._downloadable = downloadable
451         self._output = Output(downloadable, u.key, self._size, self._log_number,
452                               self._status)
453         self._paused = False
454         self._stopped = False
455
456         self.active_buckets = {} # k: shnum, v: bucket
457         self._share_buckets = [] # list of (sharenum, bucket) tuples
458         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
459         self._uri_extension_sources = []
460
461         self._uri_extension_data = None
462
463         self._fetch_failures = {"uri_extension": 0,
464                                 "plaintext_hashroot": 0,
465                                 "plaintext_hashtree": 0,
466                                 "crypttext_hashroot": 0,
467                                 "crypttext_hashtree": 0,
468                                 }
469
470     def init_logging(self):
471         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
472         num = self._client.log(format="FileDownloader(%(si)s): starting",
473                                si=storage.si_b2a(self._storage_index))
474         self._log_number = num
475
476     def log(self, *args, **kwargs):
477         if "parent" not in kwargs:
478             kwargs["parent"] = self._log_number
479         if "facility" not in kwargs:
480             kwargs["facility"] = "tahoe.download"
481         return log.msg(*args, **kwargs)
482
483     def pauseProducing(self):
484         if self._paused:
485             return
486         self._paused = defer.Deferred()
487         if self._status:
488             self._status.set_paused(True)
489
490     def resumeProducing(self):
491         if self._paused:
492             p = self._paused
493             self._paused = None
494             eventually(p.callback, None)
495             if self._status:
496                 self._status.set_paused(False)
497
498     def stopProducing(self):
499         self.log("Download.stopProducing")
500         self._stopped = True
501         if self._status:
502             self._status.set_stopped(True)
503             self._status.set_active(False)
504
505     def start(self):
506         self.log("starting download")
507
508         # first step: who should we download from?
509         d = defer.maybeDeferred(self._get_all_shareholders)
510         d.addCallback(self._got_all_shareholders)
511         # now get the uri_extension block from somebody and validate it
512         d.addCallback(self._obtain_uri_extension)
513         d.addCallback(self._got_uri_extension)
514         d.addCallback(self._get_hashtrees)
515         d.addCallback(self._create_validated_buckets)
516         # once we know that, we can download blocks from everybody
517         d.addCallback(self._download_all_segments)
518         def _finished(res):
519             if self._status:
520                 self._status.set_status("Finished")
521                 self._status.set_active(False)
522                 self._status.set_paused(False)
523             if IConsumer.providedBy(self._downloadable):
524                 self._downloadable.unregisterProducer()
525             return res
526         d.addBoth(_finished)
527         def _failed(why):
528             if self._status:
529                 self._status.set_status("Failed")
530                 self._status.set_active(False)
531             self._output.fail(why)
532             return why
533         d.addErrback(_failed)
534         d.addCallback(self._done)
535         return d
536
537     def _get_all_shareholders(self):
538         dl = []
539         for (peerid,ss) in self._client.get_permuted_peers("storage",
540                                                            self._storage_index):
541             d = ss.callRemote("get_buckets", self._storage_index)
542             d.addCallbacks(self._got_response, self._got_error,
543                            callbackArgs=(peerid,))
544             dl.append(d)
545         self._responses_received = 0
546         self._queries_sent = len(dl)
547         if self._status:
548             self._status.set_status("Locating Shares (%d/%d)" %
549                                     (self._responses_received,
550                                      self._queries_sent))
551         return defer.DeferredList(dl)
552
553     def _got_response(self, buckets, peerid):
554         self._responses_received += 1
555         if self._results:
556             elapsed = time.time() - self._started
557             self._results.timings["servers_peer_selection"][peerid] = elapsed
558         if self._status:
559             self._status.set_status("Locating Shares (%d/%d)" %
560                                     (self._responses_received,
561                                      self._queries_sent))
562         for sharenum, bucket in buckets.iteritems():
563             b = storage.ReadBucketProxy(bucket, peerid, self._si_s)
564             self.add_share_bucket(sharenum, b)
565             self._uri_extension_sources.append(b)
566             if self._results:
567                 if peerid not in self._results.servermap:
568                     self._results.servermap[peerid] = set()
569                 self._results.servermap[peerid].add(sharenum)
570
571     def add_share_bucket(self, sharenum, bucket):
572         # this is split out for the benefit of test_encode.py
573         self._share_buckets.append( (sharenum, bucket) )
574
575     def _got_error(self, f):
576         self._client.log("Somebody failed. -- %s" % (f,))
577
578     def bucket_failed(self, vbucket):
579         shnum = vbucket.sharenum
580         del self.active_buckets[shnum]
581         s = self._share_vbuckets[shnum]
582         # s is a set of ValidatedBucket instances
583         s.remove(vbucket)
584         # ... which might now be empty
585         if not s:
586             # there are no more buckets which can provide this share, so
587             # remove the key. This may prompt us to use a different share.
588             del self._share_vbuckets[shnum]
589
590     def _got_all_shareholders(self, res):
591         if self._results:
592             now = time.time()
593             self._results.timings["peer_selection"] = now - self._started
594
595         if len(self._share_buckets) < self._num_needed_shares:
596             raise NotEnoughPeersError
597
598         #for s in self._share_vbuckets.values():
599         #    for vb in s:
600         #        assert isinstance(vb, ValidatedBucket), \
601         #               "vb is %s but should be a ValidatedBucket" % (vb,)
602
603     def _unpack_uri_extension_data(self, data):
604         return uri.unpack_extension(data)
605
606     def _obtain_uri_extension(self, ignored):
607         # all shareholders are supposed to have a copy of uri_extension, and
608         # all are supposed to be identical. We compute the hash of the data
609         # that comes back, and compare it against the version in our URI. If
610         # they don't match, ignore their data and try someone else.
611         if self._status:
612             self._status.set_status("Obtaining URI Extension")
613
614         self._uri_extension_fetch_started = time.time()
615         def _validate(proposal, bucket):
616             h = hashutil.uri_extension_hash(proposal)
617             if h != self._uri_extension_hash:
618                 self._fetch_failures["uri_extension"] += 1
619                 msg = ("The copy of uri_extension we received from "
620                        "%s was bad: wanted %s, got %s" %
621                        (bucket,
622                         base32.b2a(self._uri_extension_hash),
623                         base32.b2a(h)))
624                 self.log(msg, level=log.SCARY)
625                 raise BadURIExtensionHashValue(msg)
626             return self._unpack_uri_extension_data(proposal)
627         return self._obtain_validated_thing(None,
628                                             self._uri_extension_sources,
629                                             "uri_extension",
630                                             "get_uri_extension", (), _validate)
631
632     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
633                                 validatorfunc):
634         if not sources:
635             raise NotEnoughPeersError("started with zero peers while fetching "
636                                       "%s" % name)
637         bucket = sources[0]
638         sources = sources[1:]
639         #d = bucket.callRemote(methname, *args)
640         d = bucket.startIfNecessary()
641         d.addCallback(lambda res: getattr(bucket, methname)(*args))
642         d.addCallback(validatorfunc, bucket)
643         def _bad(f):
644             self.log("%s from vbucket %s failed:" % (name, bucket),
645                      failure=f, level=log.WEIRD)
646             if not sources:
647                 raise NotEnoughPeersError("ran out of peers, last error was %s"
648                                           % (f,))
649             # try again with a different one
650             return self._obtain_validated_thing(None, sources, name,
651                                                 methname, args, validatorfunc)
652         d.addErrback(_bad)
653         return d
654
655     def _got_uri_extension(self, uri_extension_data):
656         if self._results:
657             elapsed = time.time() - self._uri_extension_fetch_started
658             self._results.timings["uri_extension"] = elapsed
659
660         d = self._uri_extension_data = uri_extension_data
661
662         self._codec = codec.get_decoder_by_name(d['codec_name'])
663         self._codec.set_serialized_params(d['codec_params'])
664         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
665         self._tail_codec.set_serialized_params(d['tail_codec_params'])
666
667         crypttext_hash = d['crypttext_hash']
668         assert isinstance(crypttext_hash, str)
669         assert len(crypttext_hash) == 32
670         self._crypttext_hash = crypttext_hash
671         self._plaintext_hash = d['plaintext_hash']
672         self._roothash = d['share_root_hash']
673
674         self._segment_size = segment_size = d['segment_size']
675         self._total_segments = mathutil.div_ceil(self._size, segment_size)
676         self._current_segnum = 0
677
678         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
679         self._share_hashtree.set_hashes({0: self._roothash})
680
681     def _get_hashtrees(self, res):
682         self._get_hashtrees_started = time.time()
683         if self._status:
684             self._status.set_status("Retrieving Hash Trees")
685         d = self._get_plaintext_hashtrees()
686         d.addCallback(self._get_crypttext_hashtrees)
687         d.addCallback(self._setup_hashtrees)
688         return d
689
690     def _get_plaintext_hashtrees(self):
691         def _validate_plaintext_hashtree(proposal, bucket):
692             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
693                 self._fetch_failures["plaintext_hashroot"] += 1
694                 msg = ("The copy of the plaintext_root_hash we received from"
695                        " %s was bad" % bucket)
696                 raise BadPlaintextHashValue(msg)
697             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
698             pt_hashes = dict(list(enumerate(proposal)))
699             try:
700                 pt_hashtree.set_hashes(pt_hashes)
701             except hashtree.BadHashError:
702                 # the hashes they gave us were not self-consistent, even
703                 # though the root matched what we saw in the uri_extension
704                 # block
705                 self._fetch_failures["plaintext_hashtree"] += 1
706                 raise
707             self._plaintext_hashtree = pt_hashtree
708         d = self._obtain_validated_thing(None,
709                                          self._uri_extension_sources,
710                                          "plaintext_hashes",
711                                          "get_plaintext_hashes", (),
712                                          _validate_plaintext_hashtree)
713         return d
714
715     def _get_crypttext_hashtrees(self, res):
716         def _validate_crypttext_hashtree(proposal, bucket):
717             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
718                 self._fetch_failures["crypttext_hashroot"] += 1
719                 msg = ("The copy of the crypttext_root_hash we received from"
720                        " %s was bad" % bucket)
721                 raise BadCrypttextHashValue(msg)
722             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
723             ct_hashes = dict(list(enumerate(proposal)))
724             try:
725                 ct_hashtree.set_hashes(ct_hashes)
726             except hashtree.BadHashError:
727                 self._fetch_failures["crypttext_hashtree"] += 1
728                 raise
729             ct_hashtree.set_hashes(ct_hashes)
730             self._crypttext_hashtree = ct_hashtree
731         d = self._obtain_validated_thing(None,
732                                          self._uri_extension_sources,
733                                          "crypttext_hashes",
734                                          "get_crypttext_hashes", (),
735                                          _validate_crypttext_hashtree)
736         return d
737
738     def _setup_hashtrees(self, res):
739         self._output.setup_hashtrees(self._plaintext_hashtree,
740                                      self._crypttext_hashtree)
741         if self._results:
742             elapsed = time.time() - self._get_hashtrees_started
743             self._results.timings["hashtrees"] = elapsed
744
745     def _create_validated_buckets(self, ignored=None):
746         self._share_vbuckets = {}
747         for sharenum, bucket in self._share_buckets:
748             vbucket = ValidatedBucket(sharenum, bucket,
749                                       self._share_hashtree,
750                                       self._roothash,
751                                       self._total_segments)
752             s = self._share_vbuckets.setdefault(sharenum, set())
753             s.add(vbucket)
754
755     def _activate_enough_buckets(self):
756         """either return a mapping from shnum to a ValidatedBucket that can
757         provide data for that share, or raise NotEnoughPeersError"""
758
759         while len(self.active_buckets) < self._num_needed_shares:
760             # need some more
761             handled_shnums = set(self.active_buckets.keys())
762             available_shnums = set(self._share_vbuckets.keys())
763             potential_shnums = list(available_shnums - handled_shnums)
764             if not potential_shnums:
765                 raise NotEnoughPeersError
766             # choose a random share
767             shnum = random.choice(potential_shnums)
768             # and a random bucket that will provide it
769             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
770             self.active_buckets[shnum] = validated_bucket
771         return self.active_buckets
772
773
774     def _download_all_segments(self, res):
775         # the promise: upon entry to this function, self._share_vbuckets
776         # contains enough buckets to complete the download, and some extra
777         # ones to tolerate some buckets dropping out or having errors.
778         # self._share_vbuckets is a dictionary that maps from shnum to a set
779         # of ValidatedBuckets, which themselves are wrappers around
780         # RIBucketReader references.
781         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
782
783         self._started_fetching = time.time()
784
785         d = defer.succeed(None)
786         for segnum in range(self._total_segments-1):
787             d.addCallback(self._download_segment, segnum)
788             # this pause, at the end of write, prevents pre-fetch from
789             # happening until the consumer is ready for more data.
790             d.addCallback(self._check_for_pause)
791         d.addCallback(self._download_tail_segment, self._total_segments-1)
792         return d
793
794     def _check_for_pause(self, res):
795         if self._paused:
796             d = defer.Deferred()
797             self._paused.addCallback(lambda ignored: d.callback(res))
798             return d
799         if self._stopped:
800             raise DownloadStopped("our Consumer called stopProducing()")
801         return res
802
803     def _download_segment(self, res, segnum):
804         if self._status:
805             self._status.set_status("Downloading segment %d of %d" %
806                                     (segnum+1, self._total_segments))
807         self.log("downloading seg#%d of %d (%d%%)"
808                  % (segnum, self._total_segments,
809                     100.0 * segnum / self._total_segments))
810         # memory footprint: when the SegmentDownloader finishes pulling down
811         # all shares, we have 1*segment_size of usage.
812         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
813                                         self._results)
814         started = time.time()
815         d = segmentdler.start()
816         def _finished_fetching(res):
817             elapsed = time.time() - started
818             self._results.timings["cumulative_fetch"] += elapsed
819             return res
820         if self._results:
821             d.addCallback(_finished_fetching)
822         # pause before using more memory
823         d.addCallback(self._check_for_pause)
824         # while the codec does its job, we hit 2*segment_size
825         def _started_decode(res):
826             self._started_decode = time.time()
827             return res
828         if self._results:
829             d.addCallback(_started_decode)
830         d.addCallback(lambda (shares, shareids):
831                       self._codec.decode(shares, shareids))
832         # once the codec is done, we drop back to 1*segment_size, because
833         # 'shares' goes out of scope. The memory usage is all in the
834         # plaintext now, spread out into a bunch of tiny buffers.
835         def _finished_decode(res):
836             elapsed = time.time() - self._started_decode
837             self._results.timings["cumulative_decode"] += elapsed
838             return res
839         if self._results:
840             d.addCallback(_finished_decode)
841
842         # pause/check-for-stop just before writing, to honor stopProducing
843         d.addCallback(self._check_for_pause)
844         def _done(buffers):
845             # we start by joining all these buffers together into a single
846             # string. This makes Output.write easier, since it wants to hash
847             # data one segment at a time anyways, and doesn't impact our
848             # memory footprint since we're already peaking at 2*segment_size
849             # inside the codec a moment ago.
850             segment = "".join(buffers)
851             del buffers
852             # we're down to 1*segment_size right now, but write_segment()
853             # will decrypt a copy of the segment internally, which will push
854             # us up to 2*segment_size while it runs.
855             started_decrypt = time.time()
856             self._output.write_segment(segment)
857             if self._results:
858                 elapsed = time.time() - started_decrypt
859                 self._results.timings["cumulative_decrypt"] += elapsed
860         d.addCallback(_done)
861         return d
862
863     def _download_tail_segment(self, res, segnum):
864         self.log("downloading seg#%d of %d (%d%%)"
865                  % (segnum, self._total_segments,
866                     100.0 * segnum / self._total_segments))
867         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
868                                         self._results)
869         started = time.time()
870         d = segmentdler.start()
871         def _finished_fetching(res):
872             elapsed = time.time() - started
873             self._results.timings["cumulative_fetch"] += elapsed
874             return res
875         if self._results:
876             d.addCallback(_finished_fetching)
877         # pause before using more memory
878         d.addCallback(self._check_for_pause)
879         def _started_decode(res):
880             self._started_decode = time.time()
881             return res
882         if self._results:
883             d.addCallback(_started_decode)
884         d.addCallback(lambda (shares, shareids):
885                       self._tail_codec.decode(shares, shareids))
886         def _finished_decode(res):
887             elapsed = time.time() - self._started_decode
888             self._results.timings["cumulative_decode"] += elapsed
889             return res
890         if self._results:
891             d.addCallback(_finished_decode)
892         # pause/check-for-stop just before writing, to honor stopProducing
893         d.addCallback(self._check_for_pause)
894         def _done(buffers):
895             # trim off any padding added by the upload side
896             segment = "".join(buffers)
897             del buffers
898             # we never send empty segments. If the data was an exact multiple
899             # of the segment size, the last segment will be full.
900             pad_size = mathutil.pad_size(self._size, self._segment_size)
901             tail_size = self._segment_size - pad_size
902             segment = segment[:tail_size]
903             started_decrypt = time.time()
904             self._output.write_segment(segment)
905             if self._results:
906                 elapsed = time.time() - started_decrypt
907                 self._results.timings["cumulative_decrypt"] += elapsed
908         d.addCallback(_done)
909         return d
910
911     def _done(self, res):
912         self.log("download done")
913         if self._results:
914             now = time.time()
915             self._results.timings["total"] = now - self._started
916             self._results.timings["segments"] = now - self._started_fetching
917         self._output.close()
918         if self.check_crypttext_hash:
919             _assert(self._crypttext_hash == self._output.crypttext_hash,
920                     "bad crypttext_hash: computed=%s, expected=%s" %
921                     (base32.b2a(self._output.crypttext_hash),
922                      base32.b2a(self._crypttext_hash)))
923         if self.check_plaintext_hash:
924             _assert(self._plaintext_hash == self._output.plaintext_hash,
925                     "bad plaintext_hash: computed=%s, expected=%s" %
926                     (base32.b2a(self._output.plaintext_hash),
927                      base32.b2a(self._plaintext_hash)))
928         _assert(self._output.length == self._size,
929                 got=self._output.length, expected=self._size)
930         return self._output.finish()
931
932     def get_download_status(self):
933         return self._status
934
935
936 class LiteralDownloader:
937     def __init__(self, client, u, downloadable):
938         self._uri = IFileURI(u)
939         assert isinstance(self._uri, uri.LiteralFileURI)
940         self._downloadable = downloadable
941         self._status = s = DownloadStatus()
942         s.set_storage_index(None)
943         s.set_helper(False)
944         s.set_status("Done")
945         s.set_active(False)
946         s.set_progress(1.0)
947
948     def start(self):
949         data = self._uri.data
950         self._status.set_size(len(data))
951         self._downloadable.open(len(data))
952         self._downloadable.write(data)
953         self._downloadable.close()
954         return defer.maybeDeferred(self._downloadable.finish)
955
956     def get_download_status(self):
957         return self._status
958
959 class FileName:
960     implements(IDownloadTarget)
961     def __init__(self, filename):
962         self._filename = filename
963         self.f = None
964     def open(self, size):
965         self.f = open(self._filename, "wb")
966         return self.f
967     def write(self, data):
968         self.f.write(data)
969     def close(self):
970         if self.f:
971             self.f.close()
972     def fail(self, why):
973         if self.f:
974             self.f.close()
975             os.unlink(self._filename)
976     def register_canceller(self, cb):
977         pass # we won't use it
978     def finish(self):
979         pass
980
981 class Data:
982     implements(IDownloadTarget)
983     def __init__(self):
984         self._data = []
985     def open(self, size):
986         pass
987     def write(self, data):
988         self._data.append(data)
989     def close(self):
990         self.data = "".join(self._data)
991         del self._data
992     def fail(self, why):
993         del self._data
994     def register_canceller(self, cb):
995         pass # we won't use it
996     def finish(self):
997         return self.data
998
999 class FileHandle:
1000     """Use me to download data to a pre-defined filehandle-like object. I
1001     will use the target's write() method. I will *not* close the filehandle:
1002     I leave that up to the originator of the filehandle. The download process
1003     will return the filehandle when it completes.
1004     """
1005     implements(IDownloadTarget)
1006     def __init__(self, filehandle):
1007         self._filehandle = filehandle
1008     def open(self, size):
1009         pass
1010     def write(self, data):
1011         self._filehandle.write(data)
1012     def close(self):
1013         # the originator of the filehandle reserves the right to close it
1014         pass
1015     def fail(self, why):
1016         pass
1017     def register_canceller(self, cb):
1018         pass
1019     def finish(self):
1020         return self._filehandle
1021
1022 class Downloader(service.MultiService):
1023     """I am a service that allows file downloading.
1024     """
1025     implements(IDownloader)
1026     name = "downloader"
1027     MAX_DOWNLOAD_STATUSES = 10
1028
1029     def __init__(self):
1030         service.MultiService.__init__(self)
1031         self._all_downloads = weakref.WeakKeyDictionary()
1032         self._recent_download_status = []
1033
1034     def download(self, u, t):
1035         assert self.parent
1036         assert self.running
1037         u = IFileURI(u)
1038         t = IDownloadTarget(t)
1039         assert t.write
1040         assert t.close
1041         if isinstance(u, uri.LiteralFileURI):
1042             dl = LiteralDownloader(self.parent, u, t)
1043         elif isinstance(u, uri.CHKFileURI):
1044             dl = FileDownloader(self.parent, u, t)
1045         else:
1046             raise RuntimeError("I don't know how to download a %s" % u)
1047         self._all_downloads[dl] = None
1048         self._recent_download_status.append(dl.get_download_status())
1049         while len(self._recent_download_status) > self.MAX_DOWNLOAD_STATUSES:
1050             self._recent_download_status.pop(0)
1051         d = dl.start()
1052         return d
1053
1054     # utility functions
1055     def download_to_data(self, uri):
1056         return self.download(uri, Data())
1057     def download_to_filename(self, uri, filename):
1058         return self.download(uri, FileName(filename))
1059     def download_to_filehandle(self, uri, filehandle):
1060         return self.download(uri, FileHandle(filehandle))
1061
1062
1063     def list_all_downloads(self):
1064         return self._all_downloads.keys()
1065     def list_active_downloads(self):
1066         return [d.get_download_status() for d in self._all_downloads.keys()
1067                 if d.get_download_status().get_active()]
1068     def list_recent_downloads(self):
1069         return self._recent_download_status