]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/immutable/download.py
3dde40f1f0ac2883af6f607c33c3f34819f0f269
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / download.py
1 import os, random, weakref, itertools, time
2 from zope.interface import implements
3 from twisted.internet import defer
4 from twisted.internet.interfaces import IPushProducer, IConsumer
5 from foolscap.api import DeadReferenceError, RemoteException, eventually
6
7 from allmydata.util import base32, deferredutil, hashutil, log, mathutil, idlib
8 from allmydata.util.assertutil import _assert, precondition
9 from allmydata import codec, hashtree, uri
10 from allmydata.interfaces import IDownloadTarget, IDownloader, \
11      IFileURI, IVerifierURI, \
12      IDownloadStatus, IDownloadResults, IValidatedThingProxy, \
13      IStorageBroker, NotEnoughSharesError, NoSharesError, NoServersError, \
14      UnableToFetchCriticalDownloadDataError
15 from allmydata.immutable import layout
16 from allmydata.monitor import Monitor
17 from pycryptopp.cipher.aes import AES
18
19 class IntegrityCheckReject(Exception):
20     pass
21
22 class BadURIExtensionHashValue(IntegrityCheckReject):
23     pass
24 class BadURIExtension(IntegrityCheckReject):
25     pass
26 class UnsupportedErasureCodec(BadURIExtension):
27     pass
28 class BadCrypttextHashValue(IntegrityCheckReject):
29     pass
30 class BadOrMissingHash(IntegrityCheckReject):
31     pass
32
33 class DownloadStopped(Exception):
34     pass
35
36 class DownloadResults:
37     implements(IDownloadResults)
38
39     def __init__(self):
40         self.servers_used = set()
41         self.server_problems = {}
42         self.servermap = {}
43         self.timings = {}
44         self.file_size = None
45
46 class DecryptingTarget(log.PrefixingLogMixin):
47     implements(IDownloadTarget, IConsumer)
48     def __init__(self, target, key, _log_msg_id=None):
49         precondition(IDownloadTarget.providedBy(target), target)
50         self.target = target
51         self._decryptor = AES(key)
52         prefix = str(target)
53         log.PrefixingLogMixin.__init__(self, "allmydata.immutable.download", _log_msg_id, prefix=prefix)
54     # methods to satisfy the IConsumer interface
55     def registerProducer(self, producer, streaming):
56         if IConsumer.providedBy(self.target):
57             self.target.registerProducer(producer, streaming)
58     def unregisterProducer(self):
59         if IConsumer.providedBy(self.target):
60             self.target.unregisterProducer()
61     def write(self, ciphertext):
62         plaintext = self._decryptor.process(ciphertext)
63         self.target.write(plaintext)
64     def open(self, size):
65         self.target.open(size)
66     def close(self):
67         self.target.close()
68     def finish(self):
69         return self.target.finish()
70     # The following methods is just to pass through to the next target, and
71     # just because that target might be a repairer.DownUpConnector, and just
72     # because the current CHKUpload object expects to find the storage index
73     # in its Uploadable.
74     def set_storageindex(self, storageindex):
75         self.target.set_storageindex(storageindex)
76     def set_encodingparams(self, encodingparams):
77         self.target.set_encodingparams(encodingparams)
78
79 class ValidatedThingObtainer:
80     def __init__(self, validatedthingproxies, debugname, log_id):
81         self._validatedthingproxies = validatedthingproxies
82         self._debugname = debugname
83         self._log_id = log_id
84
85     def _bad(self, f, validatedthingproxy):
86         failtype = f.trap(RemoteException, DeadReferenceError,
87                           IntegrityCheckReject, layout.LayoutInvalid,
88                           layout.ShareVersionIncompatible)
89         level = log.WEIRD
90         if f.check(DeadReferenceError):
91             level = log.UNUSUAL
92         elif f.check(RemoteException):
93             level = log.WEIRD
94         else:
95             level = log.SCARY
96         log.msg(parent=self._log_id, facility="tahoe.immutable.download",
97                 format="operation %(op)s from validatedthingproxy %(validatedthingproxy)s failed",
98                 op=self._debugname, validatedthingproxy=str(validatedthingproxy),
99                 failure=f, level=level, umid="JGXxBA")
100         if not self._validatedthingproxies:
101             raise UnableToFetchCriticalDownloadDataError("ran out of peers, last error was %s" % (f,))
102         # try again with a different one
103         d = self._try_the_next_one()
104         return d
105
106     def _try_the_next_one(self):
107         vtp = self._validatedthingproxies.pop(0)
108         # start() obtains, validates, and callsback-with the thing or else
109         # errbacks
110         d = vtp.start()
111         d.addErrback(self._bad, vtp)
112         return d
113
114     def start(self):
115         return self._try_the_next_one()
116
117 class ValidatedCrypttextHashTreeProxy:
118     implements(IValidatedThingProxy)
119     """ I am a front-end for a remote crypttext hash tree using a local
120     ReadBucketProxy -- I use its get_crypttext_hashes() method and offer the
121     Validated Thing protocol (i.e., I have a start() method that fires with
122     self once I get a valid one)."""
123     def __init__(self, readbucketproxy, crypttext_hash_tree, num_segments,
124                  fetch_failures=None):
125         # fetch_failures is for debugging -- see test_encode.py
126         self._readbucketproxy = readbucketproxy
127         self._num_segments = num_segments
128         self._fetch_failures = fetch_failures
129         self._crypttext_hash_tree = crypttext_hash_tree
130
131     def _validate(self, proposal):
132         ct_hashes = dict(list(enumerate(proposal)))
133         try:
134             self._crypttext_hash_tree.set_hashes(ct_hashes)
135         except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
136             if self._fetch_failures is not None:
137                 self._fetch_failures["crypttext_hash_tree"] += 1
138             raise BadOrMissingHash(le)
139         # If we now have enough of the crypttext hash tree to integrity-check
140         # *any* segment of ciphertext, then we are done. TODO: It would have
141         # better alacrity if we downloaded only part of the crypttext hash
142         # tree at a time.
143         for segnum in range(self._num_segments):
144             if self._crypttext_hash_tree.needed_hashes(segnum):
145                 raise BadOrMissingHash("not enough hashes to validate segment number %d" % (segnum,))
146         return self
147
148     def start(self):
149         d = self._readbucketproxy.get_crypttext_hashes()
150         d.addCallback(self._validate)
151         return d
152
153 class ValidatedExtendedURIProxy:
154     implements(IValidatedThingProxy)
155     """ I am a front-end for a remote UEB (using a local ReadBucketProxy),
156     responsible for retrieving and validating the elements from the UEB."""
157
158     def __init__(self, readbucketproxy, verifycap, fetch_failures=None):
159         # fetch_failures is for debugging -- see test_encode.py
160         self._fetch_failures = fetch_failures
161         self._readbucketproxy = readbucketproxy
162         precondition(IVerifierURI.providedBy(verifycap), verifycap)
163         self._verifycap = verifycap
164
165         # required
166         self.segment_size = None
167         self.crypttext_root_hash = None
168         self.share_root_hash = None
169
170         # computed
171         self.block_size = None
172         self.share_size = None
173         self.num_segments = None
174         self.tail_data_size = None
175         self.tail_segment_size = None
176
177         # optional
178         self.crypttext_hash = None
179
180     def __str__(self):
181         return "<%s %s>" % (self.__class__.__name__, self._verifycap.to_string())
182
183     def _check_integrity(self, data):
184         h = hashutil.uri_extension_hash(data)
185         if h != self._verifycap.uri_extension_hash:
186             msg = ("The copy of uri_extension we received from %s was bad: wanted %s, got %s" %
187                    (self._readbucketproxy,
188                     base32.b2a(self._verifycap.uri_extension_hash),
189                     base32.b2a(h)))
190             if self._fetch_failures is not None:
191                 self._fetch_failures["uri_extension"] += 1
192             raise BadURIExtensionHashValue(msg)
193         else:
194             return data
195
196     def _parse_and_validate(self, data):
197         self.share_size = mathutil.div_ceil(self._verifycap.size,
198                                             self._verifycap.needed_shares)
199
200         d = uri.unpack_extension(data)
201
202         # There are several kinds of things that can be found in a UEB.
203         # First, things that we really need to learn from the UEB in order to
204         # do this download. Next: things which are optional but not redundant
205         # -- if they are present in the UEB they will get used. Next, things
206         # that are optional and redundant. These things are required to be
207         # consistent: they don't have to be in the UEB, but if they are in
208         # the UEB then they will be checked for consistency with the
209         # already-known facts, and if they are inconsistent then an exception
210         # will be raised. These things aren't actually used -- they are just
211         # tested for consistency and ignored. Finally: things which are
212         # deprecated -- they ought not be in the UEB at all, and if they are
213         # present then a warning will be logged but they are otherwise
214         # ignored.
215
216         # First, things that we really need to learn from the UEB:
217         # segment_size, crypttext_root_hash, and share_root_hash.
218         self.segment_size = d['segment_size']
219
220         self.block_size = mathutil.div_ceil(self.segment_size,
221                                             self._verifycap.needed_shares)
222         self.num_segments = mathutil.div_ceil(self._verifycap.size,
223                                               self.segment_size)
224
225         self.tail_data_size = self._verifycap.size % self.segment_size
226         if not self.tail_data_size:
227             self.tail_data_size = self.segment_size
228         # padding for erasure code
229         self.tail_segment_size = mathutil.next_multiple(self.tail_data_size,
230                                                         self._verifycap.needed_shares)
231
232         # Ciphertext hash tree root is mandatory, so that there is at most
233         # one ciphertext that matches this read-cap or verify-cap. The
234         # integrity check on the shares is not sufficient to prevent the
235         # original encoder from creating some shares of file A and other
236         # shares of file B.
237         self.crypttext_root_hash = d['crypttext_root_hash']
238
239         self.share_root_hash = d['share_root_hash']
240
241
242         # Next: things that are optional and not redundant: crypttext_hash
243         if d.has_key('crypttext_hash'):
244             self.crypttext_hash = d['crypttext_hash']
245             if len(self.crypttext_hash) != hashutil.CRYPTO_VAL_SIZE:
246                 raise BadURIExtension('crypttext_hash is required to be hashutil.CRYPTO_VAL_SIZE bytes, not %s bytes' % (len(self.crypttext_hash),))
247
248
249         # Next: things that are optional, redundant, and required to be
250         # consistent: codec_name, codec_params, tail_codec_params,
251         # num_segments, size, needed_shares, total_shares
252         if d.has_key('codec_name'):
253             if d['codec_name'] != "crs":
254                 raise UnsupportedErasureCodec(d['codec_name'])
255
256         if d.has_key('codec_params'):
257             ucpss, ucpns, ucpts = codec.parse_params(d['codec_params'])
258             if ucpss != self.segment_size:
259                 raise BadURIExtension("inconsistent erasure code params: "
260                                       "ucpss: %s != self.segment_size: %s" %
261                                       (ucpss, self.segment_size))
262             if ucpns != self._verifycap.needed_shares:
263                 raise BadURIExtension("inconsistent erasure code params: ucpns: %s != "
264                                       "self._verifycap.needed_shares: %s" %
265                                       (ucpns, self._verifycap.needed_shares))
266             if ucpts != self._verifycap.total_shares:
267                 raise BadURIExtension("inconsistent erasure code params: ucpts: %s != "
268                                       "self._verifycap.total_shares: %s" %
269                                       (ucpts, self._verifycap.total_shares))
270
271         if d.has_key('tail_codec_params'):
272             utcpss, utcpns, utcpts = codec.parse_params(d['tail_codec_params'])
273             if utcpss != self.tail_segment_size:
274                 raise BadURIExtension("inconsistent erasure code params: utcpss: %s != "
275                                       "self.tail_segment_size: %s, self._verifycap.size: %s, "
276                                       "self.segment_size: %s, self._verifycap.needed_shares: %s"
277                                       % (utcpss, self.tail_segment_size, self._verifycap.size,
278                                          self.segment_size, self._verifycap.needed_shares))
279             if utcpns != self._verifycap.needed_shares:
280                 raise BadURIExtension("inconsistent erasure code params: utcpns: %s != "
281                                       "self._verifycap.needed_shares: %s" % (utcpns,
282                                                                              self._verifycap.needed_shares))
283             if utcpts != self._verifycap.total_shares:
284                 raise BadURIExtension("inconsistent erasure code params: utcpts: %s != "
285                                       "self._verifycap.total_shares: %s" % (utcpts,
286                                                                             self._verifycap.total_shares))
287
288         if d.has_key('num_segments'):
289             if d['num_segments'] != self.num_segments:
290                 raise BadURIExtension("inconsistent num_segments: size: %s, "
291                                       "segment_size: %s, computed_num_segments: %s, "
292                                       "ueb_num_segments: %s" % (self._verifycap.size,
293                                                                 self.segment_size,
294                                                                 self.num_segments, d['num_segments']))
295
296         if d.has_key('size'):
297             if d['size'] != self._verifycap.size:
298                 raise BadURIExtension("inconsistent size: URI size: %s, UEB size: %s" %
299                                       (self._verifycap.size, d['size']))
300
301         if d.has_key('needed_shares'):
302             if d['needed_shares'] != self._verifycap.needed_shares:
303                 raise BadURIExtension("inconsistent needed shares: URI needed shares: %s, UEB "
304                                       "needed shares: %s" % (self._verifycap.total_shares,
305                                                              d['needed_shares']))
306
307         if d.has_key('total_shares'):
308             if d['total_shares'] != self._verifycap.total_shares:
309                 raise BadURIExtension("inconsistent total shares: URI total shares: %s, UEB "
310                                       "total shares: %s" % (self._verifycap.total_shares,
311                                                             d['total_shares']))
312
313         # Finally, things that are deprecated and ignored: plaintext_hash,
314         # plaintext_root_hash
315         if d.get('plaintext_hash'):
316             log.msg("Found plaintext_hash in UEB. This field is deprecated for security reasons "
317                     "and is no longer used.  Ignoring.  %s" % (self,))
318         if d.get('plaintext_root_hash'):
319             log.msg("Found plaintext_root_hash in UEB. This field is deprecated for security "
320                     "reasons and is no longer used.  Ignoring.  %s" % (self,))
321
322         return self
323
324     def start(self):
325         """Fetch the UEB from bucket, compare its hash to the hash from
326         verifycap, then parse it. Returns a deferred which is called back
327         with self once the fetch is successful, or is erred back if it
328         fails."""
329         d = self._readbucketproxy.get_uri_extension()
330         d.addCallback(self._check_integrity)
331         d.addCallback(self._parse_and_validate)
332         return d
333
334 class ValidatedReadBucketProxy(log.PrefixingLogMixin):
335     """I am a front-end for a remote storage bucket, responsible for
336     retrieving and validating data from that bucket.
337
338     My get_block() method is used by BlockDownloaders.
339     """
340
341     def __init__(self, sharenum, bucket, share_hash_tree, num_blocks,
342                  block_size, share_size):
343         """ share_hash_tree is required to have already been initialized with
344         the root hash (the number-0 hash), using the share_root_hash from the
345         UEB"""
346         precondition(share_hash_tree[0] is not None, share_hash_tree)
347         prefix = "%d-%s-%s" % (sharenum, bucket,
348                                base32.b2a_l(share_hash_tree[0][:8], 60))
349         log.PrefixingLogMixin.__init__(self,
350                                        facility="tahoe.immutable.download",
351                                        prefix=prefix)
352         self.sharenum = sharenum
353         self.bucket = bucket
354         self.share_hash_tree = share_hash_tree
355         self.num_blocks = num_blocks
356         self.block_size = block_size
357         self.share_size = share_size
358         self.block_hash_tree = hashtree.IncompleteHashTree(self.num_blocks)
359
360     def get_all_sharehashes(self):
361         """Retrieve and validate all the share-hash-tree nodes that are
362         included in this share, regardless of whether we need them to
363         validate the share or not. Each share contains a minimal Merkle tree
364         chain, but there is lots of overlap, so usually we'll be using hashes
365         from other shares and not reading every single hash from this share.
366         The Verifier uses this function to read and validate every single
367         hash from this share.
368
369         Call this (and wait for the Deferred it returns to fire) before
370         calling get_block() for the first time: this lets us check that the
371         share share contains enough hashes to validate its own data, and
372         avoids downloading any share hash twice.
373
374         I return a Deferred which errbacks upon failure, probably with
375         BadOrMissingHash."""
376
377         d = self.bucket.get_share_hashes()
378         def _got_share_hashes(sh):
379             sharehashes = dict(sh)
380             try:
381                 self.share_hash_tree.set_hashes(sharehashes)
382             except IndexError, le:
383                 raise BadOrMissingHash(le)
384             except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
385                 raise BadOrMissingHash(le)
386         d.addCallback(_got_share_hashes)
387         return d
388
389     def get_all_blockhashes(self):
390         """Retrieve and validate all the block-hash-tree nodes that are
391         included in this share. Each share contains a full Merkle tree, but
392         we usually only fetch the minimal subset necessary for any particular
393         block. This function fetches everything at once. The Verifier uses
394         this function to validate the block hash tree.
395
396         Call this (and wait for the Deferred it returns to fire) after
397         calling get_all_sharehashes() and before calling get_block() for the
398         first time: this lets us check that the share contains all block
399         hashes and avoids downloading them multiple times.
400
401         I return a Deferred which errbacks upon failure, probably with
402         BadOrMissingHash.
403         """
404
405         # get_block_hashes(anything) currently always returns everything
406         needed = list(range(len(self.block_hash_tree)))
407         d = self.bucket.get_block_hashes(needed)
408         def _got_block_hashes(blockhashes):
409             if len(blockhashes) < len(self.block_hash_tree):
410                 raise BadOrMissingHash()
411             bh = dict(enumerate(blockhashes))
412
413             try:
414                 self.block_hash_tree.set_hashes(bh)
415             except IndexError, le:
416                 raise BadOrMissingHash(le)
417             except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
418                 raise BadOrMissingHash(le)
419         d.addCallback(_got_block_hashes)
420         return d
421
422     def get_all_crypttext_hashes(self, crypttext_hash_tree):
423         """Retrieve and validate all the crypttext-hash-tree nodes that are
424         in this share. Normally we don't look at these at all: the download
425         process fetches them incrementally as needed to validate each segment
426         of ciphertext. But this is a convenient place to give the Verifier a
427         function to validate all of these at once.
428
429         Call this with a new hashtree object for each share, initialized with
430         the crypttext hash tree root. I return a Deferred which errbacks upon
431         failure, probably with BadOrMissingHash.
432         """
433
434         # get_crypttext_hashes() always returns everything
435         d = self.bucket.get_crypttext_hashes()
436         def _got_crypttext_hashes(hashes):
437             if len(hashes) < len(crypttext_hash_tree):
438                 raise BadOrMissingHash()
439             ct_hashes = dict(enumerate(hashes))
440             try:
441                 crypttext_hash_tree.set_hashes(ct_hashes)
442             except IndexError, le:
443                 raise BadOrMissingHash(le)
444             except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
445                 raise BadOrMissingHash(le)
446         d.addCallback(_got_crypttext_hashes)
447         return d
448
449     def get_block(self, blocknum):
450         # the first time we use this bucket, we need to fetch enough elements
451         # of the share hash tree to validate it from our share hash up to the
452         # hashroot.
453         if self.share_hash_tree.needed_hashes(self.sharenum):
454             d1 = self.bucket.get_share_hashes()
455         else:
456             d1 = defer.succeed([])
457
458         # We might need to grab some elements of our block hash tree, to
459         # validate the requested block up to the share hash.
460         blockhashesneeded = self.block_hash_tree.needed_hashes(blocknum, include_leaf=True)
461         # We don't need the root of the block hash tree, as that comes in the
462         # share tree.
463         blockhashesneeded.discard(0)
464         d2 = self.bucket.get_block_hashes(blockhashesneeded)
465
466         if blocknum < self.num_blocks-1:
467             thisblocksize = self.block_size
468         else:
469             thisblocksize = self.share_size % self.block_size
470             if thisblocksize == 0:
471                 thisblocksize = self.block_size
472         d3 = self.bucket.get_block_data(blocknum,
473                                         self.block_size, thisblocksize)
474
475         dl = deferredutil.gatherResults([d1, d2, d3])
476         dl.addCallback(self._got_data, blocknum)
477         return dl
478
479     def _got_data(self, results, blocknum):
480         precondition(blocknum < self.num_blocks,
481                      self, blocknum, self.num_blocks)
482         sharehashes, blockhashes, blockdata = results
483         try:
484             sharehashes = dict(sharehashes)
485         except ValueError, le:
486             le.args = tuple(le.args + (sharehashes,))
487             raise
488         blockhashes = dict(enumerate(blockhashes))
489
490         candidate_share_hash = None # in case we log it in the except block below
491         blockhash = None # in case we log it in the except block below
492
493         try:
494             if self.share_hash_tree.needed_hashes(self.sharenum):
495                 # This will raise exception if the values being passed do not
496                 # match the root node of self.share_hash_tree.
497                 try:
498                     self.share_hash_tree.set_hashes(sharehashes)
499                 except IndexError, le:
500                     # Weird -- sharehashes contained index numbers outside of
501                     # the range that fit into this hash tree.
502                     raise BadOrMissingHash(le)
503
504             # To validate a block we need the root of the block hash tree,
505             # which is also one of the leafs of the share hash tree, and is
506             # called "the share hash".
507             if not self.block_hash_tree[0]: # empty -- no root node yet
508                 # Get the share hash from the share hash tree.
509                 share_hash = self.share_hash_tree.get_leaf(self.sharenum)
510                 if not share_hash:
511                     # No root node in block_hash_tree and also the share hash
512                     # wasn't sent by the server.
513                     raise hashtree.NotEnoughHashesError
514                 self.block_hash_tree.set_hashes({0: share_hash})
515
516             if self.block_hash_tree.needed_hashes(blocknum):
517                 self.block_hash_tree.set_hashes(blockhashes)
518
519             blockhash = hashutil.block_hash(blockdata)
520             self.block_hash_tree.set_hashes(leaves={blocknum: blockhash})
521             #self.log("checking block_hash(shareid=%d, blocknum=%d) len=%d "
522             #        "%r .. %r: %s" %
523             #        (self.sharenum, blocknum, len(blockdata),
524             #         blockdata[:50], blockdata[-50:], base32.b2a(blockhash)))
525
526         except (hashtree.BadHashError, hashtree.NotEnoughHashesError), le:
527             # log.WEIRD: indicates undetected disk/network error, or more
528             # likely a programming error
529             self.log("hash failure in block=%d, shnum=%d on %s" %
530                     (blocknum, self.sharenum, self.bucket))
531             if self.block_hash_tree.needed_hashes(blocknum):
532                 self.log(""" failure occurred when checking the block_hash_tree.
533                 This suggests that either the block data was bad, or that the
534                 block hashes we received along with it were bad.""")
535             else:
536                 self.log(""" the failure probably occurred when checking the
537                 share_hash_tree, which suggests that the share hashes we
538                 received from the remote peer were bad.""")
539             self.log(" have candidate_share_hash: %s" % bool(candidate_share_hash))
540             self.log(" block length: %d" % len(blockdata))
541             self.log(" block hash: %s" % base32.b2a_or_none(blockhash))
542             if len(blockdata) < 100:
543                 self.log(" block data: %r" % (blockdata,))
544             else:
545                 self.log(" block data start/end: %r .. %r" %
546                         (blockdata[:50], blockdata[-50:]))
547             self.log(" share hash tree:\n" + self.share_hash_tree.dump())
548             self.log(" block hash tree:\n" + self.block_hash_tree.dump())
549             lines = []
550             for i,h in sorted(sharehashes.items()):
551                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
552             self.log(" sharehashes:\n" + "\n".join(lines) + "\n")
553             lines = []
554             for i,h in blockhashes.items():
555                 lines.append("%3d: %s" % (i, base32.b2a_or_none(h)))
556             log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
557             raise BadOrMissingHash(le)
558
559         # If we made it here, the block is good. If the hash trees didn't
560         # like what they saw, they would have raised a BadHashError, causing
561         # our caller to see a Failure and thus ignore this block (as well as
562         # dropping this bucket).
563         return blockdata
564
565
566
567 class BlockDownloader(log.PrefixingLogMixin):
568     """I am responsible for downloading a single block (from a single bucket)
569     for a single segment.
570
571     I am a child of the SegmentDownloader.
572     """
573
574     def __init__(self, vbucket, blocknum, parent, results):
575         precondition(isinstance(vbucket, ValidatedReadBucketProxy), vbucket)
576         prefix = "%s-%d" % (vbucket, blocknum)
577         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
578         self.vbucket = vbucket
579         self.blocknum = blocknum
580         self.parent = parent
581         self.results = results
582
583     def start(self, segnum):
584         self.log("get_block(segnum=%d)" % segnum)
585         started = time.time()
586         d = self.vbucket.get_block(segnum)
587         d.addCallbacks(self._hold_block, self._got_block_error,
588                        callbackArgs=(started,))
589         return d
590
591     def _hold_block(self, data, started):
592         if self.results:
593             elapsed = time.time() - started
594             peerid = self.vbucket.bucket.get_peerid()
595             if peerid not in self.results.timings["fetch_per_server"]:
596                 self.results.timings["fetch_per_server"][peerid] = []
597             self.results.timings["fetch_per_server"][peerid].append(elapsed)
598         self.log("got block")
599         self.parent.hold_block(self.blocknum, data)
600
601     def _got_block_error(self, f):
602         failtype = f.trap(RemoteException, DeadReferenceError,
603                           IntegrityCheckReject,
604                           layout.LayoutInvalid, layout.ShareVersionIncompatible)
605         if f.check(RemoteException, DeadReferenceError):
606             level = log.UNUSUAL
607         else:
608             level = log.WEIRD
609         self.log("failure to get block", level=level, umid="5Z4uHQ")
610         if self.results:
611             peerid = self.vbucket.bucket.get_peerid()
612             self.results.server_problems[peerid] = str(f)
613         self.parent.bucket_failed(self.vbucket)
614
615 class SegmentDownloader:
616     """I am responsible for downloading all the blocks for a single segment
617     of data.
618
619     I am a child of the CiphertextDownloader.
620     """
621
622     def __init__(self, parent, segmentnumber, needed_shares, results):
623         self.parent = parent
624         self.segmentnumber = segmentnumber
625         self.needed_blocks = needed_shares
626         self.blocks = {} # k: blocknum, v: data
627         self.results = results
628         self._log_number = self.parent.log("starting segment %d" %
629                                            segmentnumber)
630
631     def log(self, *args, **kwargs):
632         if "parent" not in kwargs:
633             kwargs["parent"] = self._log_number
634         return self.parent.log(*args, **kwargs)
635
636     def start(self):
637         return self._download()
638
639     def _download(self):
640         d = self._try()
641         def _done(res):
642             if len(self.blocks) >= self.needed_blocks:
643                 # we only need self.needed_blocks blocks
644                 # we want to get the smallest blockids, because they are
645                 # more likely to be fast "primary blocks"
646                 blockids = sorted(self.blocks.keys())[:self.needed_blocks]
647                 blocks = []
648                 for blocknum in blockids:
649                     blocks.append(self.blocks[blocknum])
650                 return (blocks, blockids)
651             else:
652                 return self._download()
653         d.addCallback(_done)
654         return d
655
656     def _try(self):
657         # fill our set of active buckets, maybe raising NotEnoughSharesError
658         active_buckets = self.parent._activate_enough_buckets()
659         # Now we have enough buckets, in self.parent.active_buckets.
660
661         # in test cases, bd.start might mutate active_buckets right away, so
662         # we need to put off calling start() until we've iterated all the way
663         # through it.
664         downloaders = []
665         for blocknum, vbucket in active_buckets.iteritems():
666             assert isinstance(vbucket, ValidatedReadBucketProxy), vbucket
667             bd = BlockDownloader(vbucket, blocknum, self, self.results)
668             downloaders.append(bd)
669             if self.results:
670                 self.results.servers_used.add(vbucket.bucket.get_peerid())
671         l = [bd.start(self.segmentnumber) for bd in downloaders]
672         return defer.DeferredList(l, fireOnOneErrback=True)
673
674     def hold_block(self, blocknum, data):
675         self.blocks[blocknum] = data
676
677     def bucket_failed(self, vbucket):
678         self.parent.bucket_failed(vbucket)
679
680 class DownloadStatus:
681     implements(IDownloadStatus)
682     statusid_counter = itertools.count(0)
683
684     def __init__(self):
685         self.storage_index = None
686         self.size = None
687         self.helper = False
688         self.status = "Not started"
689         self.progress = 0.0
690         self.paused = False
691         self.stopped = False
692         self.active = True
693         self.results = None
694         self.counter = self.statusid_counter.next()
695         self.started = time.time()
696
697     def get_started(self):
698         return self.started
699     def get_storage_index(self):
700         return self.storage_index
701     def get_size(self):
702         return self.size
703     def using_helper(self):
704         return self.helper
705     def get_status(self):
706         status = self.status
707         if self.paused:
708             status += " (output paused)"
709         if self.stopped:
710             status += " (output stopped)"
711         return status
712     def get_progress(self):
713         return self.progress
714     def get_active(self):
715         return self.active
716     def get_results(self):
717         return self.results
718     def get_counter(self):
719         return self.counter
720
721     def set_storage_index(self, si):
722         self.storage_index = si
723     def set_size(self, size):
724         self.size = size
725     def set_helper(self, helper):
726         self.helper = helper
727     def set_status(self, status):
728         self.status = status
729     def set_paused(self, paused):
730         self.paused = paused
731     def set_stopped(self, stopped):
732         self.stopped = stopped
733     def set_progress(self, value):
734         self.progress = value
735     def set_active(self, value):
736         self.active = value
737     def set_results(self, value):
738         self.results = value
739
740 class CiphertextDownloader(log.PrefixingLogMixin):
741     """ I download shares, check their integrity, then decode them, check the
742     integrity of the resulting ciphertext, then and write it to my target.
743     Before I send any new request to a server, I always ask the 'monitor'
744     object that was passed into my constructor whether this task has been
745     cancelled (by invoking its raise_if_cancelled() method)."""
746     implements(IPushProducer)
747     _status = None
748
749     def __init__(self, storage_broker, v, target, monitor):
750
751         precondition(IStorageBroker.providedBy(storage_broker), storage_broker)
752         precondition(IVerifierURI.providedBy(v), v)
753         precondition(IDownloadTarget.providedBy(target), target)
754
755         prefix=base32.b2a_l(v.storage_index[:8], 60)
756         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.download", prefix=prefix)
757         self._storage_broker = storage_broker
758
759         self._verifycap = v
760         self._storage_index = v.storage_index
761         self._uri_extension_hash = v.uri_extension_hash
762
763         self._started = time.time()
764         self._status = s = DownloadStatus()
765         s.set_status("Starting")
766         s.set_storage_index(self._storage_index)
767         s.set_size(self._verifycap.size)
768         s.set_helper(False)
769         s.set_active(True)
770
771         self._results = DownloadResults()
772         s.set_results(self._results)
773         self._results.file_size = self._verifycap.size
774         self._results.timings["servers_peer_selection"] = {}
775         self._results.timings["fetch_per_server"] = {}
776         self._results.timings["cumulative_fetch"] = 0.0
777         self._results.timings["cumulative_decode"] = 0.0
778         self._results.timings["cumulative_decrypt"] = 0.0
779         self._results.timings["paused"] = 0.0
780
781         self._paused = False
782         self._stopped = False
783         if IConsumer.providedBy(target):
784             target.registerProducer(self, True)
785         self._target = target
786         # Repairer (uploader) needs the storageindex.
787         self._target.set_storageindex(self._storage_index)
788         self._monitor = monitor
789         self._opened = False
790
791         self.active_buckets = {} # k: shnum, v: bucket
792         self._share_buckets = [] # list of (sharenum, bucket) tuples
793         self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
794
795         self._fetch_failures = {"uri_extension": 0, "crypttext_hash_tree": 0, }
796
797         self._ciphertext_hasher = hashutil.crypttext_hasher()
798
799         self._bytes_done = 0
800         self._status.set_progress(float(self._bytes_done)/self._verifycap.size)
801
802         # _got_uri_extension() will create the following:
803         # self._crypttext_hash_tree
804         # self._share_hash_tree
805         # self._current_segnum = 0
806         # self._vup # ValidatedExtendedURIProxy
807
808     def pauseProducing(self):
809         if self._paused:
810             return
811         self._paused = defer.Deferred()
812         self._paused_at = time.time()
813         if self._status:
814             self._status.set_paused(True)
815
816     def resumeProducing(self):
817         if self._paused:
818             paused_for = time.time() - self._paused_at
819             self._results.timings['paused'] += paused_for
820             p = self._paused
821             self._paused = None
822             eventually(p.callback, None)
823             if self._status:
824                 self._status.set_paused(False)
825
826     def stopProducing(self):
827         self.log("Download.stopProducing")
828         self._stopped = True
829         self.resumeProducing()
830         if self._status:
831             self._status.set_stopped(True)
832             self._status.set_active(False)
833
834     def start(self):
835         self.log("starting download")
836
837         # first step: who should we download from?
838         d = defer.maybeDeferred(self._get_all_shareholders)
839         d.addCallback(self._got_all_shareholders)
840         # now get the uri_extension block from somebody and integrity check
841         # it and parse and validate its contents
842         d.addCallback(self._obtain_uri_extension)
843         d.addCallback(self._get_crypttext_hash_tree)
844         # once we know that, we can download blocks from everybody
845         d.addCallback(self._download_all_segments)
846         def _finished(res):
847             if self._status:
848                 self._status.set_status("Finished")
849                 self._status.set_active(False)
850                 self._status.set_paused(False)
851             if IConsumer.providedBy(self._target):
852                 self._target.unregisterProducer()
853             return res
854         d.addBoth(_finished)
855         def _failed(why):
856             if self._status:
857                 self._status.set_status("Failed")
858                 self._status.set_active(False)
859             if why.check(DownloadStopped):
860                 # DownloadStopped just means the consumer aborted the
861                 # download; not so scary.
862                 self.log("download stopped", level=log.UNUSUAL)
863             else:
864                 # This is really unusual, and deserves maximum forensics.
865                 self.log("download failed!", failure=why, level=log.SCARY,
866                          umid="lp1vaQ")
867             return why
868         d.addErrback(_failed)
869         d.addCallback(self._done)
870         return d
871
872     def _get_all_shareholders(self):
873         dl = []
874         sb = self._storage_broker
875         servers = sb.get_servers_for_index(self._storage_index)
876         if not servers:
877             raise NoServersError("broker gave us no servers!")
878         for (peerid,ss) in servers:
879             self.log(format="sending DYHB to [%(peerid)s]",
880                      peerid=idlib.shortnodeid_b2a(peerid),
881                      level=log.NOISY, umid="rT03hg")
882             d = ss.callRemote("get_buckets", self._storage_index)
883             d.addCallbacks(self._got_response, self._got_error,
884                            callbackArgs=(peerid,))
885             dl.append(d)
886         self._responses_received = 0
887         self._queries_sent = len(dl)
888         if self._status:
889             self._status.set_status("Locating Shares (%d/%d)" %
890                                     (self._responses_received,
891                                      self._queries_sent))
892         return defer.DeferredList(dl)
893
894     def _got_response(self, buckets, peerid):
895         self.log(format="got results from [%(peerid)s]: shnums %(shnums)s",
896                  peerid=idlib.shortnodeid_b2a(peerid),
897                  shnums=sorted(buckets.keys()),
898                  level=log.NOISY, umid="o4uwFg")
899         self._responses_received += 1
900         if self._results:
901             elapsed = time.time() - self._started
902             self._results.timings["servers_peer_selection"][peerid] = elapsed
903         if self._status:
904             self._status.set_status("Locating Shares (%d/%d)" %
905                                     (self._responses_received,
906                                      self._queries_sent))
907         for sharenum, bucket in buckets.iteritems():
908             b = layout.ReadBucketProxy(bucket, peerid, self._storage_index)
909             self.add_share_bucket(sharenum, b)
910
911             if self._results:
912                 if peerid not in self._results.servermap:
913                     self._results.servermap[peerid] = set()
914                 self._results.servermap[peerid].add(sharenum)
915
916     def add_share_bucket(self, sharenum, bucket):
917         # this is split out for the benefit of test_encode.py
918         self._share_buckets.append( (sharenum, bucket) )
919
920     def _got_error(self, f):
921         level = log.WEIRD
922         if f.check(DeadReferenceError):
923             level = log.UNUSUAL
924         self.log("Error during get_buckets", failure=f, level=level,
925                          umid="3uuBUQ")
926
927     def bucket_failed(self, vbucket):
928         shnum = vbucket.sharenum
929         del self.active_buckets[shnum]
930         s = self._share_vbuckets[shnum]
931         # s is a set of ValidatedReadBucketProxy instances
932         s.remove(vbucket)
933         # ... which might now be empty
934         if not s:
935             # there are no more buckets which can provide this share, so
936             # remove the key. This may prompt us to use a different share.
937             del self._share_vbuckets[shnum]
938
939     def _got_all_shareholders(self, res):
940         if self._results:
941             now = time.time()
942             self._results.timings["peer_selection"] = now - self._started
943
944         if len(self._share_buckets) < self._verifycap.needed_shares:
945             msg = "Failed to get enough shareholders: have %d, need %d" \
946                   % (len(self._share_buckets), self._verifycap.needed_shares)
947             if self._share_buckets:
948                 raise NotEnoughSharesError(msg)
949             else:
950                 raise NoSharesError(msg)
951
952         #for s in self._share_vbuckets.values():
953         #    for vb in s:
954         #        assert isinstance(vb, ValidatedReadBucketProxy), \
955         #               "vb is %s but should be a ValidatedReadBucketProxy" % (vb,)
956
957     def _obtain_uri_extension(self, ignored):
958         # all shareholders are supposed to have a copy of uri_extension, and
959         # all are supposed to be identical. We compute the hash of the data
960         # that comes back, and compare it against the version in our URI. If
961         # they don't match, ignore their data and try someone else.
962         if self._status:
963             self._status.set_status("Obtaining URI Extension")
964
965         uri_extension_fetch_started = time.time()
966
967         vups = []
968         for sharenum, bucket in self._share_buckets:
969             vups.append(ValidatedExtendedURIProxy(bucket, self._verifycap, self._fetch_failures))
970         vto = ValidatedThingObtainer(vups, debugname="vups", log_id=self._parentmsgid)
971         d = vto.start()
972
973         def _got_uri_extension(vup):
974             precondition(isinstance(vup, ValidatedExtendedURIProxy), vup)
975             if self._results:
976                 elapsed = time.time() - uri_extension_fetch_started
977                 self._results.timings["uri_extension"] = elapsed
978
979             self._vup = vup
980             self._codec = codec.CRSDecoder()
981             self._codec.set_params(self._vup.segment_size, self._verifycap.needed_shares, self._verifycap.total_shares)
982             self._tail_codec = codec.CRSDecoder()
983             self._tail_codec.set_params(self._vup.tail_segment_size, self._verifycap.needed_shares, self._verifycap.total_shares)
984
985             self._current_segnum = 0
986
987             self._share_hash_tree = hashtree.IncompleteHashTree(self._verifycap.total_shares)
988             self._share_hash_tree.set_hashes({0: vup.share_root_hash})
989
990             self._crypttext_hash_tree = hashtree.IncompleteHashTree(self._vup.num_segments)
991             self._crypttext_hash_tree.set_hashes({0: self._vup.crypttext_root_hash})
992
993             # Repairer (uploader) needs the encodingparams.
994             self._target.set_encodingparams((
995                 self._verifycap.needed_shares,
996                 self._verifycap.total_shares, # I don't think the target actually cares about "happy".
997                 self._verifycap.total_shares,
998                 self._vup.segment_size
999                 ))
1000         d.addCallback(_got_uri_extension)
1001         return d
1002
1003     def _get_crypttext_hash_tree(self, res):
1004         vchtps = []
1005         for sharenum, bucket in self._share_buckets:
1006             vchtp = ValidatedCrypttextHashTreeProxy(bucket, self._crypttext_hash_tree, self._vup.num_segments, self._fetch_failures)
1007             vchtps.append(vchtp)
1008
1009         _get_crypttext_hash_tree_started = time.time()
1010         if self._status:
1011             self._status.set_status("Retrieving crypttext hash tree")
1012
1013         vto = ValidatedThingObtainer(vchtps, debugname="vchtps",
1014                                      log_id=self._parentmsgid)
1015         d = vto.start()
1016
1017         def _got_crypttext_hash_tree(res):
1018             # Good -- the self._crypttext_hash_tree that we passed to vchtp
1019             # is now populated with hashes.
1020             if self._results:
1021                 elapsed = time.time() - _get_crypttext_hash_tree_started
1022                 self._results.timings["hashtrees"] = elapsed
1023         d.addCallback(_got_crypttext_hash_tree)
1024         return d
1025
1026     def _activate_enough_buckets(self):
1027         """either return a mapping from shnum to a ValidatedReadBucketProxy
1028         that can provide data for that share, or raise NotEnoughSharesError"""
1029
1030         while len(self.active_buckets) < self._verifycap.needed_shares:
1031             # need some more
1032             handled_shnums = set(self.active_buckets.keys())
1033             available_shnums = set(self._share_vbuckets.keys())
1034             potential_shnums = list(available_shnums - handled_shnums)
1035             if len(potential_shnums) < (self._verifycap.needed_shares
1036                                         - len(self.active_buckets)):
1037                 have = len(potential_shnums) + len(self.active_buckets)
1038                 msg = "Unable to activate enough shares: have %d, need %d" \
1039                       % (have, self._verifycap.needed_shares)
1040                 if have:
1041                     raise NotEnoughSharesError(msg)
1042                 else:
1043                     raise NoSharesError(msg)
1044             # For the next share, choose a primary share if available, else a
1045             # randomly chosen secondary share.
1046             potential_shnums.sort()
1047             if potential_shnums[0] < self._verifycap.needed_shares:
1048                 shnum = potential_shnums[0]
1049             else:
1050                 shnum = random.choice(potential_shnums)
1051             # and a random bucket that will provide it
1052             validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
1053             self.active_buckets[shnum] = validated_bucket
1054         return self.active_buckets
1055
1056
1057     def _download_all_segments(self, res):
1058         for sharenum, bucket in self._share_buckets:
1059             vbucket = ValidatedReadBucketProxy(sharenum, bucket, self._share_hash_tree, self._vup.num_segments, self._vup.block_size, self._vup.share_size)
1060             self._share_vbuckets.setdefault(sharenum, set()).add(vbucket)
1061
1062         # after the above code, self._share_vbuckets contains enough
1063         # buckets to complete the download, and some extra ones to
1064         # tolerate some buckets dropping out or having
1065         # errors. self._share_vbuckets is a dictionary that maps from
1066         # shnum to a set of ValidatedBuckets, which themselves are
1067         # wrappers around RIBucketReader references.
1068         self.active_buckets = {} # k: shnum, v: ValidatedReadBucketProxy instance
1069
1070         self._started_fetching = time.time()
1071
1072         d = defer.succeed(None)
1073         for segnum in range(self._vup.num_segments):
1074             d.addCallback(self._download_segment, segnum)
1075             # this pause, at the end of write, prevents pre-fetch from
1076             # happening until the consumer is ready for more data.
1077             d.addCallback(self._check_for_pause)
1078         return d
1079
1080     def _check_for_pause(self, res):
1081         if self._paused:
1082             d = defer.Deferred()
1083             self._paused.addCallback(lambda ignored: d.callback(res))
1084             return d
1085         if self._stopped:
1086             raise DownloadStopped("our Consumer called stopProducing()")
1087         self._monitor.raise_if_cancelled()
1088         return res
1089
1090     def _download_segment(self, res, segnum):
1091         if self._status:
1092             self._status.set_status("Downloading segment %d of %d" %
1093                                     (segnum+1, self._vup.num_segments))
1094         self.log("downloading seg#%d of %d (%d%%)"
1095                  % (segnum, self._vup.num_segments,
1096                     100.0 * segnum / self._vup.num_segments))
1097         # memory footprint: when the SegmentDownloader finishes pulling down
1098         # all shares, we have 1*segment_size of usage.
1099         segmentdler = SegmentDownloader(self, segnum,
1100                                         self._verifycap.needed_shares,
1101                                         self._results)
1102         started = time.time()
1103         d = segmentdler.start()
1104         def _finished_fetching(res):
1105             elapsed = time.time() - started
1106             self._results.timings["cumulative_fetch"] += elapsed
1107             return res
1108         if self._results:
1109             d.addCallback(_finished_fetching)
1110         # pause before using more memory
1111         d.addCallback(self._check_for_pause)
1112         # while the codec does its job, we hit 2*segment_size
1113         def _started_decode(res):
1114             self._started_decode = time.time()
1115             return res
1116         if self._results:
1117             d.addCallback(_started_decode)
1118         if segnum + 1 == self._vup.num_segments:
1119             codec = self._tail_codec
1120         else:
1121             codec = self._codec
1122         d.addCallback(lambda (shares, shareids): codec.decode(shares, shareids))
1123         # once the codec is done, we drop back to 1*segment_size, because
1124         # 'shares' goes out of scope. The memory usage is all in the
1125         # plaintext now, spread out into a bunch of tiny buffers.
1126         def _finished_decode(res):
1127             elapsed = time.time() - self._started_decode
1128             self._results.timings["cumulative_decode"] += elapsed
1129             return res
1130         if self._results:
1131             d.addCallback(_finished_decode)
1132
1133         # pause/check-for-stop just before writing, to honor stopProducing
1134         d.addCallback(self._check_for_pause)
1135         d.addCallback(self._got_segment)
1136         return d
1137
1138     def _got_segment(self, buffers):
1139         precondition(self._crypttext_hash_tree)
1140         started_decrypt = time.time()
1141         self._status.set_progress(float(self._current_segnum)/self._verifycap.size)
1142
1143         if self._current_segnum + 1 == self._vup.num_segments:
1144             # This is the last segment.
1145             # Trim off any padding added by the upload side. We never send
1146             # empty segments. If the data was an exact multiple of the
1147             # segment size, the last segment will be full.
1148             tail_buf_size = mathutil.div_ceil(self._vup.tail_segment_size, self._verifycap.needed_shares)
1149             num_buffers_used = mathutil.div_ceil(self._vup.tail_data_size, tail_buf_size)
1150             # Remove buffers which don't contain any part of the tail.
1151             del buffers[num_buffers_used:]
1152             # Remove the past-the-tail-part of the last buffer.
1153             tail_in_last_buf = self._vup.tail_data_size % tail_buf_size
1154             if tail_in_last_buf == 0:
1155                 tail_in_last_buf = tail_buf_size
1156             buffers[-1] = buffers[-1][:tail_in_last_buf]
1157
1158         # First compute the hash of this segment and check that it fits.
1159         ch = hashutil.crypttext_segment_hasher()
1160         for buffer in buffers:
1161             self._ciphertext_hasher.update(buffer)
1162             ch.update(buffer)
1163         self._crypttext_hash_tree.set_hashes(leaves={self._current_segnum: ch.digest()})
1164
1165         # Then write this segment to the target.
1166         if not self._opened:
1167             self._opened = True
1168             self._target.open(self._verifycap.size)
1169
1170         for buffer in buffers:
1171             self._target.write(buffer)
1172             self._bytes_done += len(buffer)
1173
1174         self._status.set_progress(float(self._bytes_done)/self._verifycap.size)
1175         self._current_segnum += 1
1176
1177         if self._results:
1178             elapsed = time.time() - started_decrypt
1179             self._results.timings["cumulative_decrypt"] += elapsed
1180
1181     def _done(self, res):
1182         self.log("download done")
1183         if self._results:
1184             now = time.time()
1185             self._results.timings["total"] = now - self._started
1186             self._results.timings["segments"] = now - self._started_fetching
1187         if self._vup.crypttext_hash:
1188             _assert(self._vup.crypttext_hash == self._ciphertext_hasher.digest(),
1189                     "bad crypttext_hash: computed=%s, expected=%s" %
1190                     (base32.b2a(self._ciphertext_hasher.digest()),
1191                      base32.b2a(self._vup.crypttext_hash)))
1192         _assert(self._bytes_done == self._verifycap.size, self._bytes_done, self._verifycap.size)
1193         self._status.set_progress(1)
1194         self._target.close()
1195         return self._target.finish()
1196     def get_download_status(self):
1197         return self._status
1198
1199
1200 class FileName:
1201     implements(IDownloadTarget)
1202     def __init__(self, filename):
1203         self._filename = filename
1204         self.f = None
1205     def open(self, size):
1206         self.f = open(self._filename, "wb")
1207         return self.f
1208     def write(self, data):
1209         self.f.write(data)
1210     def close(self):
1211         if self.f:
1212             self.f.close()
1213     def fail(self, why):
1214         if self.f:
1215             self.f.close()
1216             os.unlink(self._filename)
1217     def register_canceller(self, cb):
1218         pass # we won't use it
1219     def finish(self):
1220         pass
1221     # The following methods are just because the target might be a
1222     # repairer.DownUpConnector, and just because the current CHKUpload object
1223     # expects to find the storage index and encoding parameters in its
1224     # Uploadable.
1225     def set_storageindex(self, storageindex):
1226         pass
1227     def set_encodingparams(self, encodingparams):
1228         pass
1229
1230 class Data:
1231     implements(IDownloadTarget)
1232     def __init__(self):
1233         self._data = []
1234     def open(self, size):
1235         pass
1236     def write(self, data):
1237         self._data.append(data)
1238     def close(self):
1239         self.data = "".join(self._data)
1240         del self._data
1241     def fail(self, why):
1242         del self._data
1243     def register_canceller(self, cb):
1244         pass # we won't use it
1245     def finish(self):
1246         return self.data
1247     # The following methods are just because the target might be a
1248     # repairer.DownUpConnector, and just because the current CHKUpload object
1249     # expects to find the storage index and encoding parameters in its
1250     # Uploadable.
1251     def set_storageindex(self, storageindex):
1252         pass
1253     def set_encodingparams(self, encodingparams):
1254         pass
1255
1256 class FileHandle:
1257     """Use me to download data to a pre-defined filehandle-like object. I
1258     will use the target's write() method. I will *not* close the filehandle:
1259     I leave that up to the originator of the filehandle. The download process
1260     will return the filehandle when it completes.
1261     """
1262     implements(IDownloadTarget)
1263     def __init__(self, filehandle):
1264         self._filehandle = filehandle
1265     def open(self, size):
1266         pass
1267     def write(self, data):
1268         self._filehandle.write(data)
1269     def close(self):
1270         # the originator of the filehandle reserves the right to close it
1271         pass
1272     def fail(self, why):
1273         pass
1274     def register_canceller(self, cb):
1275         pass
1276     def finish(self):
1277         return self._filehandle
1278     # The following methods are just because the target might be a
1279     # repairer.DownUpConnector, and just because the current CHKUpload object
1280     # expects to find the storage index and encoding parameters in its
1281     # Uploadable.
1282     def set_storageindex(self, storageindex):
1283         pass
1284     def set_encodingparams(self, encodingparams):
1285         pass
1286
1287 class ConsumerAdapter:
1288     implements(IDownloadTarget, IConsumer)
1289     def __init__(self, consumer):
1290         self._consumer = consumer
1291
1292     def registerProducer(self, producer, streaming):
1293         self._consumer.registerProducer(producer, streaming)
1294     def unregisterProducer(self):
1295         self._consumer.unregisterProducer()
1296
1297     def open(self, size):
1298         pass
1299     def write(self, data):
1300         self._consumer.write(data)
1301     def close(self):
1302         pass
1303
1304     def fail(self, why):
1305         pass
1306     def register_canceller(self, cb):
1307         pass
1308     def finish(self):
1309         return self._consumer
1310     # The following methods are just because the target might be a
1311     # repairer.DownUpConnector, and just because the current CHKUpload object
1312     # expects to find the storage index and encoding parameters in its
1313     # Uploadable.
1314     def set_storageindex(self, storageindex):
1315         pass
1316     def set_encodingparams(self, encodingparams):
1317         pass
1318
1319
1320 class Downloader:
1321     """I am a service that allows file downloading.
1322     """
1323     # TODO: in fact, this service only downloads immutable files (URI:CHK:).
1324     # It is scheduled to go away, to be replaced by filenode.download()
1325     implements(IDownloader)
1326
1327     def __init__(self, storage_broker, stats_provider):
1328         self.storage_broker = storage_broker
1329         self.stats_provider = stats_provider
1330         self._all_downloads = weakref.WeakKeyDictionary() # for debugging
1331
1332     def download(self, u, t, _log_msg_id=None, monitor=None, history=None):
1333         u = IFileURI(u)
1334         t = IDownloadTarget(t)
1335         assert t.write
1336         assert t.close
1337
1338         assert isinstance(u, uri.CHKFileURI)
1339         if self.stats_provider:
1340             # these counters are meant for network traffic, and don't
1341             # include LIT files
1342             self.stats_provider.count('downloader.files_downloaded', 1)
1343             self.stats_provider.count('downloader.bytes_downloaded', u.get_size())
1344
1345         target = DecryptingTarget(t, u.key, _log_msg_id=_log_msg_id)
1346         if not monitor:
1347             monitor=Monitor()
1348         dl = CiphertextDownloader(self.storage_broker,
1349                                   u.get_verify_cap(), target,
1350                                   monitor=monitor)
1351         self._all_downloads[dl] = None
1352         if history:
1353             history.add_download(dl.get_download_status())
1354         d = dl.start()
1355         return d
1356
1357     # utility functions
1358     def download_to_data(self, uri, _log_msg_id=None, history=None):
1359         return self.download(uri, Data(), _log_msg_id=_log_msg_id, history=history)
1360     def download_to_filename(self, uri, filename, _log_msg_id=None):
1361         return self.download(uri, FileName(filename), _log_msg_id=_log_msg_id)
1362     def download_to_filehandle(self, uri, filehandle, _log_msg_id=None):
1363         return self.download(uri, FileHandle(filehandle), _log_msg_id=_log_msg_id)