]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/immutable/download.py
80cd2460d176959509ee9b139c212ca0d0648392
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap import DeadReferenceError
8 from foolscap.eventual import eventually
9
10 from allmydata.util import base32, mathutil, hashutil, log, observer
11 from allmydata.util.assertutil import _assert
12 from allmydata import codec, hashtree, storage, uri
13 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, \
14      IDownloadStatus, IDownloadResults, NotEnoughSharesError
15 from allmydata.immutable import layout
16 from pycryptopp.cipher.aes import AES
17
18 class HaveAllPeersError(Exception):
19     # we use this to jump out of the loop
20     pass
21
22 class IntegrityCheckError(Exception):
23     pass
24
25 class BadURIExtensionHashValue(IntegrityCheckError):
26     pass
27 class BadURIExtension(IntegrityCheckError):
28     pass
29 class BadPlaintextHashValue(IntegrityCheckError):
30     pass
31 class BadCrypttextHashValue(IntegrityCheckError):
32     pass
33
34 class DownloadStopped(Exception):
35     pass
36
37 class DownloadResults:
38     implements(IDownloadResults)
39
40     def __init__(self):
41         self.servers_used = set()
42         self.server_problems = {}
43         self.servermap = {}
44         self.timings = {}
45         self.file_size = None
46
47 class Output:
48     def __init__(self, downloadable, key, total_length, log_parent,
49                  download_status):
50         self.downloadable = downloadable
51         self._decryptor = AES(key)
52         self._crypttext_hasher = hashutil.crypttext_hasher()
53         self._plaintext_hasher = hashutil.plaintext_hasher()
54         self.length = 0
55         self.total_length = total_length
56         self._segment_number = 0
57         self._plaintext_hash_tree = None
58         self._crypttext_hash_tree = None
59         self._opened = False
60         self._log_parent = log_parent
61         self._status = download_status
62         self._status.set_progress(0.0)
63
64     def log(self, *args, **kwargs):
65         if "parent" not in kwargs:
66             kwargs["parent"] = self._log_parent
67         if "facility" not in kwargs:
68             kwargs["facility"] = "download.output"
69         return log.msg(*args, **kwargs)
70
71     def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
72         self._plaintext_hash_tree = plaintext_hashtree
73         self._crypttext_hash_tree = crypttext_hashtree
74
75     def write_segment(self, crypttext):
76         self.length += len(crypttext)
77         self._status.set_progress( float(self.length) / self.total_length )
78
79         # memory footprint: 'crypttext' is the only segment_size usage
80         # outstanding. While we decrypt it into 'plaintext', we hit
81         # 2*segment_size.
82         self._crypttext_hasher.update(crypttext)
83         if self._crypttext_hash_tree:
84             ch = hashutil.crypttext_segment_hasher()
85             ch.update(crypttext)
86             crypttext_leaves = {self._segment_number: ch.digest()}
87             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
88                      bytes=len(crypttext),
89                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
90                      level=log.NOISY)
91             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
92
93         plaintext = self._decryptor.process(crypttext)
94         del crypttext
95
96         # now we're back down to 1*segment_size.
97
98         self._plaintext_hasher.update(plaintext)
99         if self._plaintext_hash_tree:
100             ph = hashutil.plaintext_segment_hasher()
101             ph.update(plaintext)
102             plaintext_leaves = {self._segment_number: ph.digest()}
103             self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
104                      bytes=len(plaintext),
105                      segnum=self._segment_number, hash=base32.b2a(ph.digest()),
106                      level=log.NOISY)
107             self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
108
109         self._segment_number += 1
110         # We're still at 1*segment_size. The Downloadable is responsible for
111         # any memory usage beyond this.
112         if not self._opened:
113             self._opened = True
114             self.downloadable.open(self.total_length)
115         self.downloadable.write(plaintext)
116
117     def fail(self, why):
118         # this is really unusual, and deserves maximum forensics
119         if why.check(DownloadStopped):
120             # except DownloadStopped just means the consumer aborted the
121             # download, not so scary
122             self.log("download stopped", level=log.UNUSUAL)
123         else:
124             self.log("download failed!", failure=why,
125                      level=log.SCARY, umid="lp1vaQ")
126         self.downloadable.fail(why)
127
128     def close(self):
129         self.crypttext_hash = self._crypttext_hasher.digest()
130         self.plaintext_hash = self._plaintext_hasher.digest()
131         self.log("download finished, closing IDownloadable", level=log.NOISY)
132         self.downloadable.close()
133
134     def finish(self):
135         return self.downloadable.finish()
136
137 class ValidatedBucket:
138     """I am a front-end for a remote storage bucket, responsible for
139     retrieving and validating data from that bucket.
140
141     My get_block() method is used by BlockDownloaders.
142     """
143
144     def __init__(self, sharenum, bucket,
145                  share_hash_tree, roothash,
146                  num_blocks):
147         self.sharenum = sharenum
148         self.bucket = bucket
149         self._share_hash = None # None means not validated yet
150         self.share_hash_tree = share_hash_tree
151         self._roothash = roothash
152         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
153         self.started = False
154
155     def get_block(self, blocknum):
156         if not self.started:
157             d = self.bucket.start()
158             def _started(res):
159                 self.started = True
160                 return self.get_block(blocknum)
161             d.addCallback(_started)
162             return d
163
164         # the first time we use this bucket, we need to fetch enough elements
165         # of the share hash tree to validate it from our share hash up to the
166         # hashroot.
167         if not self._share_hash:
168             d1 = self.bucket.get_share_hashes()
169         else:
170             d1 = defer.succeed([])
171
172         # we might need to grab some elements of our block hash tree, to
173         # validate the requested block up to the share hash
174         needed = self.block_hash_tree.needed_hashes(blocknum)
175         if needed:
176             # TODO: get fewer hashes, use get_block_hashes(needed)
177             d2 = self.bucket.get_block_hashes()
178         else:
179             d2 = defer.succeed([])
180
181         d3 = self.bucket.get_block(blocknum)
182
183         d = defer.gatherResults([d1, d2, d3])
184         d.addCallback(self._got_data, blocknum)
185         return d
186
187     def _got_data(self, res, blocknum):
188         sharehashes, blockhashes, blockdata = res
189         blockhash = None # to make logging it safe
190
191         try:
192             if not self._share_hash:
193                 sh = dict(sharehashes)
194                 sh[0] = self._roothash # always use our own root, from the URI
195                 sht = self.share_hash_tree
196                 if sht.get_leaf_index(self.sharenum) not in sh:
197                     raise hashtree.NotEnoughHashesError
198                 sht.set_hashes(sh)
199                 self._share_hash = sht.get_leaf(self.sharenum)
200
201             blockhash = hashutil.block_hash(blockdata)
202             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
203             #        "%r .. %r: %s" %
204             #        (self.sharenum, blocknum, len(blockdata),
205             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
206
207             # we always validate the blockhash
208             bh = dict(enumerate(blockhashes))
209             # replace blockhash root with validated value
210             bh[0] = self._share_hash
211             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
212
213         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
214             # log.WEIRD: indicates undetected disk/network error, or more
215             # likely a programming error
216             log.msg("hash failure in block=%d, shnum=%d on %s" %
217                     (blocknum, self.sharenum, self.bucket))
218             if self._share_hash:
219                 log.msg(""" failure occurred when checking the block_hash_tree.
220                 This suggests that either the block data was bad, or that the
221                 block hashes we received along with it were bad.""")
222             else:
223                 log.msg(""" the failure probably occurred when checking the
224                 share_hash_tree, which suggests that the share hashes we
225                 received from the remote peer were bad.""")
226             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
227             log.msg(" block length: %d" % len(blockdata))
228             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
229             if len(blockdata) < 100:
230                 log.msg(" block data: %r" % (blockdata,))
231             else:
232                 log.msg(" block data start/end: %r .. %r" %
233                         (blockdata[:50], blockdata[-50:]))
234             log.msg(" root hash: %s" % base32.b2a(self._roothash))
235             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
236             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
237             lines = []
238             for i,h in sorted(sharehashes):
239                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
240             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
241             lines = []
242             for i,h in enumerate(blockhashes):
243                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
244             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
245             raise
246
247         # If we made it here, the block is good. If the hash trees didn't
248         # like what they saw, they would have raised a BadHashError, causing
249         # our caller to see a Failure and thus ignore this block (as well as
250         # dropping this bucket).
251         return blockdata
252
253
254
255 class BlockDownloader:
256     """I am responsible for downloading a single block (from a single bucket)
257     for a single segment.
258
259     I am a child of the SegmentDownloader.
260     """
261
262     def __init__(self, vbucket, blocknum, parent, results):
263         self.vbucket = vbucket
264         self.blocknum = blocknum
265         self.parent = parent
266         self.results = results
267         self._log_number = self.parent.log("starting block %d" % blocknum)
268
269     def log(self, *args, **kwargs):
270         if "parent" not in kwargs:
271             kwargs["parent"] = self._log_number
272         return self.parent.log(*args, **kwargs)
273
274     def start(self, segnum):
275         lognum = self.log("get_block(segnum=%d)" % segnum)
276         started = time.time()
277         d = self.vbucket.get_block(segnum)
278         d.addCallbacks(self._hold_block, self._got_block_error,
279                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
280         return d
281
282     def _hold_block(self, data, started, lognum):
283         if self.results:
284             elapsed = time.time() - started
285             peerid = self.vbucket.bucket.get_peerid()
286             if peerid not in self.results.timings["fetch_per_server"]:
287                 self.results.timings["fetch_per_server"][peerid] = []
288             self.results.timings["fetch_per_server"][peerid].append(elapsed)
289         self.log("got block", parent=lognum)
290         self.parent.hold_block(self.blocknum, data)
291
292     def _got_block_error(self, f, lognum):
293         level = log.WEIRD
294         if f.check(DeadReferenceError):
295             level = log.UNUSUAL
296         self.log("BlockDownloader[%d] got error" % self.blocknum,
297                  failure=f, level=level, parent=lognum, umid="5Z4uHQ")
298         if self.results:
299             peerid = self.vbucket.bucket.get_peerid()
300             self.results.server_problems[peerid] = str(f)
301         self.parent.bucket_failed(self.vbucket)
302
303 class SegmentDownloader:
304     """I am responsible for downloading all the blocks for a single segment
305     of data.
306
307     I am a child of the FileDownloader.
308     """
309
310     def __init__(self, parent, segmentnumber, needed_shares, results):
311         self.parent = parent
312         self.segmentnumber = segmentnumber
313         self.needed_blocks = needed_shares
314         self.blocks = {} # k: blocknum, v: data
315         self.results = results
316         self._log_number = self.parent.log("starting segment %d" %
317                                            segmentnumber)
318
319     def log(self, *args, **kwargs):
320         if "parent" not in kwargs:
321             kwargs["parent"] = self._log_number
322         return self.parent.log(*args, **kwargs)
323
324     def start(self):
325         return self._download()
326
327     def _download(self):
328         d = self._try()
329         def _done(res):
330             if len(self.blocks) >= self.needed_blocks:
331                 # we only need self.needed_blocks blocks
332                 # we want to get the smallest blockids, because they are
333                 # more likely to be fast "primary blocks"
334                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
335                 blocks = []
336                 for blocknum in blockids:
337                     blocks.append(self.blocks[blocknum])
338                 return (blocks, blockids)
339             else:
340                 return self._download()
341         d.addCallback(_done)
342         return d
343
344     def _try(self):
345         # fill our set of active buckets, maybe raising NotEnoughSharesError
346         active_buckets = self.parent._activate_enough_buckets()
347         # Now we have enough buckets, in self.parent.active_buckets.
348
349         # in test cases, bd.start might mutate active_buckets right away, so
350         # we need to put off calling start() until we've iterated all the way
351         # through it.
352         downloaders = []
353         for blocknum, vbucket in active_buckets.iteritems():
354             bd = BlockDownloader(vbucket, blocknum, self, self.results)
355             downloaders.append(bd)
356             if self.results:
357                 self.results.servers_used.add(vbucket.bucket.get_peerid())
358         l = [bd.start(self.segmentnumber) for bd in downloaders]
359         return defer.DeferredList(l, fireOnOneErrback=True)
360
361     def hold_block(self, blocknum, data):
362         self.blocks[blocknum] = data
363
364     def bucket_failed(self, vbucket):
365         self.parent.bucket_failed(vbucket)
366
367 class DownloadStatus:
368     implements(IDownloadStatus)
369     statusid_counter = itertools.count(0)
370
371     def __init__(self):
372         self.storage_index = None
373         self.size = None
374         self.helper = False
375         self.status = "Not started"
376         self.progress = 0.0
377         self.paused = False
378         self.stopped = False
379         self.active = True
380         self.results = None
381         self.counter = self.statusid_counter.next()
382         self.started = time.time()
383
384     def get_started(self):
385         return self.started
386     def get_storage_index(self):
387         return self.storage_index
388     def get_size(self):
389         return self.size
390     def using_helper(self):
391         return self.helper
392     def get_status(self):
393         status = self.status
394         if self.paused:
395             status += " (output paused)"
396         if self.stopped:
397             status += " (output stopped)"
398         return status
399     def get_progress(self):
400         return self.progress
401     def get_active(self):
402         return self.active
403     def get_results(self):
404         return self.results
405     def get_counter(self):
406         return self.counter
407
408     def set_storage_index(self, si):
409         self.storage_index = si
410     def set_size(self, size):
411         self.size = size
412     def set_helper(self, helper):
413         self.helper = helper
414     def set_status(self, status):
415         self.status = status
416     def set_paused(self, paused):
417         self.paused = paused
418     def set_stopped(self, stopped):
419         self.stopped = stopped
420     def set_progress(self, value):
421         self.progress = value
422     def set_active(self, value):
423         self.active = value
424     def set_results(self, value):
425         self.results = value
426
427 class FileDownloader:
428     implements(IPushProducer)
429     check_crypttext_hash = True
430     check_plaintext_hash = True
431     _status = None
432
433     def __init__(self, client, u, downloadable):
434         self._client = client
435
436         u = IFileURI(u)
437         self._storage_index = u.storage_index
438         self._uri_extension_hash = u.uri_extension_hash
439         self._total_shares = u.total_shares
440         self._size = u.size
441         self._num_needed_shares = u.needed_shares
442
443         self._si_s = storage.si_b2a(self._storage_index)
444         self.init_logging()
445
446         self._started = time.time()
447         self._status = s = DownloadStatus()
448         s.set_status("Starting")
449         s.set_storage_index(self._storage_index)
450         s.set_size(self._size)
451         s.set_helper(False)
452         s.set_active(True)
453
454         self._results = DownloadResults()
455         s.set_results(self._results)
456         self._results.file_size = self._size
457         self._results.timings["servers_peer_selection"] = {}
458         self._results.timings["fetch_per_server"] = {}
459         self._results.timings["cumulative_fetch"] = 0.0
460         self._results.timings["cumulative_decode"] = 0.0
461         self._results.timings["cumulative_decrypt"] = 0.0
462         self._results.timings["paused"] = 0.0
463
464         self._paused = False
465         self._stopped = False
466         if IConsumer.providedBy(downloadable):
467             downloadable.registerProducer(self, True)
468         self._downloadable = downloadable
469         self._output = Output(downloadable, u.key, self._size, self._log_number,
470                               self._status)
471
472         self.active_buckets = {} # k: shnum, v: bucket
473         self._share_buckets = [] # list of (sharenum, bucket) tuples
474         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
475         self._uri_extension_sources = []
476
477         self._uri_extension_data = None
478
479         self._fetch_failures = {"uri_extension": 0,
480                                 "plaintext_hashroot": 0,
481                                 "plaintext_hashtree": 0,
482                                 "crypttext_hashroot": 0,
483                                 "crypttext_hashtree": 0,
484                                 }
485
486     def init_logging(self):
487         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
488         num = self._client.log(format="FileDownloader(%(si)s): starting",
489                                si=storage.si_b2a(self._storage_index))
490         self._log_number = num
491
492     def log(self, *args, **kwargs):
493         if "parent" not in kwargs:
494             kwargs["parent"] = self._log_number
495         if "facility" not in kwargs:
496             kwargs["facility"] = "tahoe.download"
497         return log.msg(*args, **kwargs)
498
499     def pauseProducing(self):
500         if self._paused:
501             return
502         self._paused = defer.Deferred()
503         self._paused_at = time.time()
504         if self._status:
505             self._status.set_paused(True)
506
507     def resumeProducing(self):
508         if self._paused:
509             paused_for = time.time() - self._paused_at
510             self._results.timings['paused'] += paused_for
511             p = self._paused
512             self._paused = None
513             eventually(p.callback, None)
514             if self._status:
515                 self._status.set_paused(False)
516
517     def stopProducing(self):
518         self.log("Download.stopProducing")
519         self._stopped = True
520         self.resumeProducing()
521         if self._status:
522             self._status.set_stopped(True)
523             self._status.set_active(False)
524
525     def start(self):
526         self.log("starting download")
527
528         # first step: who should we download from?
529         d = defer.maybeDeferred(self._get_all_shareholders)
530         d.addCallback(self._got_all_shareholders)
531         # now get the uri_extension block from somebody and validate it
532         d.addCallback(self._obtain_uri_extension)
533         d.addCallback(self._got_uri_extension)
534         d.addCallback(self._get_hashtrees)
535         d.addCallback(self._create_validated_buckets)
536         # once we know that, we can download blocks from everybody
537         d.addCallback(self._download_all_segments)
538         def _finished(res):
539             if self._status:
540                 self._status.set_status("Finished")
541                 self._status.set_active(False)
542                 self._status.set_paused(False)
543             if IConsumer.providedBy(self._downloadable):
544                 self._downloadable.unregisterProducer()
545             return res
546         d.addBoth(_finished)
547         def _failed(why):
548             if self._status:
549                 self._status.set_status("Failed")
550                 self._status.set_active(False)
551             self._output.fail(why)
552             return why
553         d.addErrback(_failed)
554         d.addCallback(self._done)
555         return d
556
557     def _get_all_shareholders(self):
558         dl = []
559         for (peerid,ss) in self._client.get_permuted_peers("storage",
560                                                            self._storage_index):
561             d = ss.callRemote("get_buckets", self._storage_index)
562             d.addCallbacks(self._got_response, self._got_error,
563                            callbackArgs=(peerid,))
564             dl.append(d)
565         self._responses_received = 0
566         self._queries_sent = len(dl)
567         if self._status:
568             self._status.set_status("Locating Shares (%d/%d)" %
569                                     (self._responses_received,
570                                      self._queries_sent))
571         return defer.DeferredList(dl)
572
573     def _got_response(self, buckets, peerid):
574         self._responses_received += 1
575         if self._results:
576             elapsed = time.time() - self._started
577             self._results.timings["servers_peer_selection"][peerid] = elapsed
578         if self._status:
579             self._status.set_status("Locating Shares (%d/%d)" %
580                                     (self._responses_received,
581                                      self._queries_sent))
582         for sharenum, bucket in buckets.iteritems():
583             b = layout.ReadBucketProxy(bucket, peerid, self._si_s)
584             self.add_share_bucket(sharenum, b)
585             self._uri_extension_sources.append(b)
586             if self._results:
587                 if peerid not in self._results.servermap:
588                     self._results.servermap[peerid] = set()
589                 self._results.servermap[peerid].add(sharenum)
590
591     def add_share_bucket(self, sharenum, bucket):
592         # this is split out for the benefit of test_encode.py
593         self._share_buckets.append( (sharenum, bucket) )
594
595     def _got_error(self, f):
596         level = log.WEIRD
597         if f.check(DeadReferenceError):
598             level = log.UNUSUAL
599         self._client.log("Error during get_buckets", failure=f, level=level,
600                          umid="3uuBUQ")
601
602     def bucket_failed(self, vbucket):
603         shnum = vbucket.sharenum
604         del self.active_buckets[shnum]
605         s = self._share_vbuckets[shnum]
606         # s is a set of ValidatedBucket instances
607         s.remove(vbucket)
608         # ... which might now be empty
609         if not s:
610             # there are no more buckets which can provide this share, so
611             # remove the key. This may prompt us to use a different share.
612             del self._share_vbuckets[shnum]
613
614     def _got_all_shareholders(self, res):
615         if self._results:
616             now = time.time()
617             self._results.timings["peer_selection"] = now - self._started
618
619         if len(self._share_buckets) < self._num_needed_shares:
620             raise NotEnoughSharesError
621
622         #for s in self._share_vbuckets.values():
623         #    for vb in s:
624         #        assert isinstance(vb, ValidatedBucket), \
625         #               "vb is %s but should be a ValidatedBucket" % (vb,)
626
627     def _unpack_uri_extension_data(self, data):
628         return uri.unpack_extension(data)
629
630     def _obtain_uri_extension(self, ignored):
631         # all shareholders are supposed to have a copy of uri_extension, and
632         # all are supposed to be identical. We compute the hash of the data
633         # that comes back, and compare it against the version in our URI. If
634         # they don't match, ignore their data and try someone else.
635         if self._status:
636             self._status.set_status("Obtaining URI Extension")
637
638         self._uri_extension_fetch_started = time.time()
639         def _validate(proposal, bucket):
640             h = hashutil.uri_extension_hash(proposal)
641             if h != self._uri_extension_hash:
642                 self._fetch_failures["uri_extension"] += 1
643                 msg = ("The copy of uri_extension we received from "
644                        "%s was bad: wanted %s, got %s" %
645                        (bucket,
646                         base32.b2a(self._uri_extension_hash),
647                         base32.b2a(h)))
648                 self.log(msg, level=log.SCARY, umid="jnkTtQ")
649                 raise BadURIExtensionHashValue(msg)
650             return self._unpack_uri_extension_data(proposal)
651         return self._obtain_validated_thing(None,
652                                             self._uri_extension_sources,
653                                             "uri_extension",
654                                             "get_uri_extension", (), _validate)
655
656     def _obtain_validated_thing(self, ignored, sources, name, methname, args,
657                                 validatorfunc):
658         if not sources:
659             raise NotEnoughSharesError("started with zero peers while fetching "
660                                       "%s" % name)
661         bucket = sources[0]
662         sources = sources[1:]
663         #d = bucket.callRemote(methname, *args)
664         d = bucket.startIfNecessary()
665         d.addCallback(lambda res: getattr(bucket, methname)(*args))
666         d.addCallback(validatorfunc, bucket)
667         def _bad(f):
668             level = log.WEIRD
669             if f.check(DeadReferenceError):
670                 level = log.UNUSUAL
671             self.log(format="operation %(op)s from vbucket %(vbucket)s failed",
672                      op=name, vbucket=str(bucket),
673                      failure=f, level=level, umid="JGXxBA")
674             if not sources:
675                 raise NotEnoughSharesError("ran out of peers, last error was %s"
676                                           % (f,))
677             # try again with a different one
678             return self._obtain_validated_thing(None, sources, name,
679                                                 methname, args, validatorfunc)
680         d.addErrback(_bad)
681         return d
682
683     def _got_uri_extension(self, uri_extension_data):
684         if self._results:
685             elapsed = time.time() - self._uri_extension_fetch_started
686             self._results.timings["uri_extension"] = elapsed
687
688         d = self._uri_extension_data = uri_extension_data
689
690         self._codec = codec.get_decoder_by_name(d['codec_name'])
691         self._codec.set_serialized_params(d['codec_params'])
692         self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
693         self._tail_codec.set_serialized_params(d['tail_codec_params'])
694
695         crypttext_hash = d.get('crypttext_hash', None) # optional
696         if crypttext_hash:
697             assert isinstance(crypttext_hash, str)
698             assert len(crypttext_hash) == 32
699         self._crypttext_hash = crypttext_hash
700         self._plaintext_hash = d.get('plaintext_hash', None) # optional
701
702         self._roothash = d['share_root_hash']
703
704         self._segment_size = segment_size = d['segment_size']
705         self._total_segments = mathutil.div_ceil(self._size, segment_size)
706         self._current_segnum = 0
707
708         self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
709         self._share_hashtree.set_hashes({0: self._roothash})
710
711     def _get_hashtrees(self, res):
712         self._get_hashtrees_started = time.time()
713         if self._status:
714             self._status.set_status("Retrieving Hash Trees")
715         d = defer.maybeDeferred(self._get_plaintext_hashtrees)
716         d.addCallback(self._get_crypttext_hashtrees)
717         d.addCallback(self._setup_hashtrees)
718         return d
719
720     def _get_plaintext_hashtrees(self):
721         # plaintext hashes are optional. If the root isn't in the UEB, then
722         # the share will be holding an empty list. We don't even bother
723         # fetching it.
724         if "plaintext_root_hash" not in self._uri_extension_data:
725             self._plaintext_hashtree = None
726             return
727         def _validate_plaintext_hashtree(proposal, bucket):
728             if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
729                 self._fetch_failures["plaintext_hashroot"] += 1
730                 msg = ("The copy of the plaintext_root_hash we received from"
731                        " %s was bad" % bucket)
732                 raise BadPlaintextHashValue(msg)
733             pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
734             pt_hashes = dict(list(enumerate(proposal)))
735             try:
736                 pt_hashtree.set_hashes(pt_hashes)
737             except hashtree.BadHashError:
738                 # the hashes they gave us were not self-consistent, even
739                 # though the root matched what we saw in the uri_extension
740                 # block
741                 self._fetch_failures["plaintext_hashtree"] += 1
742                 raise
743             self._plaintext_hashtree = pt_hashtree
744         d = self._obtain_validated_thing(None,
745                                          self._uri_extension_sources,
746                                          "plaintext_hashes",
747                                          "get_plaintext_hashes", (),
748                                          _validate_plaintext_hashtree)
749         return d
750
751     def _get_crypttext_hashtrees(self, res):
752         # Ciphertext hash tree root is mandatory, so that there is at
753         # most one ciphertext that matches this read-cap or
754         # verify-cap.  The integrity check on the shares is not
755         # sufficient to prevent the original encoder from creating
756         # some shares of file A and other shares of file B.
757         if "crypttext_root_hash" not in self._uri_extension_data:
758             raise BadURIExtension("URI Extension block did not have the ciphertext hash tree root")
759         def _validate_crypttext_hashtree(proposal, bucket):
760             if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
761                 self._fetch_failures["crypttext_hashroot"] += 1
762                 msg = ("The copy of the crypttext_root_hash we received from"
763                        " %s was bad" % bucket)
764                 raise BadCrypttextHashValue(msg)
765             ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
766             ct_hashes = dict(list(enumerate(proposal)))
767             try:
768                 ct_hashtree.set_hashes(ct_hashes)
769             except hashtree.BadHashError:
770                 self._fetch_failures["crypttext_hashtree"] += 1
771                 raise
772             ct_hashtree.set_hashes(ct_hashes)
773             self._crypttext_hashtree = ct_hashtree
774         d = self._obtain_validated_thing(None,
775                                          self._uri_extension_sources,
776                                          "crypttext_hashes",
777                                          "get_crypttext_hashes", (),
778                                          _validate_crypttext_hashtree)
779         return d
780
781     def _setup_hashtrees(self, res):
782         self._output.setup_hashtrees(self._plaintext_hashtree,
783                                      self._crypttext_hashtree)
784         if self._results:
785             elapsed = time.time() - self._get_hashtrees_started
786             self._results.timings["hashtrees"] = elapsed
787
788     def _create_validated_buckets(self, ignored=None):
789         self._share_vbuckets = {}
790         for sharenum, bucket in self._share_buckets:
791             vbucket = ValidatedBucket(sharenum, bucket,
792                                       self._share_hashtree,
793                                       self._roothash,
794                                       self._total_segments)
795             s = self._share_vbuckets.setdefault(sharenum, set())
796             s.add(vbucket)
797
798     def _activate_enough_buckets(self):
799         """either return a mapping from shnum to a ValidatedBucket that can
800         provide data for that share, or raise NotEnoughSharesError"""
801
802         while len(self.active_buckets) < self._num_needed_shares:
803             # need some more
804             handled_shnums = set(self.active_buckets.keys())
805             available_shnums = set(self._share_vbuckets.keys())
806             potential_shnums = list(available_shnums - handled_shnums)
807             if not potential_shnums:
808                 raise NotEnoughSharesError
809             # choose a random share
810             shnum = random.choice(potential_shnums)
811             # and a random bucket that will provide it
812             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
813             self.active_buckets[shnum] = validated_bucket
814         return self.active_buckets
815
816
817     def _download_all_segments(self, res):
818         # the promise: upon entry to this function, self._share_vbuckets
819         # contains enough buckets to complete the download, and some extra
820         # ones to tolerate some buckets dropping out or having errors.
821         # self._share_vbuckets is a dictionary that maps from shnum to a set
822         # of ValidatedBuckets, which themselves are wrappers around
823         # RIBucketReader references.
824         self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
825
826         self._started_fetching = time.time()
827
828         d = defer.succeed(None)
829         for segnum in range(self._total_segments-1):
830             d.addCallback(self._download_segment, segnum)
831             # this pause, at the end of write, prevents pre-fetch from
832             # happening until the consumer is ready for more data.
833             d.addCallback(self._check_for_pause)
834         d.addCallback(self._download_tail_segment, self._total_segments-1)
835         return d
836
837     def _check_for_pause(self, res):
838         if self._paused:
839             d = defer.Deferred()
840             self._paused.addCallback(lambda ignored: d.callback(res))
841             return d
842         if self._stopped:
843             raise DownloadStopped("our Consumer called stopProducing()")
844         return res
845
846     def _download_segment(self, res, segnum):
847         if self._status:
848             self._status.set_status("Downloading segment %d of %d" %
849                                     (segnum+1, self._total_segments))
850         self.log("downloading seg#%d of %d (%d%%)"
851                  % (segnum, self._total_segments,
852                     100.0 * segnum / self._total_segments))
853         # memory footprint: when the SegmentDownloader finishes pulling down
854         # all shares, we have 1*segment_size of usage.
855         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
856                                         self._results)
857         started = time.time()
858         d = segmentdler.start()
859         def _finished_fetching(res):
860             elapsed = time.time() - started
861             self._results.timings["cumulative_fetch"] += elapsed
862             return res
863         if self._results:
864             d.addCallback(_finished_fetching)
865         # pause before using more memory
866         d.addCallback(self._check_for_pause)
867         # while the codec does its job, we hit 2*segment_size
868         def _started_decode(res):
869             self._started_decode = time.time()
870             return res
871         if self._results:
872             d.addCallback(_started_decode)
873         d.addCallback(lambda (shares, shareids):
874                       self._codec.decode(shares, shareids))
875         # once the codec is done, we drop back to 1*segment_size, because
876         # 'shares' goes out of scope. The memory usage is all in the
877         # plaintext now, spread out into a bunch of tiny buffers.
878         def _finished_decode(res):
879             elapsed = time.time() - self._started_decode
880             self._results.timings["cumulative_decode"] += elapsed
881             return res
882         if self._results:
883             d.addCallback(_finished_decode)
884
885         # pause/check-for-stop just before writing, to honor stopProducing
886         d.addCallback(self._check_for_pause)
887         def _done(buffers):
888             # we start by joining all these buffers together into a single
889             # string. This makes Output.write easier, since it wants to hash
890             # data one segment at a time anyways, and doesn't impact our
891             # memory footprint since we're already peaking at 2*segment_size
892             # inside the codec a moment ago.
893             segment = "".join(buffers)
894             del buffers
895             # we're down to 1*segment_size right now, but write_segment()
896             # will decrypt a copy of the segment internally, which will push
897             # us up to 2*segment_size while it runs.
898             started_decrypt = time.time()
899             self._output.write_segment(segment)
900             if self._results:
901                 elapsed = time.time() - started_decrypt
902                 self._results.timings["cumulative_decrypt"] += elapsed
903         d.addCallback(_done)
904         return d
905
906     def _download_tail_segment(self, res, segnum):
907         self.log("downloading seg#%d of %d (%d%%)"
908                  % (segnum, self._total_segments,
909                     100.0 * segnum / self._total_segments))
910         segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares,
911                                         self._results)
912         started = time.time()
913         d = segmentdler.start()
914         def _finished_fetching(res):
915             elapsed = time.time() - started
916             self._results.timings["cumulative_fetch"] += elapsed
917             return res
918         if self._results:
919             d.addCallback(_finished_fetching)
920         # pause before using more memory
921         d.addCallback(self._check_for_pause)
922         def _started_decode(res):
923             self._started_decode = time.time()
924             return res
925         if self._results:
926             d.addCallback(_started_decode)
927         d.addCallback(lambda (shares, shareids):
928                       self._tail_codec.decode(shares, shareids))
929         def _finished_decode(res):
930             elapsed = time.time() - self._started_decode
931             self._results.timings["cumulative_decode"] += elapsed
932             return res
933         if self._results:
934             d.addCallback(_finished_decode)
935         # pause/check-for-stop just before writing, to honor stopProducing
936         d.addCallback(self._check_for_pause)
937         def _done(buffers):
938             # trim off any padding added by the upload side
939             segment = "".join(buffers)
940             del buffers
941             # we never send empty segments. If the data was an exact multiple
942             # of the segment size, the last segment will be full.
943             pad_size = mathutil.pad_size(self._size, self._segment_size)
944             tail_size = self._segment_size - pad_size
945             segment = segment[:tail_size]
946             started_decrypt = time.time()
947             self._output.write_segment(segment)
948             if self._results:
949                 elapsed = time.time() - started_decrypt
950                 self._results.timings["cumulative_decrypt"] += elapsed
951         d.addCallback(_done)
952         return d
953
954     def _done(self, res):
955         self.log("download done")
956         if self._results:
957             now = time.time()
958             self._results.timings["total"] = now - self._started
959             self._results.timings["segments"] = now - self._started_fetching
960         self._output.close()
961         if self.check_crypttext_hash and self._crypttext_hash:
962             _assert(self._crypttext_hash == self._output.crypttext_hash,
963                     "bad crypttext_hash: computed=%s, expected=%s" %
964                     (base32.b2a(self._output.crypttext_hash),
965                      base32.b2a(self._crypttext_hash)))
966         if self.check_plaintext_hash and self._plaintext_hash:
967             _assert(self._plaintext_hash == self._output.plaintext_hash,
968                     "bad plaintext_hash: computed=%s, expected=%s" %
969                     (base32.b2a(self._output.plaintext_hash),
970                      base32.b2a(self._plaintext_hash)))
971         _assert(self._output.length == self._size,
972                 got=self._output.length, expected=self._size)
973         return self._output.finish()
974
975     def get_download_status(self):
976         return self._status
977
978
979 class FileName:
980     implements(IDownloadTarget)
981     def __init__(self, filename):
982         self._filename = filename
983         self.f = None
984     def open(self, size):
985         self.f = open(self._filename, "wb")
986         return self.f
987     def write(self, data):
988         self.f.write(data)
989     def close(self):
990         if self.f:
991             self.f.close()
992     def fail(self, why):
993         if self.f:
994             self.f.close()
995             os.unlink(self._filename)
996     def register_canceller(self, cb):
997         pass # we won't use it
998     def finish(self):
999         pass
1000
1001 class Data:
1002     implements(IDownloadTarget)
1003     def __init__(self):
1004         self._data = []
1005     def open(self, size):
1006         pass
1007     def write(self, data):
1008         self._data.append(data)
1009     def close(self):
1010         self.data = "".join(self._data)
1011         del self._data
1012     def fail(self, why):
1013         del self._data
1014     def register_canceller(self, cb):
1015         pass # we won't use it
1016     def finish(self):
1017         return self.data
1018
1019 class FileHandle:
1020     """Use me to download data to a pre-defined filehandle-like object. I
1021     will use the target's write() method. I will *not* close the filehandle:
1022     I leave that up to the originator of the filehandle. The download process
1023     will return the filehandle when it completes.
1024     """
1025     implements(IDownloadTarget)
1026     def __init__(self, filehandle):
1027         self._filehandle = filehandle
1028     def open(self, size):
1029         pass
1030     def write(self, data):
1031         self._filehandle.write(data)
1032     def close(self):
1033         # the originator of the filehandle reserves the right to close it
1034         pass
1035     def fail(self, why):
1036         pass
1037     def register_canceller(self, cb):
1038         pass
1039     def finish(self):
1040         return self._filehandle
1041
1042 class ConsumerAdapter:
1043     implements(IDownloadTarget, IConsumer)
1044     def __init__(self, consumer):
1045         self._consumer = consumer
1046         self._when_finished = observer.OneShotObserverList()
1047
1048     def when_finished(self):
1049         # I think this is unused, along with self._when_finished . But I need
1050         # to trace the error paths to be sure.
1051         return self._when_finished.when_fired()
1052
1053     def registerProducer(self, producer, streaming):
1054         self._consumer.registerProducer(producer, streaming)
1055     def unregisterProducer(self):
1056         self._consumer.unregisterProducer()
1057
1058     def open(self, size):
1059         pass
1060     def write(self, data):
1061         self._consumer.write(data)
1062     def close(self):
1063         self._when_finished.fire(None)
1064
1065     def fail(self, why):
1066         self._when_finished.fire(why)
1067     def register_canceller(self, cb):
1068         pass
1069     def finish(self):
1070         return None
1071
1072
1073 class Downloader(service.MultiService):
1074     """I am a service that allows file downloading.
1075     """
1076     # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1077     # It is scheduled to go away, to be replaced by filenode.download()
1078     implements(IDownloader)
1079     name = "downloader"
1080     MAX_DOWNLOAD_STATUSES = 10
1081
1082     def __init__(self, stats_provider=None):
1083         service.MultiService.__init__(self)
1084         self.stats_provider = stats_provider
1085         self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1086         self._all_download_statuses = weakref.WeakKeyDictionary()
1087         self._recent_download_statuses = []
1088
1089     def download(self, u, t):
1090         assert self.parent
1091         assert self.running
1092         u = IFileURI(u)
1093         t = IDownloadTarget(t)
1094         assert t.write
1095         assert t.close
1096
1097         assert isinstance(u, uri.CHKFileURI)
1098         if self.stats_provider:
1099             # these counters are meant for network traffic, and don't
1100             # include LIT files
1101             self.stats_provider.count('downloader.files_downloaded', 1)
1102             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1103         dl = FileDownloader(self.parent, u, t)
1104         self._add_download(dl)
1105         d = dl.start()
1106         return d
1107
1108     # utility functions
1109     def download_to_data(self, uri):
1110         return self.download(uri, Data())
1111     def download_to_filename(self, uri, filename):
1112         return self.download(uri, FileName(filename))
1113     def download_to_filehandle(self, uri, filehandle):
1114         return self.download(uri, FileHandle(filehandle))
1115
1116     def _add_download(self, downloader):
1117         self._all_downloads[downloader] = None
1118         s = downloader.get_download_status()
1119         self._all_download_statuses[s] = None
1120         self._recent_download_statuses.append(s)
1121         while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1122             self._recent_download_statuses.pop(0)
1123
1124     def list_all_download_statuses(self):
1125         for ds in self._all_download_statuses:
1126             yield ds