]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/immutable/download.py
download: refactor handling of URI Extension Block and crypttext hash tree, simplify...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / download.py
1
2 import os, random, weakref, itertools, time
3 from zope.interface import implements
4 from twisted.internet import defer
5 from twisted.internet.interfaces import IPushProducer, IConsumer
6 from twisted.application import service
7 from foolscap import DeadReferenceError
8 from foolscap.eventual import eventually
9
10 from allmydata.util import base32, mathutil, hashutil, log, observer
11 from allmydata.util.assertutil import _assert, precondition
12 from allmydata import codec, hashtree, storage, uri
13 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, IVerifierURI, \
14      IDownloadStatus, IDownloadResults, IValidatedThingProxy, NotEnoughSharesError
15 from allmydata.immutable import layout
16 from pycryptopp.cipher.aes import AES
17
18 class HaveAllPeersError(Exception):
19     # we use this to jump out of the loop
20     pass
21
22 class IntegrityCheckError(Exception):
23     pass
24
25 class BadURIExtensionHashValue(IntegrityCheckError):
26     pass
27 class BadURIExtension(IntegrityCheckError):
28     pass
29 class UnsupportedErasureCodec(BadURIExtension):
30     pass
31 class BadCrypttextHashValue(IntegrityCheckError):
32     pass
33
34 class DownloadStopped(Exception):
35     pass
36
37 class DownloadResults:
38     implements(IDownloadResults)
39
40     def __init__(self):
41         self.servers_used = set()
42         self.server_problems = {}
43         self.servermap = {}
44         self.timings = {}
45         self.file_size = None
46
47 class Output:
48     def __init__(self, downloadable, key, total_length, log_parent,
49                  download_status):
50         self.downloadable = downloadable
51         self._decryptor = AES(key)
52         self._crypttext_hasher = hashutil.crypttext_hasher()
53         self.length = 0
54         self.total_length = total_length
55         self._segment_number = 0
56         self._crypttext_hash_tree = None
57         self._opened = False
58         self._log_parent = log_parent
59         self._status = download_status
60         self._status.set_progress(0.0)
61
62     def log(self, *args, **kwargs):
63         if "parent" not in kwargs:
64             kwargs["parent"] = self._log_parent
65         if "facility" not in kwargs:
66             kwargs["facility"] = "download.output"
67         return log.msg(*args, **kwargs)
68
69     def got_crypttext_hash_tree(self, crypttext_hash_tree):
70         self._crypttext_hash_tree = crypttext_hash_tree
71
72     def write_segment(self, crypttext):
73         self.length += len(crypttext)
74         self._status.set_progress( float(self.length) / self.total_length )
75
76         # memory footprint: 'crypttext' is the only segment_size usage
77         # outstanding. While we decrypt it into 'plaintext', we hit
78         # 2*segment_size.
79         self._crypttext_hasher.update(crypttext)
80         if self._crypttext_hash_tree:
81             ch = hashutil.crypttext_segment_hasher()
82             ch.update(crypttext)
83             crypttext_leaves = {self._segment_number: ch.digest()}
84             self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
85                      bytes=len(crypttext),
86                      segnum=self._segment_number, hash=base32.b2a(ch.digest()),
87                      level=log.NOISY)
88             self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
89
90         plaintext = self._decryptor.process(crypttext)
91         del crypttext
92
93         # now we're back down to 1*segment_size.
94         self._segment_number += 1
95         # We're still at 1*segment_size. The Downloadable is responsible for
96         # any memory usage beyond this.
97         if not self._opened:
98             self._opened = True
99             self.downloadable.open(self.total_length)
100         self.downloadable.write(plaintext)
101
102     def fail(self, why):
103         # this is really unusual, and deserves maximum forensics
104         if why.check(DownloadStopped):
105             # except DownloadStopped just means the consumer aborted the
106             # download, not so scary
107             self.log("download stopped", level=log.UNUSUAL)
108         else:
109             self.log("download failed!", failure=why,
110                      level=log.SCARY, umid="lp1vaQ")
111         self.downloadable.fail(why)
112
113     def close(self):
114         self.crypttext_hash = self._crypttext_hasher.digest()
115         self.log("download finished, closing IDownloadable", level=log.NOISY)
116         self.downloadable.close()
117
118     def finish(self):
119         return self.downloadable.finish()
120
121 class ValidatedThingObtainer:
122     def __init__(self, validatedthingproxies, debugname, log_id):
123         self._validatedthingproxies = validatedthingproxies
124         self._debugname = debugname
125         self._log_id = log_id
126
127     def _bad(self, f, validatedthingproxy):
128         level = log.WEIRD
129         if f.check(DeadReferenceError):
130             level = log.UNUSUAL
131         log.msg(parent=self._log_id, facility="tahoe.immutable.download",
132                 format="operation %(op)s from validatedthingproxy %(validatedthingproxy)s failed",
133                 op=self._debugname, validatedthingproxy=str(validatedthingproxy),
134                 failure=f, level=level, umid="JGXxBA")
135         if not self._validatedthingproxies:
136             raise NotEnoughSharesError("ran out of peers, last error was %s" % (f,))
137         # try again with a different one
138         return self._try_the_next_one()
139
140     def _try_the_next_one(self):
141         vtp = self._validatedthingproxies.pop(0)
142         d = vtp.start() # start() obtains, validates, and callsback-with the thing or else errbacks
143         d.addErrback(self._bad, vtp)
144         return d
145
146     def start(self):
147         return self._try_the_next_one()
148
149 class ValidatedCrypttextHashTreeProxy:
150     implements(IValidatedThingProxy)
151     """ I am a front-end for a remote crypttext hash tree using a local ReadBucketProxy -- I use
152     its get_crypttext_hashes() method and offer the Validated Thing protocol (i.e., I have a
153     start() method that fires with self once I get a valid one). """
154     def __init__(self, readbucketproxy, crypttext_hash_tree, fetch_failures=None):
155         # fetch_failures is for debugging -- see test_encode.py
156         self._readbucketproxy = readbucketproxy
157         self._fetch_failures = fetch_failures
158         self._crypttext_hash_tree = crypttext_hash_tree
159
160     def _validate(self, proposal):
161         ct_hashes = dict(list(enumerate(proposal)))
162         try:
163             self._crypttext_hash_tree.set_hashes(ct_hashes)
164         except hashtree.BadHashError:
165             if self._fetch_failures is not None:
166                 self._fetch_failures["crypttext_hash_tree"] += 1
167             raise
168         return self
169
170     def start(self):
171         d = self._readbucketproxy.startIfNecessary()
172         d.addCallback(lambda ignored: self._readbucketproxy.get_crypttext_hashes())
173         d.addCallback(self._validate)
174         return d
175
176 class ValidatedExtendedURIProxy:
177     implements(IValidatedThingProxy)
178     """ I am a front-end for a remote UEB (using a local ReadBucketProxy), responsible for
179     retrieving and validating the elements from the UEB. """
180
181     def __init__(self, readbucketproxy, verifycap, fetch_failures=None):
182         # fetch_failures is for debugging -- see test_encode.py
183         self._fetch_failures = fetch_failures
184         self._readbucketproxy = readbucketproxy
185         precondition(IVerifierURI.providedBy(verifycap), verifycap)
186         self._verifycap = verifycap
187
188         # required
189         self.segment_size = None
190         self.crypttext_root_hash = None
191         self.share_root_hash = None
192
193         # computed
194         self.num_segments = None
195         self.tail_segment_size = None
196
197         # optional
198         self.crypttext_hash = None
199
200     def __str__(self):
201         return "<%s %s>" % (self.__class__.__name__, self._verifycap.to_string())
202
203     def _check_integrity(self, data):
204         h = hashutil.uri_extension_hash(data)
205         if h != self._verifycap.uri_extension_hash:
206             msg = ("The copy of uri_extension we received from %s was bad: wanted %s, got %s" %
207                    (self._readbucketproxy, base32.b2a(self._verifycap.uri_extension_hash), base32.b2a(h)))
208             if self._fetch_failures is not None:
209                 self._fetch_failures["uri_extension"] += 1
210             raise BadURIExtensionHashValue(msg)
211         else:
212             return data
213
214     def _parse_and_validate(self, data):
215         d = uri.unpack_extension(data)
216
217         # There are several kinds of things that can be found in a UEB.  First, things that we
218         # really need to learn from the UEB in order to do this download. Next: things which are
219         # optional but not redundant -- if they are present in the UEB they will get used. Next,
220         # things that are optional and redundant. These things are required to be consistent:
221         # they don't have to be in the UEB, but if they are in the UEB then they will be checked
222         # for consistency with the already-known facts, and if they are inconsistent then an
223         # exception will be raised. These things aren't actually used -- they are just tested
224         # for consistency and ignored. Finally: things which are deprecated -- they ought not be
225         # in the UEB at all, and if they are present then a warning will be logged but they are
226         # otherwise ignored.
227
228        # First, things that we really need to learn from the UEB: segment_size,
229         # crypttext_root_hash, and share_root_hash.
230         self.segment_size = d['segment_size']
231
232         self.num_segments = mathutil.div_ceil(self._verifycap.size, self.segment_size)
233
234         tail_data_size = self._verifycap.size % self.segment_size
235         if not tail_data_size:
236             tail_data_size = self.segment_size
237         # padding for erasure code
238         self.tail_segment_size = mathutil.next_multiple(tail_data_size, self._verifycap.needed_shares)
239
240         # Ciphertext hash tree root is mandatory, so that there is at most one ciphertext that
241         # matches this read-cap or verify-cap.  The integrity check on the shares is not
242         # sufficient to prevent the original encoder from creating some shares of file A and
243         # other shares of file B.
244         self.crypttext_root_hash = d['crypttext_root_hash']
245
246         self.share_root_hash = d['share_root_hash']
247
248
249         # Next: things that are optional and not redundant: crypttext_hash
250         if d.has_key('crypttext_hash'):
251             self.crypttext_hash = d['crypttext_hash']
252             if len(self.crypttext_hash) != hashutil.CRYPTO_VAL_SIZE:
253                 raise BadURIExtension('crypttext_hash is required to be hashutil.CRYPTO_VAL_SIZE bytes, not %s bytes' % (len(self.crypttext_hash),))
254
255
256         # Next: things that are optional, redundant, and required to be consistent: codec_name,
257         # codec_params, tail_codec_params, num_segments, size, needed_shares, total_shares
258         if d.has_key('codec_name'):
259             if d['codec_name'] != "crs":
260                 raise UnsupportedErasureCodec(d['codec_name'])
261
262         if d.has_key('codec_params'):
263             ucpss, ucpns, ucpts = codec.parse_params(d['codec_params'])
264             if ucpss != self.segment_size:
265                 raise BadURIExtension("inconsistent erasure code params: ucpss: %s != "
266                                       "self.segment_size: %s" % (ucpss, self.segment_size))
267             if ucpns != self._verifycap.needed_shares:
268                 raise BadURIExtension("inconsistent erasure code params: ucpns: %s != "
269                                       "self._verifycap.needed_shares: %s" % (ucpns,
270                                                                              self._verifycap.needed_shares))
271             if ucpts != self._verifycap.total_shares:
272                 raise BadURIExtension("inconsistent erasure code params: ucpts: %s != "
273                                       "self._verifycap.total_shares: %s" % (ucpts,
274                                                                             self._verifycap.total_shares))
275
276         if d.has_key('tail_codec_params'):
277             utcpss, utcpns, utcpts = codec.parse_params(d['tail_codec_params'])
278             if utcpss != self.tail_segment_size:
279                 raise BadURIExtension("inconsistent erasure code params: utcpss: %s != "
280                                       "self.tail_segment_size: %s, self._verifycap.size: %s, "
281                                       "self.segment_size: %s, self._verifycap.needed_shares: %s"
282                                       % (utcpss, self.tail_segment_size, self._verifycap.size,
283                                          self.segment_size, self._verifycap.needed_shares))
284             if utcpns != self._verifycap.needed_shares:
285                 raise BadURIExtension("inconsistent erasure code params: utcpns: %s != "
286                                       "self._verifycap.needed_shares: %s" % (utcpns,
287                                                                              self._verifycap.needed_shares))
288             if utcpts != self._verifycap.total_shares:
289                 raise BadURIExtension("inconsistent erasure code params: utcpts: %s != "
290                                       "self._verifycap.total_shares: %s" % (utcpts,
291                                                                             self._verifycap.total_shares))
292
293         if d.has_key('num_segments'):
294             if d['num_segments'] != self.num_segments:
295                 raise BadURIExtension("inconsistent num_segments: size: %s, "
296                                       "segment_size: %s, computed_num_segments: %s, "
297                                       "ueb_num_segments: %s" % (self._verifycap.size,
298                                                                 self.segment_size,
299                                                                 self.num_segments, d['num_segments']))
300
301         if d.has_key('size'):
302             if d['size'] != self._verifycap.size:
303                 raise BadURIExtension("inconsistent size: URI size: %s, UEB size: %s" %
304                                       (self._verifycap.size, d['size']))
305
306         if d.has_key('needed_shares'):
307             if d['needed_shares'] != self._verifycap.needed_shares:
308                 raise BadURIExtension("inconsistent needed shares: URI needed shares: %s, UEB "
309                                       "needed shares: %s" % (self._verifycap.total_shares,
310                                                              d['needed_shares']))
311
312         if d.has_key('total_shares'):
313             if d['total_shares'] != self._verifycap.total_shares:
314                 raise BadURIExtension("inconsistent total shares: URI total shares: %s, UEB "
315                                       "total shares: %s" % (self._verifycap.total_shares,
316                                                             d['total_shares']))
317
318         # Finally, things that are deprecated and ignored: plaintext_hash, plaintext_root_hash
319         if d.get('plaintext_hash'):
320             log.msg("Found plaintext_hash in UEB. This field is deprecated for security reasons "
321                     "and is no longer used.  Ignoring.  %s" % (self,))
322         if d.get('plaintext_root_hash'):
323             log.msg("Found plaintext_root_hash in UEB. This field is deprecated for security "
324                     "reasons and is no longer used.  Ignoring.  %s" % (self,))
325
326         return self
327
328     def start(self):
329         """ Fetch the UEB from bucket, compare its hash to the hash from verifycap, then parse
330         it.  Returns a deferred which is called back with self once the fetch is successful, or
331         is erred back if it fails. """
332         d = self._readbucketproxy.startIfNecessary()
333         d.addCallback(lambda ignored: self._readbucketproxy.get_uri_extension())
334         d.addCallback(self._check_integrity)
335         d.addCallback(self._parse_and_validate)
336         return d
337
338 class ValidatedReadBucketProxy:
339     """I am a front-end for a remote storage bucket, responsible for
340     retrieving and validating data from that bucket.
341
342     My get_block() method is used by BlockDownloaders.
343     """
344
345     def __init__(self, sharenum, bucket,
346                  share_hash_tree, share_root_hash,
347                  num_blocks):
348         """ share_root_hash is the root of the share hash tree; share_root_hash is stored in the UEB """
349         self.sharenum = sharenum
350         self.bucket = bucket
351         self._share_hash = None # None means not validated yet
352         self.share_hash_tree = share_hash_tree
353         self._share_root_hash = share_root_hash
354         self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
355         self.started = False
356
357     def get_block(self, blocknum):
358         if not self.started:
359             d = self.bucket.start()
360             def _started(res):
361                 self.started = True
362                 return self.get_block(blocknum)
363             d.addCallback(_started)
364             return d
365
366         # the first time we use this bucket, we need to fetch enough elements
367         # of the share hash tree to validate it from our share hash up to the
368         # hashroot.
369         if not self._share_hash:
370             d1 = self.bucket.get_share_hashes()
371         else:
372             d1 = defer.succeed([])
373
374         # we might need to grab some elements of our block hash tree, to
375         # validate the requested block up to the share hash
376         needed = self.block_hash_tree.needed_hashes(blocknum)
377         if needed:
378             # TODO: get fewer hashes, use get_block_hashes(needed)
379             d2 = self.bucket.get_block_hashes()
380         else:
381             d2 = defer.succeed([])
382
383         d3 = self.bucket.get_block(blocknum)
384
385         d = defer.gatherResults([d1, d2, d3])
386         d.addCallback(self._got_data, blocknum)
387         return d
388
389     def _got_data(self, res, blocknum):
390         sharehashes, blockhashes, blockdata = res
391         blockhash = None # to make logging it safe
392
393         try:
394             if not self._share_hash:
395                 sh = dict(sharehashes)
396                 sh[0] = self._share_root_hash # always use our own root, from the URI
397                 sht = self.share_hash_tree
398                 if sht.get_leaf_index(self.sharenum) not in sh:
399                     raise hashtree.NotEnoughHashesError
400                 sht.set_hashes(sh)
401                 self._share_hash = sht.get_leaf(self.sharenum)
402
403             blockhash = hashutil.block_hash(blockdata)
404             #log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
405             #        "%r .. %r: %s" %
406             #        (self.sharenum, blocknum, len(blockdata),
407             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
408
409             # we always validate the blockhash
410             bh = dict(enumerate(blockhashes))
411             # replace blockhash root with validated value
412             bh[0] = self._share_hash
413             self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
414
415         except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
416             # log.WEIRD: indicates undetected disk/network error, or more
417             # likely a programming error
418             log.msg("hash failure in block=%d, shnum=%d on %s" %
419                     (blocknum, self.sharenum, self.bucket))
420             if self._share_hash:
421                 log.msg(""" failure occurred when checking the block_hash_tree.
422                 This suggests that either the block data was bad, or that the
423                 block hashes we received along with it were bad.""")
424             else:
425                 log.msg(""" the failure probably occurred when checking the
426                 share_hash_tree, which suggests that the share hashes we
427                 received from the remote peer were bad.""")
428             log.msg(" have self._share_hash: %s" % bool(self._share_hash))
429             log.msg(" block length: %d" % len(blockdata))
430             log.msg(" block hash: %s" % base32.b2a_or_none(blockhash))
431             if len(blockdata) < 100:
432                 log.msg(" block data: %r" % (blockdata,))
433             else:
434                 log.msg(" block data start/end: %r .. %r" %
435                         (blockdata[:50], blockdata[-50:]))
436             log.msg(" root hash: %s" % base32.b2a(self._share_root_hash))
437             log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
438             log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
439             lines = []
440             for i,h in sorted(sharehashes):
441                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
442             log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
443             lines = []
444             for i,h in enumerate(blockhashes):
445                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
446             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
447             raise
448
449         # If we made it here, the block is good. If the hash trees didn't
450         # like what they saw, they would have raised a BadHashError, causing
451         # our caller to see a Failure and thus ignore this block (as well as
452         # dropping this bucket).
453         return blockdata
454
455
456
457 class BlockDownloader:
458     """I am responsible for downloading a single block (from a single bucket)
459     for a single segment.
460
461     I am a child of the SegmentDownloader.
462     """
463
464     def __init__(self, vbucket, blocknum, parent, results):
465         self.vbucket = vbucket
466         self.blocknum = blocknum
467         self.parent = parent
468         self.results = results
469         self._log_number = self.parent.log("starting block %d" % blocknum)
470
471     def log(self, *args, **kwargs):
472         if "parent" not in kwargs:
473             kwargs["parent"] = self._log_number
474         return self.parent.log(*args, **kwargs)
475
476     def start(self, segnum):
477         lognum = self.log("get_block(segnum=%d)" % segnum)
478         started = time.time()
479         d = self.vbucket.get_block(segnum)
480         d.addCallbacks(self._hold_block, self._got_block_error,
481                        callbackArgs=(started, lognum,), errbackArgs=(lognum,))
482         return d
483
484     def _hold_block(self, data, started, lognum):
485         if self.results:
486             elapsed = time.time() - started
487             peerid = self.vbucket.bucket.get_peerid()
488             if peerid not in self.results.timings["fetch_per_server"]:
489                 self.results.timings["fetch_per_server"][peerid] = []
490             self.results.timings["fetch_per_server"][peerid].append(elapsed)
491         self.log("got block", parent=lognum)
492         self.parent.hold_block(self.blocknum, data)
493
494     def _got_block_error(self, f, lognum):
495         level = log.WEIRD
496         if f.check(DeadReferenceError):
497             level = log.UNUSUAL
498         self.log("BlockDownloader[%d] got error" % self.blocknum,
499                  failure=f, level=level, parent=lognum, umid="5Z4uHQ")
500         if self.results:
501             peerid = self.vbucket.bucket.get_peerid()
502             self.results.server_problems[peerid] = str(f)
503         self.parent.bucket_failed(self.vbucket)
504
505 class SegmentDownloader:
506     """I am responsible for downloading all the blocks for a single segment
507     of data.
508
509     I am a child of the FileDownloader.
510     """
511
512     def __init__(self, parent, segmentnumber, needed_shares, results):
513         self.parent = parent
514         self.segmentnumber = segmentnumber
515         self.needed_blocks = needed_shares
516         self.blocks = {} # k: blocknum, v: data
517         self.results = results
518         self._log_number = self.parent.log("starting segment %d" %
519                                            segmentnumber)
520
521     def log(self, *args, **kwargs):
522         if "parent" not in kwargs:
523             kwargs["parent"] = self._log_number
524         return self.parent.log(*args, **kwargs)
525
526     def start(self):
527         return self._download()
528
529     def _download(self):
530         d = self._try()
531         def _done(res):
532             if len(self.blocks) >= self.needed_blocks:
533                 # we only need self.needed_blocks blocks
534                 # we want to get the smallest blockids, because they are
535                 # more likely to be fast "primary blocks"
536                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
537                 blocks = []
538                 for blocknum in blockids:
539                     blocks.append(self.blocks[blocknum])
540                 return (blocks, blockids)
541             else:
542                 return self._download()
543         d.addCallback(_done)
544         return d
545
546     def _try(self):
547         # fill our set of active buckets, maybe raising NotEnoughSharesError
548         active_buckets = self.parent._activate_enough_buckets()
549         # Now we have enough buckets, in self.parent.active_buckets.
550
551         # in test cases, bd.start might mutate active_buckets right away, so
552         # we need to put off calling start() until we've iterated all the way
553         # through it.
554         downloaders = []
555         for blocknum, vbucket in active_buckets.iteritems():
556             bd = BlockDownloader(vbucket, blocknum, self, self.results)
557             downloaders.append(bd)
558             if self.results:
559                 self.results.servers_used.add(vbucket.bucket.get_peerid())
560         l = [bd.start(self.segmentnumber) for bd in downloaders]
561         return defer.DeferredList(l, fireOnOneErrback=True)
562
563     def hold_block(self, blocknum, data):
564         self.blocks[blocknum] = data
565
566     def bucket_failed(self, vbucket):
567         self.parent.bucket_failed(vbucket)
568
569 class DownloadStatus:
570     implements(IDownloadStatus)
571     statusid_counter = itertools.count(0)
572
573     def __init__(self):
574         self.storage_index = None
575         self.size = None
576         self.helper = False
577         self.status = "Not started"
578         self.progress = 0.0
579         self.paused = False
580         self.stopped = False
581         self.active = True
582         self.results = None
583         self.counter = self.statusid_counter.next()
584         self.started = time.time()
585
586     def get_started(self):
587         return self.started
588     def get_storage_index(self):
589         return self.storage_index
590     def get_size(self):
591         return self.size
592     def using_helper(self):
593         return self.helper
594     def get_status(self):
595         status = self.status
596         if self.paused:
597             status += " (output paused)"
598         if self.stopped:
599             status += " (output stopped)"
600         return status
601     def get_progress(self):
602         return self.progress
603     def get_active(self):
604         return self.active
605     def get_results(self):
606         return self.results
607     def get_counter(self):
608         return self.counter
609
610     def set_storage_index(self, si):
611         self.storage_index = si
612     def set_size(self, size):
613         self.size = size
614     def set_helper(self, helper):
615         self.helper = helper
616     def set_status(self, status):
617         self.status = status
618     def set_paused(self, paused):
619         self.paused = paused
620     def set_stopped(self, stopped):
621         self.stopped = stopped
622     def set_progress(self, value):
623         self.progress = value
624     def set_active(self, value):
625         self.active = value
626     def set_results(self, value):
627         self.results = value
628
629 class FileDownloader:
630     implements(IPushProducer)
631     _status = None
632
633     def __init__(self, client, u, downloadable):
634         precondition(isinstance(u, uri.CHKFileURI), u)
635         self._client = client
636
637         self._uri = u
638         self._storage_index = u.storage_index
639         self._uri_extension_hash = u.uri_extension_hash
640         self._vup = None # ValidatedExtendedURIProxy
641
642         self._si_s = storage.si_b2a(self._storage_index)
643         self.init_logging()
644
645         self._started = time.time()
646         self._status = s = DownloadStatus()
647         s.set_status("Starting")
648         s.set_storage_index(self._storage_index)
649         s.set_size(self._uri.size)
650         s.set_helper(False)
651         s.set_active(True)
652
653         self._results = DownloadResults()
654         s.set_results(self._results)
655         self._results.file_size = self._uri.size
656         self._results.timings["servers_peer_selection"] = {}
657         self._results.timings["fetch_per_server"] = {}
658         self._results.timings["cumulative_fetch"] = 0.0
659         self._results.timings["cumulative_decode"] = 0.0
660         self._results.timings["cumulative_decrypt"] = 0.0
661         self._results.timings["paused"] = 0.0
662
663         self._paused = False
664         self._stopped = False
665         if IConsumer.providedBy(downloadable):
666             downloadable.registerProducer(self, True)
667         self._downloadable = downloadable
668         self._output = Output(downloadable, u.key, self._uri.size, self._log_number,
669                               self._status)
670
671         self.active_buckets = {} # k: shnum, v: bucket
672         self._share_buckets = [] # list of (sharenum, bucket) tuples
673         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
674
675         self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
676
677         self._crypttext_hash_tree = None
678
679     def init_logging(self):
680         self._log_prefix = prefix = storage.si_b2a(self._storage_index)[:5]
681         num = self._client.log(format="FileDownloader(%(si)s): starting",
682                                si=storage.si_b2a(self._storage_index))
683         self._log_number = num
684
685     def log(self, *args, **kwargs):
686         if "parent" not in kwargs:
687             kwargs["parent"] = self._log_number
688         if "facility" not in kwargs:
689             kwargs["facility"] = "tahoe.download"
690         return log.msg(*args, **kwargs)
691
692     def pauseProducing(self):
693         if self._paused:
694             return
695         self._paused = defer.Deferred()
696         self._paused_at = time.time()
697         if self._status:
698             self._status.set_paused(True)
699
700     def resumeProducing(self):
701         if self._paused:
702             paused_for = time.time() - self._paused_at
703             self._results.timings['paused'] += paused_for
704             p = self._paused
705             self._paused = None
706             eventually(p.callback, None)
707             if self._status:
708                 self._status.set_paused(False)
709
710     def stopProducing(self):
711         self.log("Download.stopProducing")
712         self._stopped = True
713         self.resumeProducing()
714         if self._status:
715             self._status.set_stopped(True)
716             self._status.set_active(False)
717
718     def start(self):
719         assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
720         self.log("starting download")
721
722         # first step: who should we download from?
723         d = defer.maybeDeferred(self._get_all_shareholders)
724         d.addCallback(self._got_all_shareholders)
725         # now get the uri_extension block from somebody and integrity check it and parse and validate its contents
726         d.addCallback(self._obtain_uri_extension)
727         d.addCallback(self._get_crypttext_hash_tree)
728         # once we know that, we can download blocks from everybody
729         d.addCallback(self._download_all_segments)
730         def _finished(res):
731             if self._status:
732                 self._status.set_status("Finished")
733                 self._status.set_active(False)
734                 self._status.set_paused(False)
735             if IConsumer.providedBy(self._downloadable):
736                 self._downloadable.unregisterProducer()
737             return res
738         d.addBoth(_finished)
739         def _failed(why):
740             if self._status:
741                 self._status.set_status("Failed")
742                 self._status.set_active(False)
743             self._output.fail(why)
744             return why
745         d.addErrback(_failed)
746         d.addCallback(self._done)
747         return d
748
749     def _get_all_shareholders(self):
750         dl = []
751         for (peerid,ss) in self._client.get_permuted_peers("storage",
752                                                            self._storage_index):
753             d = ss.callRemote("get_buckets", self._storage_index)
754             d.addCallbacks(self._got_response, self._got_error,
755                            callbackArgs=(peerid,))
756             dl.append(d)
757         self._responses_received = 0
758         self._queries_sent = len(dl)
759         if self._status:
760             self._status.set_status("Locating Shares (%d/%d)" %
761                                     (self._responses_received,
762                                      self._queries_sent))
763         return defer.DeferredList(dl)
764
765     def _got_response(self, buckets, peerid):
766         self._responses_received += 1
767         if self._results:
768             elapsed = time.time() - self._started
769             self._results.timings["servers_peer_selection"][peerid] = elapsed
770         if self._status:
771             self._status.set_status("Locating Shares (%d/%d)" %
772                                     (self._responses_received,
773                                      self._queries_sent))
774         for sharenum, bucket in buckets.iteritems():
775             b = layout.ReadBucketProxy(bucket, peerid, self._si_s)
776             self.add_share_bucket(sharenum, b)
777
778             if self._results:
779                 if peerid not in self._results.servermap:
780                     self._results.servermap[peerid] = set()
781                 self._results.servermap[peerid].add(sharenum)
782
783     def add_share_bucket(self, sharenum, bucket):
784         # this is split out for the benefit of test_encode.py
785         self._share_buckets.append( (sharenum, bucket) )
786
787     def _got_error(self, f):
788         level = log.WEIRD
789         if f.check(DeadReferenceError):
790             level = log.UNUSUAL
791         self._client.log("Error during get_buckets", failure=f, level=level,
792                          umid="3uuBUQ")
793
794     def bucket_failed(self, vbucket):
795         shnum = vbucket.sharenum
796         del self.active_buckets[shnum]
797         s = self._share_vbuckets[shnum]
798         # s is a set of ValidatedReadBucketProxy instances
799         s.remove(vbucket)
800         # ... which might now be empty
801         if not s:
802             # there are no more buckets which can provide this share, so
803             # remove the key. This may prompt us to use a different share.
804             del self._share_vbuckets[shnum]
805
806     def _got_all_shareholders(self, res):
807         assert isinstance(self._uri, uri.CHKFileURI), (self._uri, type(self._uri))
808         if self._results:
809             now = time.time()
810             self._results.timings["peer_selection"] = now - self._started
811
812         if len(self._share_buckets) < self._uri.needed_shares:
813             raise NotEnoughSharesError
814
815         #for s in self._share_vbuckets.values():
816         #    for vb in s:
817         #        assert isinstance(vb, ValidatedReadBucketProxy), \
818         #               "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
819
820     def _obtain_uri_extension(self, ignored):
821         assert isinstance(self._uri, uri.CHKFileURI), self._uri
822         # all shareholders are supposed to have a copy of uri_extension, and
823         # all are supposed to be identical. We compute the hash of the data
824         # that comes back, and compare it against the version in our URI. If
825         # they don't match, ignore their data and try someone else.
826         if self._status:
827             self._status.set_status("Obtaining URI Extension")
828
829         uri_extension_fetch_started = time.time()
830
831         vups = []
832         for sharenum, bucket in self._share_buckets:
833             vups.append(ValidatedExtendedURIProxy(bucket, self._uri.get_verifier(), self._fetch_failures))
834         vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._log_number)
835         d = vto.start()
836
837         def _got_uri_extension(vup):
838             precondition(isinstance(vup, ValidatedExtendedURIProxy), vup)
839             if self._results:
840                 elapsed = time.time() - uri_extension_fetch_started
841                 self._results.timings["uri_extension"] = elapsed
842
843             self._vup = vup
844             self._codec = codec.CRSDecoder()
845             self._codec.set_params(self._vup.segment_size, self._uri.needed_shares, self._uri.total_shares)
846             self._tail_codec = codec.CRSDecoder()
847             self._tail_codec.set_params(self._vup.tail_segment_size, self._uri.needed_shares, self._uri.total_shares)
848
849             self._current_segnum = 0
850
851             self._share_hashtree = hashtree.IncompleteHashTree(self._uri.total_shares)
852             self._share_hashtree.set_hashes({0: vup.share_root_hash})
853
854             self._crypttext_hash_tree = hashtree.IncompleteHashTree(self._vup.num_segments)
855             self._crypttext_hash_tree.set_hashes({0: self._vup.crypttext_root_hash})
856         d.addCallback(_got_uri_extension)
857         return d
858
859     def _get_crypttext_hash_tree(self, res):
860         vchtps = []
861         for sharenum, bucket in self._share_buckets:
862             vchtp = ValidatedCrypttextHashTreeProxy(bucket, self._crypttext_hash_tree, self._fetch_failures)
863             vchtps.append(vchtp)
864
865         _get_crypttext_hash_tree_started = time.time()
866         if self._status:
867             self._status.set_status("Retrieving crypttext hash tree")
868
869         vto = ValidatedThingObtainer(vchtps , debugname="vchtps", log_id=self._log_number)
870         d = vto.start()
871
872         def _got_crypttext_hash_tree(res):
873             self._crypttext_hash_tree = res._crypttext_hash_tree
874             self._output.got_crypttext_hash_tree(self._crypttext_hash_tree)
875             if self._results:
876                 elapsed = time.time() - _get_crypttext_hash_tree_started
877                 self._results.timings["hashtrees"] = elapsed
878         d.addCallback(_got_crypttext_hash_tree)
879         return d
880
881     def _activate_enough_buckets(self):
882         """either return a mapping from shnum to a ValidatedReadBucketProxy that can
883         provide data for that share, or raise NotEnoughSharesError"""
884         assert isinstance(self._uri, uri.CHKFileURI), self._uri
885
886         while len(self.active_buckets) < self._uri.needed_shares:
887             # need some more
888             handled_shnums = set(self.active_buckets.keys())
889             available_shnums = set(self._share_vbuckets.keys())
890             potential_shnums = list(available_shnums - handled_shnums)
891             if not potential_shnums:
892                 raise NotEnoughSharesError
893             # choose a random share
894             shnum = random.choice(potential_shnums)
895             # and a random bucket that will provide it
896             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
897             self.active_buckets[shnum] = validated_bucket
898         return self.active_buckets
899
900
901     def _download_all_segments(self, res):
902         for sharenum, bucket in self._share_buckets:
903             vbucket = ValidatedReadBucketProxy(sharenum, bucket,
904                                       self._share_hashtree,
905                                       self._vup.share_root_hash,
906                                       self._vup.num_segments)
907             s = self._share_vbuckets.setdefault(sharenum, set())
908             s.add(vbucket)
909
910         # after the above code, self._share_vbuckets contains enough
911         # buckets to complete the download, and some extra ones to
912         # tolerate some buckets dropping out or having
913         # errors. self._share_vbuckets is a dictionary that maps from
914         # shnum to a set of ValidatedBuckets, which themselves are
915         # wrappers around RIBucketReader references.
916         self.active_buckets = {} # k: shnum, v: ValidatedReadBucketProxy instance
917
918         self._started_fetching = time.time()
919
920         d = defer.succeed(None)
921         for segnum in range(self._vup.num_segments-1):
922             d.addCallback(self._download_segment, segnum)
923             # this pause, at the end of write, prevents pre-fetch from
924             # happening until the consumer is ready for more data.
925             d.addCallback(self._check_for_pause)
926         d.addCallback(self._download_tail_segment, self._vup.num_segments-1)
927         return d
928
929     def _check_for_pause(self, res):
930         if self._paused:
931             d = defer.Deferred()
932             self._paused.addCallback(lambda ignored: d.callback(res))
933             return d
934         if self._stopped:
935             raise DownloadStopped("our Consumer called stopProducing()")
936         return res
937
938     def _download_segment(self, res, segnum):
939         assert isinstance(self._uri, uri.CHKFileURI), self._uri
940         if self._status:
941             self._status.set_status("Downloading segment %d of %d" %
942                                     (segnum+1, self._vup.num_segments))
943         self.log("downloading seg#%d of %d (%d%%)"
944                  % (segnum, self._vup.num_segments,
945                     100.0 * segnum / self._vup.num_segments))
946         # memory footprint: when the SegmentDownloader finishes pulling down
947         # all shares, we have 1*segment_size of usage.
948         segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares,
949                                         self._results)
950         started = time.time()
951         d = segmentdler.start()
952         def _finished_fetching(res):
953             elapsed = time.time() - started
954             self._results.timings["cumulative_fetch"] += elapsed
955             return res
956         if self._results:
957             d.addCallback(_finished_fetching)
958         # pause before using more memory
959         d.addCallback(self._check_for_pause)
960         # while the codec does its job, we hit 2*segment_size
961         def _started_decode(res):
962             self._started_decode = time.time()
963             return res
964         if self._results:
965             d.addCallback(_started_decode)
966         d.addCallback(lambda (shares, shareids):
967                       self._codec.decode(shares, shareids))
968         # once the codec is done, we drop back to 1*segment_size, because
969         # 'shares' goes out of scope. The memory usage is all in the
970         # plaintext now, spread out into a bunch of tiny buffers.
971         def _finished_decode(res):
972             elapsed = time.time() - self._started_decode
973             self._results.timings["cumulative_decode"] += elapsed
974             return res
975         if self._results:
976             d.addCallback(_finished_decode)
977
978         # pause/check-for-stop just before writing, to honor stopProducing
979         d.addCallback(self._check_for_pause)
980         def _done(buffers):
981             # we start by joining all these buffers together into a single
982             # string. This makes Output.write easier, since it wants to hash
983             # data one segment at a time anyways, and doesn't impact our
984             # memory footprint since we're already peaking at 2*segment_size
985             # inside the codec a moment ago.
986             segment = "".join(buffers)
987             del buffers
988             # we're down to 1*segment_size right now, but write_segment()
989             # will decrypt a copy of the segment internally, which will push
990             # us up to 2*segment_size while it runs.
991             started_decrypt = time.time()
992             self._output.write_segment(segment)
993             if self._results:
994                 elapsed = time.time() - started_decrypt
995                 self._results.timings["cumulative_decrypt"] += elapsed
996         d.addCallback(_done)
997         return d
998
999     def _download_tail_segment(self, res, segnum):
1000         assert isinstance(self._uri, uri.CHKFileURI), self._uri
1001         self.log("downloading seg#%d of %d (%d%%)"
1002                  % (segnum, self._vup.num_segments,
1003                     100.0 * segnum / self._vup.num_segments))
1004         segmentdler = SegmentDownloader(self, segnum, self._uri.needed_shares,
1005                                         self._results)
1006         started = time.time()
1007         d = segmentdler.start()
1008         def _finished_fetching(res):
1009             elapsed = time.time() - started
1010             self._results.timings["cumulative_fetch"] += elapsed
1011             return res
1012         if self._results:
1013             d.addCallback(_finished_fetching)
1014         # pause before using more memory
1015         d.addCallback(self._check_for_pause)
1016         def _started_decode(res):
1017             self._started_decode = time.time()
1018             return res
1019         if self._results:
1020             d.addCallback(_started_decode)
1021         d.addCallback(lambda (shares, shareids):
1022                       self._tail_codec.decode(shares, shareids))
1023         def _finished_decode(res):
1024             elapsed = time.time() - self._started_decode
1025             self._results.timings["cumulative_decode"] += elapsed
1026             return res
1027         if self._results:
1028             d.addCallback(_finished_decode)
1029         # pause/check-for-stop just before writing, to honor stopProducing
1030         d.addCallback(self._check_for_pause)
1031         def _done(buffers):
1032             # trim off any padding added by the upload side
1033             segment = "".join(buffers)
1034             del buffers
1035             # we never send empty segments. If the data was an exact multiple
1036             # of the segment size, the last segment will be full.
1037             pad_size = mathutil.pad_size(self._uri.size, self._vup.segment_size)
1038             tail_size = self._vup.segment_size - pad_size
1039             segment = segment[:tail_size]
1040             started_decrypt = time.time()
1041             self._output.write_segment(segment)
1042             if self._results:
1043                 elapsed = time.time() - started_decrypt
1044                 self._results.timings["cumulative_decrypt"] += elapsed
1045         d.addCallback(_done)
1046         return d
1047
1048     def _done(self, res):
1049         assert isinstance(self._uri, uri.CHKFileURI), self._uri
1050         self.log("download done")
1051         if self._results:
1052             now = time.time()
1053             self._results.timings["total"] = now - self._started
1054             self._results.timings["segments"] = now - self._started_fetching
1055         self._output.close()
1056         if self._vup.crypttext_hash:
1057             _assert(self._vup.crypttext_hash == self._output.crypttext_hash,
1058                     "bad crypttext_hash: computed=%s, expected=%s" %
1059                     (base32.b2a(self._output.crypttext_hash),
1060                      base32.b2a(self._vup.crypttext_hash)))
1061         _assert(self._output.length == self._uri.size,
1062                 got=self._output.length, expected=self._uri.size)
1063         return self._output.finish()
1064
1065     def get_download_status(self):
1066         return self._status
1067
1068
1069 class FileName:
1070     implements(IDownloadTarget)
1071     def __init__(self, filename):
1072         self._filename = filename
1073         self.f = None
1074     def open(self, size):
1075         self.f = open(self._filename, "wb")
1076         return self.f
1077     def write(self, data):
1078         self.f.write(data)
1079     def close(self):
1080         if self.f:
1081             self.f.close()
1082     def fail(self, why):
1083         if self.f:
1084             self.f.close()
1085             os.unlink(self._filename)
1086     def register_canceller(self, cb):
1087         pass # we won't use it
1088     def finish(self):
1089         pass
1090
1091 class Data:
1092     implements(IDownloadTarget)
1093     def __init__(self):
1094         self._data = []
1095     def open(self, size):
1096         pass
1097     def write(self, data):
1098         self._data.append(data)
1099     def close(self):
1100         self.data = "".join(self._data)
1101         del self._data
1102     def fail(self, why):
1103         del self._data
1104     def register_canceller(self, cb):
1105         pass # we won't use it
1106     def finish(self):
1107         return self.data
1108
1109 class FileHandle:
1110     """Use me to download data to a pre-defined filehandle-like object. I
1111     will use the target's write() method. I will *not* close the filehandle:
1112     I leave that up to the originator of the filehandle. The download process
1113     will return the filehandle when it completes.
1114     """
1115     implements(IDownloadTarget)
1116     def __init__(self, filehandle):
1117         self._filehandle = filehandle
1118     def open(self, size):
1119         pass
1120     def write(self, data):
1121         self._filehandle.write(data)
1122     def close(self):
1123         # the originator of the filehandle reserves the right to close it
1124         pass
1125     def fail(self, why):
1126         pass
1127     def register_canceller(self, cb):
1128         pass
1129     def finish(self):
1130         return self._filehandle
1131
1132 class ConsumerAdapter:
1133     implements(IDownloadTarget, IConsumer)
1134     def __init__(self, consumer):
1135         self._consumer = consumer
1136
1137     def registerProducer(self, producer, streaming):
1138         self._consumer.registerProducer(producer, streaming)
1139     def unregisterProducer(self):
1140         self._consumer.unregisterProducer()
1141
1142     def open(self, size):
1143         pass
1144     def write(self, data):
1145         self._consumer.write(data)
1146     def close(self):
1147         pass
1148
1149     def fail(self, why):
1150         pass
1151     def register_canceller(self, cb):
1152         pass
1153     def finish(self):
1154         return self._consumer
1155
1156
1157 class Downloader(service.MultiService):
1158     """I am a service that allows file downloading.
1159     """
1160     # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1161     # It is scheduled to go away, to be replaced by filenode.download()
1162     implements(IDownloader)
1163     name = "downloader"
1164     MAX_DOWNLOAD_STATUSES = 10
1165
1166     def __init__(self, stats_provider=None):
1167         service.MultiService.__init__(self)
1168         self.stats_provider = stats_provider
1169         self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1170         self._all_download_statuses = weakref.WeakKeyDictionary()
1171         self._recent_download_statuses = []
1172
1173     def download(self, u, t):
1174         assert self.parent
1175         assert self.running
1176         u = IFileURI(u)
1177         t = IDownloadTarget(t)
1178         assert t.write
1179         assert t.close
1180
1181         assert isinstance(u, uri.CHKFileURI)
1182         if self.stats_provider:
1183             # these counters are meant for network traffic, and don't
1184             # include LIT files
1185             self.stats_provider.count('downloader.files_downloaded', 1)
1186             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1187         dl = FileDownloader(self.parent, u, t)
1188         self._add_download(dl)
1189         d = dl.start()
1190         return d
1191
1192     # utility functions
1193     def download_to_data(self, uri):
1194         return self.download(uri, Data())
1195     def download_to_filename(self, uri, filename):
1196         return self.download(uri, FileName(filename))
1197     def download_to_filehandle(self, uri, filehandle):
1198         return self.download(uri, FileHandle(filehandle))
1199
1200     def _add_download(self, downloader):
1201         self._all_downloads[downloader] = None
1202         s = downloader.get_download_status()
1203         self._all_download_statuses[s] = None
1204         self._recent_download_statuses.append(s)
1205         while len(self._recent_download_statuses) > self.MAX_DOWNLOAD_STATUSES:
1206             self._recent_download_statuses.pop(0)
1207
1208     def list_all_download_statuses(self):
1209         for ds in self._all_download_statuses:
1210             yield ds