]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/immutable/upload.py
MDMF: more writable/writeable consistentifications
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / upload.py
1 import os, time, weakref, itertools
2 from zope.interface import implements
3 from twisted.python import failure
4 from twisted.internet import defer
5 from twisted.application import service
6 from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually
7
8 from allmydata.util.hashutil import file_renewal_secret_hash, \
9      file_cancel_secret_hash, bucket_renewal_secret_hash, \
10      bucket_cancel_secret_hash, plaintext_hasher, \
11      storage_index_hash, plaintext_segment_hasher, convergence_hasher
12 from allmydata import hashtree, uri
13 from allmydata.storage.server import si_b2a
14 from allmydata.immutable import encode
15 from allmydata.util import base32, dictutil, idlib, log, mathutil
16 from allmydata.util.happinessutil import servers_of_happiness, \
17                                          shares_by_server, merge_servers, \
18                                          failure_message
19 from allmydata.util.assertutil import precondition
20 from allmydata.util.rrefutil import add_version_to_remote_reference
21 from allmydata.interfaces import IUploadable, IUploader, IUploadResults, \
22      IEncryptedUploadable, RIEncryptedUploadable, IUploadStatus, \
23      NoServersError, InsufficientVersionError, UploadUnhappinessError, \
24      DEFAULT_MAX_SEGMENT_SIZE
25 from allmydata.immutable import layout
26 from pycryptopp.cipher.aes import AES
27
28 from cStringIO import StringIO
29
30
31 # this wants to live in storage, not here
32 class TooFullError(Exception):
33     pass
34
35 class UploadResults(Copyable, RemoteCopy):
36     implements(IUploadResults)
37     # note: don't change this string, it needs to match the value used on the
38     # helper, and it does *not* need to match the fully-qualified
39     # package/module/class name
40     typeToCopy = "allmydata.upload.UploadResults.tahoe.allmydata.com"
41     copytype = typeToCopy
42
43     # also, think twice about changing the shape of any existing attribute,
44     # because instances of this class are sent from the helper to its client,
45     # so changing this may break compatibility. Consider adding new fields
46     # instead of modifying existing ones.
47
48     def __init__(self):
49         self.timings = {} # dict of name to number of seconds
50         self.sharemap = dictutil.DictOfSets() # {shnum: set(serverid)}
51         self.servermap = dictutil.DictOfSets() # {serverid: set(shnum)}
52         self.file_size = None
53         self.ciphertext_fetched = None # how much the helper fetched
54         self.uri = None
55         self.preexisting_shares = None # count of shares already present
56         self.pushed_shares = None # count of shares we pushed
57
58
59 # our current uri_extension is 846 bytes for small files, a few bytes
60 # more for larger ones (since the filesize is encoded in decimal in a
61 # few places). Ask for a little bit more just in case we need it. If
62 # the extension changes size, we can change EXTENSION_SIZE to
63 # allocate a more accurate amount of space.
64 EXTENSION_SIZE = 1000
65 # TODO: actual extensions are closer to 419 bytes, so we can probably lower
66 # this.
67
68 def pretty_print_shnum_to_servers(s):
69     return ', '.join([ "sh%s: %s" % (k, '+'.join([idlib.shortnodeid_b2a(x) for x in v])) for k, v in s.iteritems() ])
70
71 class ServerTracker:
72     def __init__(self, server,
73                  sharesize, blocksize, num_segments, num_share_hashes,
74                  storage_index,
75                  bucket_renewal_secret, bucket_cancel_secret):
76         self._server = server
77         self.buckets = {} # k: shareid, v: IRemoteBucketWriter
78         self.sharesize = sharesize
79
80         wbp = layout.make_write_bucket_proxy(None, None, sharesize,
81                                              blocksize, num_segments,
82                                              num_share_hashes,
83                                              EXTENSION_SIZE)
84         self.wbp_class = wbp.__class__ # to create more of them
85         self.allocated_size = wbp.get_allocated_size()
86         self.blocksize = blocksize
87         self.num_segments = num_segments
88         self.num_share_hashes = num_share_hashes
89         self.storage_index = storage_index
90
91         self.renew_secret = bucket_renewal_secret
92         self.cancel_secret = bucket_cancel_secret
93
94     def __repr__(self):
95         return ("<ServerTracker for server %s and SI %s>"
96                 % (self._server.get_name(), si_b2a(self.storage_index)[:5]))
97
98     def get_serverid(self):
99         return self._server.get_serverid()
100     def get_name(self):
101         return self._server.get_name()
102
103     def query(self, sharenums):
104         rref = self._server.get_rref()
105         d = rref.callRemote("allocate_buckets",
106                             self.storage_index,
107                             self.renew_secret,
108                             self.cancel_secret,
109                             sharenums,
110                             self.allocated_size,
111                             canary=Referenceable())
112         d.addCallback(self._got_reply)
113         return d
114
115     def ask_about_existing_shares(self):
116         rref = self._server.get_rref()
117         return rref.callRemote("get_buckets", self.storage_index)
118
119     def _got_reply(self, (alreadygot, buckets)):
120         #log.msg("%s._got_reply(%s)" % (self, (alreadygot, buckets)))
121         b = {}
122         for sharenum, rref in buckets.iteritems():
123             bp = self.wbp_class(rref, self._server, self.sharesize,
124                                 self.blocksize,
125                                 self.num_segments,
126                                 self.num_share_hashes,
127                                 EXTENSION_SIZE)
128             b[sharenum] = bp
129         self.buckets.update(b)
130         return (alreadygot, set(b.keys()))
131
132
133     def abort(self):
134         """
135         I abort the remote bucket writers for all shares. This is a good idea
136         to conserve space on the storage server.
137         """
138         self.abort_some_buckets(self.buckets.keys())
139
140     def abort_some_buckets(self, sharenums):
141         """
142         I abort the remote bucket writers for the share numbers in sharenums.
143         """
144         for sharenum in sharenums:
145             if sharenum in self.buckets:
146                 self.buckets[sharenum].abort()
147                 del self.buckets[sharenum]
148
149
150 def str_shareloc(shnum, bucketwriter):
151     return "%s: %s" % (shnum, bucketwriter.get_servername(),)
152
153 class Tahoe2ServerSelector(log.PrefixingLogMixin):
154
155     def __init__(self, upload_id, logparent=None, upload_status=None):
156         self.upload_id = upload_id
157         self.query_count, self.good_query_count, self.bad_query_count = 0,0,0
158         # Servers that are working normally, but full.
159         self.full_count = 0
160         self.error_count = 0
161         self.num_servers_contacted = 0
162         self.last_failure_msg = None
163         self._status = IUploadStatus(upload_status)
164         log.PrefixingLogMixin.__init__(self, 'tahoe.immutable.upload', logparent, prefix=upload_id)
165         self.log("starting", level=log.OPERATIONAL)
166
167     def __repr__(self):
168         return "<Tahoe2ServerSelector for upload %s>" % self.upload_id
169
170     def get_shareholders(self, storage_broker, secret_holder,
171                          storage_index, share_size, block_size,
172                          num_segments, total_shares, needed_shares,
173                          servers_of_happiness):
174         """
175         @return: (upload_trackers, already_serverids), where upload_trackers
176                  is a set of ServerTracker instances that have agreed to hold
177                  some shares for us (the shareids are stashed inside the
178                  ServerTracker), and already_serverids is a dict mapping
179                  shnum to a set of serverids for servers which claim to
180                  already have the share.
181         """
182
183         if self._status:
184             self._status.set_status("Contacting Servers..")
185
186         self.total_shares = total_shares
187         self.servers_of_happiness = servers_of_happiness
188         self.needed_shares = needed_shares
189
190         self.homeless_shares = set(range(total_shares))
191         self.use_trackers = set() # ServerTrackers that have shares assigned
192                                   # to them
193         self.preexisting_shares = {} # shareid => set(serverids) holding shareid
194
195         # These servers have shares -- any shares -- for our SI. We keep
196         # track of these to write an error message with them later.
197         self.serverids_with_shares = set()
198
199         # this needed_hashes computation should mirror
200         # Encoder.send_all_share_hash_trees. We use an IncompleteHashTree
201         # (instead of a HashTree) because we don't require actual hashing
202         # just to count the levels.
203         ht = hashtree.IncompleteHashTree(total_shares)
204         num_share_hashes = len(ht.needed_hashes(0, include_leaf=True))
205
206         # figure out how much space to ask for
207         wbp = layout.make_write_bucket_proxy(None, None,
208                                              share_size, 0, num_segments,
209                                              num_share_hashes, EXTENSION_SIZE)
210         allocated_size = wbp.get_allocated_size()
211         all_servers = storage_broker.get_servers_for_psi(storage_index)
212         if not all_servers:
213             raise NoServersError("client gave us zero servers")
214
215         # filter the list of servers according to which ones can accomodate
216         # this request. This excludes older servers (which used a 4-byte size
217         # field) from getting large shares (for files larger than about
218         # 12GiB). See #439 for details.
219         def _get_maxsize(server):
220             v0 = server.get_rref().version
221             v1 = v0["http://allmydata.org/tahoe/protocols/storage/v1"]
222             return v1["maximum-immutable-share-size"]
223         writeable_servers = [server for server in all_servers
224                             if _get_maxsize(server) >= allocated_size]
225         readonly_servers = set(all_servers[:2*total_shares]) - set(writeable_servers)
226
227         # decide upon the renewal/cancel secrets, to include them in the
228         # allocate_buckets query.
229         client_renewal_secret = secret_holder.get_renewal_secret()
230         client_cancel_secret = secret_holder.get_cancel_secret()
231
232         file_renewal_secret = file_renewal_secret_hash(client_renewal_secret,
233                                                        storage_index)
234         file_cancel_secret = file_cancel_secret_hash(client_cancel_secret,
235                                                      storage_index)
236         def _make_trackers(servers):
237             trackers = []
238             for s in servers:
239                 seed = s.get_lease_seed()
240                 renew = bucket_renewal_secret_hash(file_renewal_secret, seed)
241                 cancel = bucket_cancel_secret_hash(file_cancel_secret, seed)
242                 st = ServerTracker(s,
243                                    share_size, block_size,
244                                    num_segments, num_share_hashes,
245                                    storage_index,
246                                    renew, cancel)
247                 trackers.append(st)
248             return trackers
249
250         # We assign each servers/trackers into one three lists. They all
251         # start in the "first pass" list. During the first pass, as we ask
252         # each one to hold a share, we move their tracker to the "second
253         # pass" list, until the first-pass list is empty. Then during the
254         # second pass, as we ask each to hold more shares, we move their
255         # tracker to the "next pass" list, until the second-pass list is
256         # empty. Then we move everybody from the next-pass list back to the
257         # second-pass list and repeat the "second" pass (really the third,
258         # fourth, etc pass), until all shares are assigned, or we've run out
259         # of potential servers.
260         self.first_pass_trackers = _make_trackers(writeable_servers)
261         self.second_pass_trackers = [] # servers worth asking again
262         self.next_pass_trackers = [] # servers that we have asked again
263         self._started_second_pass = False
264
265         # We don't try to allocate shares to these servers, since they've
266         # said that they're incapable of storing shares of the size that we'd
267         # want to store. We ask them about existing shares for this storage
268         # index, which we want to know about for accurate
269         # servers_of_happiness accounting, then we forget about them.
270         readonly_trackers = _make_trackers(readonly_servers)
271
272         # We now ask servers that can't hold any new shares about existing
273         # shares that they might have for our SI. Once this is done, we
274         # start placing the shares that we haven't already accounted
275         # for.
276         ds = []
277         if self._status and readonly_trackers:
278             self._status.set_status("Contacting readonly servers to find "
279                                     "any existing shares")
280         for tracker in readonly_trackers:
281             assert isinstance(tracker, ServerTracker)
282             d = tracker.ask_about_existing_shares()
283             d.addBoth(self._handle_existing_response, tracker)
284             ds.append(d)
285             self.num_servers_contacted += 1
286             self.query_count += 1
287             self.log("asking server %s for any existing shares" %
288                      (tracker.get_name(),), level=log.NOISY)
289         dl = defer.DeferredList(ds)
290         dl.addCallback(lambda ign: self._loop())
291         return dl
292
293
294     def _handle_existing_response(self, res, tracker):
295         """
296         I handle responses to the queries sent by
297         Tahoe2ServerSelector._existing_shares.
298         """
299         serverid = tracker.get_serverid()
300         if isinstance(res, failure.Failure):
301             self.log("%s got error during existing shares check: %s"
302                     % (tracker.get_name(), res), level=log.UNUSUAL)
303             self.error_count += 1
304             self.bad_query_count += 1
305         else:
306             buckets = res
307             if buckets:
308                 self.serverids_with_shares.add(serverid)
309             self.log("response to get_buckets() from server %s: alreadygot=%s"
310                     % (tracker.get_name(), tuple(sorted(buckets))),
311                     level=log.NOISY)
312             for bucket in buckets:
313                 self.preexisting_shares.setdefault(bucket, set()).add(serverid)
314                 self.homeless_shares.discard(bucket)
315             self.full_count += 1
316             self.bad_query_count += 1
317
318
319     def _get_progress_message(self):
320         if not self.homeless_shares:
321             msg = "placed all %d shares, " % (self.total_shares)
322         else:
323             msg = ("placed %d shares out of %d total (%d homeless), " %
324                    (self.total_shares - len(self.homeless_shares),
325                     self.total_shares,
326                     len(self.homeless_shares)))
327         return (msg + "want to place shares on at least %d servers such that "
328                       "any %d of them have enough shares to recover the file, "
329                       "sent %d queries to %d servers, "
330                       "%d queries placed some shares, %d placed none "
331                       "(of which %d placed none due to the server being"
332                       " full and %d placed none due to an error)" %
333                         (self.servers_of_happiness, self.needed_shares,
334                          self.query_count, self.num_servers_contacted,
335                          self.good_query_count, self.bad_query_count,
336                          self.full_count, self.error_count))
337
338
339     def _loop(self):
340         if not self.homeless_shares:
341             merged = merge_servers(self.preexisting_shares, self.use_trackers)
342             effective_happiness = servers_of_happiness(merged)
343             if self.servers_of_happiness <= effective_happiness:
344                 msg = ("server selection successful for %s: %s: pretty_print_merged: %s, "
345                        "self.use_trackers: %s, self.preexisting_shares: %s") \
346                        % (self, self._get_progress_message(),
347                           pretty_print_shnum_to_servers(merged),
348                           [', '.join([str_shareloc(k,v)
349                                       for k,v in st.buckets.iteritems()])
350                            for st in self.use_trackers],
351                           pretty_print_shnum_to_servers(self.preexisting_shares))
352                 self.log(msg, level=log.OPERATIONAL)
353                 return (self.use_trackers, self.preexisting_shares)
354             else:
355                 # We're not okay right now, but maybe we can fix it by
356                 # redistributing some shares. In cases where one or two
357                 # servers has, before the upload, all or most of the
358                 # shares for a given SI, this can work by allowing _loop
359                 # a chance to spread those out over the other servers,
360                 delta = self.servers_of_happiness - effective_happiness
361                 shares = shares_by_server(self.preexisting_shares)
362                 # Each server in shares maps to a set of shares stored on it.
363                 # Since we want to keep at least one share on each server
364                 # that has one (otherwise we'd only be making
365                 # the situation worse by removing distinct servers),
366                 # each server has len(its shares) - 1 to spread around.
367                 shares_to_spread = sum([len(list(sharelist)) - 1
368                                         for (server, sharelist)
369                                         in shares.items()])
370                 if delta <= len(self.first_pass_trackers) and \
371                    shares_to_spread >= delta:
372                     items = shares.items()
373                     while len(self.homeless_shares) < delta:
374                         # Loop through the allocated shares, removing
375                         # one from each server that has more than one
376                         # and putting it back into self.homeless_shares
377                         # until we've done this delta times.
378                         server, sharelist = items.pop()
379                         if len(sharelist) > 1:
380                             share = sharelist.pop()
381                             self.homeless_shares.add(share)
382                             self.preexisting_shares[share].remove(server)
383                             if not self.preexisting_shares[share]:
384                                 del self.preexisting_shares[share]
385                             items.append((server, sharelist))
386                         for writer in self.use_trackers:
387                             writer.abort_some_buckets(self.homeless_shares)
388                     return self._loop()
389                 else:
390                     # Redistribution won't help us; fail.
391                     server_count = len(self.serverids_with_shares)
392                     failmsg = failure_message(server_count,
393                                               self.needed_shares,
394                                               self.servers_of_happiness,
395                                               effective_happiness)
396                     servmsgtempl = "server selection unsuccessful for %r: %s (%s), merged=%s"
397                     servmsg = servmsgtempl % (
398                         self,
399                         failmsg,
400                         self._get_progress_message(),
401                         pretty_print_shnum_to_servers(merged)
402                         )
403                     self.log(servmsg, level=log.INFREQUENT)
404                     return self._failed("%s (%s)" % (failmsg, self._get_progress_message()))
405
406         if self.first_pass_trackers:
407             tracker = self.first_pass_trackers.pop(0)
408             # TODO: don't pre-convert all serverids to ServerTrackers
409             assert isinstance(tracker, ServerTracker)
410
411             shares_to_ask = set(sorted(self.homeless_shares)[:1])
412             self.homeless_shares -= shares_to_ask
413             self.query_count += 1
414             self.num_servers_contacted += 1
415             if self._status:
416                 self._status.set_status("Contacting Servers [%s] (first query),"
417                                         " %d shares left.."
418                                         % (tracker.get_name(),
419                                            len(self.homeless_shares)))
420             d = tracker.query(shares_to_ask)
421             d.addBoth(self._got_response, tracker, shares_to_ask,
422                       self.second_pass_trackers)
423             return d
424         elif self.second_pass_trackers:
425             # ask a server that we've already asked.
426             if not self._started_second_pass:
427                 self.log("starting second pass",
428                         level=log.NOISY)
429                 self._started_second_pass = True
430             num_shares = mathutil.div_ceil(len(self.homeless_shares),
431                                            len(self.second_pass_trackers))
432             tracker = self.second_pass_trackers.pop(0)
433             shares_to_ask = set(sorted(self.homeless_shares)[:num_shares])
434             self.homeless_shares -= shares_to_ask
435             self.query_count += 1
436             if self._status:
437                 self._status.set_status("Contacting Servers [%s] (second query),"
438                                         " %d shares left.."
439                                         % (tracker.get_name(),
440                                            len(self.homeless_shares)))
441             d = tracker.query(shares_to_ask)
442             d.addBoth(self._got_response, tracker, shares_to_ask,
443                       self.next_pass_trackers)
444             return d
445         elif self.next_pass_trackers:
446             # we've finished the second-or-later pass. Move all the remaining
447             # servers back into self.second_pass_trackers for the next pass.
448             self.second_pass_trackers.extend(self.next_pass_trackers)
449             self.next_pass_trackers[:] = []
450             return self._loop()
451         else:
452             # no more servers. If we haven't placed enough shares, we fail.
453             merged = merge_servers(self.preexisting_shares, self.use_trackers)
454             effective_happiness = servers_of_happiness(merged)
455             if effective_happiness < self.servers_of_happiness:
456                 msg = failure_message(len(self.serverids_with_shares),
457                                       self.needed_shares,
458                                       self.servers_of_happiness,
459                                       effective_happiness)
460                 msg = ("server selection failed for %s: %s (%s)" %
461                        (self, msg, self._get_progress_message()))
462                 if self.last_failure_msg:
463                     msg += " (%s)" % (self.last_failure_msg,)
464                 self.log(msg, level=log.UNUSUAL)
465                 return self._failed(msg)
466             else:
467                 # we placed enough to be happy, so we're done
468                 if self._status:
469                     self._status.set_status("Placed all shares")
470                 msg = ("server selection successful (no more servers) for %s: %s: %s" % (self,
471                             self._get_progress_message(), pretty_print_shnum_to_servers(merged)))
472                 self.log(msg, level=log.OPERATIONAL)
473                 return (self.use_trackers, self.preexisting_shares)
474
475     def _got_response(self, res, tracker, shares_to_ask, put_tracker_here):
476         if isinstance(res, failure.Failure):
477             # This is unusual, and probably indicates a bug or a network
478             # problem.
479             self.log("%s got error during server selection: %s" % (tracker, res),
480                     level=log.UNUSUAL)
481             self.error_count += 1
482             self.bad_query_count += 1
483             self.homeless_shares |= shares_to_ask
484             if (self.first_pass_trackers
485                 or self.second_pass_trackers
486                 or self.next_pass_trackers):
487                 # there is still hope, so just loop
488                 pass
489             else:
490                 # No more servers, so this upload might fail (it depends upon
491                 # whether we've hit servers_of_happiness or not). Log the last
492                 # failure we got: if a coding error causes all servers to fail
493                 # in the same way, this allows the common failure to be seen
494                 # by the uploader and should help with debugging
495                 msg = ("last failure (from %s) was: %s" % (tracker, res))
496                 self.last_failure_msg = msg
497         else:
498             (alreadygot, allocated) = res
499             self.log("response to allocate_buckets() from server %s: alreadygot=%s, allocated=%s"
500                     % (tracker.get_name(),
501                        tuple(sorted(alreadygot)), tuple(sorted(allocated))),
502                     level=log.NOISY)
503             progress = False
504             for s in alreadygot:
505                 self.preexisting_shares.setdefault(s, set()).add(tracker.get_serverid())
506                 if s in self.homeless_shares:
507                     self.homeless_shares.remove(s)
508                     progress = True
509                 elif s in shares_to_ask:
510                     progress = True
511
512             # the ServerTracker will remember which shares were allocated on
513             # that peer. We just have to remember to use them.
514             if allocated:
515                 self.use_trackers.add(tracker)
516                 progress = True
517
518             if allocated or alreadygot:
519                 self.serverids_with_shares.add(tracker.get_serverid())
520
521             not_yet_present = set(shares_to_ask) - set(alreadygot)
522             still_homeless = not_yet_present - set(allocated)
523
524             if progress:
525                 # They accepted at least one of the shares that we asked
526                 # them to accept, or they had a share that we didn't ask
527                 # them to accept but that we hadn't placed yet, so this
528                 # was a productive query
529                 self.good_query_count += 1
530             else:
531                 self.bad_query_count += 1
532                 self.full_count += 1
533
534             if still_homeless:
535                 # In networks with lots of space, this is very unusual and
536                 # probably indicates an error. In networks with servers that
537                 # are full, it is merely unusual. In networks that are very
538                 # full, it is common, and many uploads will fail. In most
539                 # cases, this is obviously not fatal, and we'll just use some
540                 # other servers.
541
542                 # some shares are still homeless, keep trying to find them a
543                 # home. The ones that were rejected get first priority.
544                 self.homeless_shares |= still_homeless
545                 # Since they were unable to accept all of our requests, so it
546                 # is safe to assume that asking them again won't help.
547             else:
548                 # if they *were* able to accept everything, they might be
549                 # willing to accept even more.
550                 put_tracker_here.append(tracker)
551
552         # now loop
553         return self._loop()
554
555
556     def _failed(self, msg):
557         """
558         I am called when server selection fails. I first abort all of the
559         remote buckets that I allocated during my unsuccessful attempt to
560         place shares for this file. I then raise an
561         UploadUnhappinessError with my msg argument.
562         """
563         for tracker in self.use_trackers:
564             assert isinstance(tracker, ServerTracker)
565             tracker.abort()
566         raise UploadUnhappinessError(msg)
567
568
569 class EncryptAnUploadable:
570     """This is a wrapper that takes an IUploadable and provides
571     IEncryptedUploadable."""
572     implements(IEncryptedUploadable)
573     CHUNKSIZE = 50*1024
574
575     def __init__(self, original, log_parent=None):
576         self.original = IUploadable(original)
577         self._log_number = log_parent
578         self._encryptor = None
579         self._plaintext_hasher = plaintext_hasher()
580         self._plaintext_segment_hasher = None
581         self._plaintext_segment_hashes = []
582         self._encoding_parameters = None
583         self._file_size = None
584         self._ciphertext_bytes_read = 0
585         self._status = None
586
587     def set_upload_status(self, upload_status):
588         self._status = IUploadStatus(upload_status)
589         self.original.set_upload_status(upload_status)
590
591     def log(self, *args, **kwargs):
592         if "facility" not in kwargs:
593             kwargs["facility"] = "upload.encryption"
594         if "parent" not in kwargs:
595             kwargs["parent"] = self._log_number
596         return log.msg(*args, **kwargs)
597
598     def get_size(self):
599         if self._file_size is not None:
600             return defer.succeed(self._file_size)
601         d = self.original.get_size()
602         def _got_size(size):
603             self._file_size = size
604             if self._status:
605                 self._status.set_size(size)
606             return size
607         d.addCallback(_got_size)
608         return d
609
610     def get_all_encoding_parameters(self):
611         if self._encoding_parameters is not None:
612             return defer.succeed(self._encoding_parameters)
613         d = self.original.get_all_encoding_parameters()
614         def _got(encoding_parameters):
615             (k, happy, n, segsize) = encoding_parameters
616             self._segment_size = segsize # used by segment hashers
617             self._encoding_parameters = encoding_parameters
618             self.log("my encoding parameters: %s" % (encoding_parameters,),
619                      level=log.NOISY)
620             return encoding_parameters
621         d.addCallback(_got)
622         return d
623
624     def _get_encryptor(self):
625         if self._encryptor:
626             return defer.succeed(self._encryptor)
627
628         d = self.original.get_encryption_key()
629         def _got(key):
630             e = AES(key)
631             self._encryptor = e
632
633             storage_index = storage_index_hash(key)
634             assert isinstance(storage_index, str)
635             # There's no point to having the SI be longer than the key, so we
636             # specify that it is truncated to the same 128 bits as the AES key.
637             assert len(storage_index) == 16  # SHA-256 truncated to 128b
638             self._storage_index = storage_index
639             if self._status:
640                 self._status.set_storage_index(storage_index)
641             return e
642         d.addCallback(_got)
643         return d
644
645     def get_storage_index(self):
646         d = self._get_encryptor()
647         d.addCallback(lambda res: self._storage_index)
648         return d
649
650     def _get_segment_hasher(self):
651         p = self._plaintext_segment_hasher
652         if p:
653             left = self._segment_size - self._plaintext_segment_hashed_bytes
654             return p, left
655         p = plaintext_segment_hasher()
656         self._plaintext_segment_hasher = p
657         self._plaintext_segment_hashed_bytes = 0
658         return p, self._segment_size
659
660     def _update_segment_hash(self, chunk):
661         offset = 0
662         while offset < len(chunk):
663             p, segment_left = self._get_segment_hasher()
664             chunk_left = len(chunk) - offset
665             this_segment = min(chunk_left, segment_left)
666             p.update(chunk[offset:offset+this_segment])
667             self._plaintext_segment_hashed_bytes += this_segment
668
669             if self._plaintext_segment_hashed_bytes == self._segment_size:
670                 # we've filled this segment
671                 self._plaintext_segment_hashes.append(p.digest())
672                 self._plaintext_segment_hasher = None
673                 self.log("closed hash [%d]: %dB" %
674                          (len(self._plaintext_segment_hashes)-1,
675                           self._plaintext_segment_hashed_bytes),
676                          level=log.NOISY)
677                 self.log(format="plaintext leaf hash [%(segnum)d] is %(hash)s",
678                          segnum=len(self._plaintext_segment_hashes)-1,
679                          hash=base32.b2a(p.digest()),
680                          level=log.NOISY)
681
682             offset += this_segment
683
684
685     def read_encrypted(self, length, hash_only):
686         # make sure our parameters have been set up first
687         d = self.get_all_encoding_parameters()
688         # and size
689         d.addCallback(lambda ignored: self.get_size())
690         d.addCallback(lambda ignored: self._get_encryptor())
691         # then fetch and encrypt the plaintext. The unusual structure here
692         # (passing a Deferred *into* a function) is needed to avoid
693         # overflowing the stack: Deferreds don't optimize out tail recursion.
694         # We also pass in a list, to which _read_encrypted will append
695         # ciphertext.
696         ciphertext = []
697         d2 = defer.Deferred()
698         d.addCallback(lambda ignored:
699                       self._read_encrypted(length, ciphertext, hash_only, d2))
700         d.addCallback(lambda ignored: d2)
701         return d
702
703     def _read_encrypted(self, remaining, ciphertext, hash_only, fire_when_done):
704         if not remaining:
705             fire_when_done.callback(ciphertext)
706             return None
707         # tolerate large length= values without consuming a lot of RAM by
708         # reading just a chunk (say 50kB) at a time. This only really matters
709         # when hash_only==True (i.e. resuming an interrupted upload), since
710         # that's the case where we will be skipping over a lot of data.
711         size = min(remaining, self.CHUNKSIZE)
712         remaining = remaining - size
713         # read a chunk of plaintext..
714         d = defer.maybeDeferred(self.original.read, size)
715         # N.B.: if read() is synchronous, then since everything else is
716         # actually synchronous too, we'd blow the stack unless we stall for a
717         # tick. Once you accept a Deferred from IUploadable.read(), you must
718         # be prepared to have it fire immediately too.
719         d.addCallback(fireEventually)
720         def _good(plaintext):
721             # and encrypt it..
722             # o/' over the fields we go, hashing all the way, sHA! sHA! sHA! o/'
723             ct = self._hash_and_encrypt_plaintext(plaintext, hash_only)
724             ciphertext.extend(ct)
725             self._read_encrypted(remaining, ciphertext, hash_only,
726                                  fire_when_done)
727         def _err(why):
728             fire_when_done.errback(why)
729         d.addCallback(_good)
730         d.addErrback(_err)
731         return None
732
733     def _hash_and_encrypt_plaintext(self, data, hash_only):
734         assert isinstance(data, (tuple, list)), type(data)
735         data = list(data)
736         cryptdata = []
737         # we use data.pop(0) instead of 'for chunk in data' to save
738         # memory: each chunk is destroyed as soon as we're done with it.
739         bytes_processed = 0
740         while data:
741             chunk = data.pop(0)
742             self.log(" read_encrypted handling %dB-sized chunk" % len(chunk),
743                      level=log.NOISY)
744             bytes_processed += len(chunk)
745             self._plaintext_hasher.update(chunk)
746             self._update_segment_hash(chunk)
747             # TODO: we have to encrypt the data (even if hash_only==True)
748             # because pycryptopp's AES-CTR implementation doesn't offer a
749             # way to change the counter value. Once pycryptopp acquires
750             # this ability, change this to simply update the counter
751             # before each call to (hash_only==False) _encryptor.process()
752             ciphertext = self._encryptor.process(chunk)
753             if hash_only:
754                 self.log("  skipping encryption", level=log.NOISY)
755             else:
756                 cryptdata.append(ciphertext)
757             del ciphertext
758             del chunk
759         self._ciphertext_bytes_read += bytes_processed
760         if self._status:
761             progress = float(self._ciphertext_bytes_read) / self._file_size
762             self._status.set_progress(1, progress)
763         return cryptdata
764
765
766     def get_plaintext_hashtree_leaves(self, first, last, num_segments):
767         # this is currently unused, but will live again when we fix #453
768         if len(self._plaintext_segment_hashes) < num_segments:
769             # close out the last one
770             assert len(self._plaintext_segment_hashes) == num_segments-1
771             p, segment_left = self._get_segment_hasher()
772             self._plaintext_segment_hashes.append(p.digest())
773             del self._plaintext_segment_hasher
774             self.log("closing plaintext leaf hasher, hashed %d bytes" %
775                      self._plaintext_segment_hashed_bytes,
776                      level=log.NOISY)
777             self.log(format="plaintext leaf hash [%(segnum)d] is %(hash)s",
778                      segnum=len(self._plaintext_segment_hashes)-1,
779                      hash=base32.b2a(p.digest()),
780                      level=log.NOISY)
781         assert len(self._plaintext_segment_hashes) == num_segments
782         return defer.succeed(tuple(self._plaintext_segment_hashes[first:last]))
783
784     def get_plaintext_hash(self):
785         h = self._plaintext_hasher.digest()
786         return defer.succeed(h)
787
788     def close(self):
789         return self.original.close()
790
791 class UploadStatus:
792     implements(IUploadStatus)
793     statusid_counter = itertools.count(0)
794
795     def __init__(self):
796         self.storage_index = None
797         self.size = None
798         self.helper = False
799         self.status = "Not started"
800         self.progress = [0.0, 0.0, 0.0]
801         self.active = True
802         self.results = None
803         self.counter = self.statusid_counter.next()
804         self.started = time.time()
805
806     def get_started(self):
807         return self.started
808     def get_storage_index(self):
809         return self.storage_index
810     def get_size(self):
811         return self.size
812     def using_helper(self):
813         return self.helper
814     def get_status(self):
815         return self.status
816     def get_progress(self):
817         return tuple(self.progress)
818     def get_active(self):
819         return self.active
820     def get_results(self):
821         return self.results
822     def get_counter(self):
823         return self.counter
824
825     def set_storage_index(self, si):
826         self.storage_index = si
827     def set_size(self, size):
828         self.size = size
829     def set_helper(self, helper):
830         self.helper = helper
831     def set_status(self, status):
832         self.status = status
833     def set_progress(self, which, value):
834         # [0]: chk, [1]: ciphertext, [2]: encode+push
835         self.progress[which] = value
836     def set_active(self, value):
837         self.active = value
838     def set_results(self, value):
839         self.results = value
840
841 class CHKUploader:
842     server_selector_class = Tahoe2ServerSelector
843
844     def __init__(self, storage_broker, secret_holder):
845         # server_selector needs storage_broker and secret_holder
846         self._storage_broker = storage_broker
847         self._secret_holder = secret_holder
848         self._log_number = self.log("CHKUploader starting", parent=None)
849         self._encoder = None
850         self._results = UploadResults()
851         self._storage_index = None
852         self._upload_status = UploadStatus()
853         self._upload_status.set_helper(False)
854         self._upload_status.set_active(True)
855         self._upload_status.set_results(self._results)
856
857         # locate_all_shareholders() will create the following attribute:
858         # self._server_trackers = {} # k: shnum, v: instance of ServerTracker
859
860     def log(self, *args, **kwargs):
861         if "parent" not in kwargs:
862             kwargs["parent"] = self._log_number
863         if "facility" not in kwargs:
864             kwargs["facility"] = "tahoe.upload"
865         return log.msg(*args, **kwargs)
866
867     def start(self, encrypted_uploadable):
868         """Start uploading the file.
869
870         Returns a Deferred that will fire with the UploadResults instance.
871         """
872
873         self._started = time.time()
874         eu = IEncryptedUploadable(encrypted_uploadable)
875         self.log("starting upload of %s" % eu)
876
877         eu.set_upload_status(self._upload_status)
878         d = self.start_encrypted(eu)
879         def _done(uploadresults):
880             self._upload_status.set_active(False)
881             return uploadresults
882         d.addBoth(_done)
883         return d
884
885     def abort(self):
886         """Call this if the upload must be abandoned before it completes.
887         This will tell the shareholders to delete their partial shares. I
888         return a Deferred that fires when these messages have been acked."""
889         if not self._encoder:
890             # how did you call abort() before calling start() ?
891             return defer.succeed(None)
892         return self._encoder.abort()
893
894     def start_encrypted(self, encrypted):
895         """ Returns a Deferred that will fire with the UploadResults instance. """
896         eu = IEncryptedUploadable(encrypted)
897
898         started = time.time()
899         self._encoder = e = encode.Encoder(self._log_number,
900                                            self._upload_status)
901         d = e.set_encrypted_uploadable(eu)
902         d.addCallback(self.locate_all_shareholders, started)
903         d.addCallback(self.set_shareholders, e)
904         d.addCallback(lambda res: e.start())
905         d.addCallback(self._encrypted_done)
906         return d
907
908     def locate_all_shareholders(self, encoder, started):
909         server_selection_started = now = time.time()
910         self._storage_index_elapsed = now - started
911         storage_broker = self._storage_broker
912         secret_holder = self._secret_holder
913         storage_index = encoder.get_param("storage_index")
914         self._storage_index = storage_index
915         upload_id = si_b2a(storage_index)[:5]
916         self.log("using storage index %s" % upload_id)
917         server_selector = self.server_selector_class(upload_id,
918                                                      self._log_number,
919                                                      self._upload_status)
920
921         share_size = encoder.get_param("share_size")
922         block_size = encoder.get_param("block_size")
923         num_segments = encoder.get_param("num_segments")
924         k,desired,n = encoder.get_param("share_counts")
925
926         self._server_selection_started = time.time()
927         d = server_selector.get_shareholders(storage_broker, secret_holder,
928                                              storage_index,
929                                              share_size, block_size,
930                                              num_segments, n, k, desired)
931         def _done(res):
932             self._server_selection_elapsed = time.time() - server_selection_started
933             return res
934         d.addCallback(_done)
935         return d
936
937     def set_shareholders(self, (upload_trackers, already_serverids), encoder):
938         """
939         @param upload_trackers: a sequence of ServerTracker objects that
940                                 have agreed to hold some shares for us (the
941                                 shareids are stashed inside the ServerTracker)
942
943         @paran already_serverids: a dict mapping sharenum to a set of
944                                   serverids for servers that claim to already
945                                   have this share
946         """
947         msgtempl = "set_shareholders; upload_trackers is %s, already_serverids is %s"
948         values = ([', '.join([str_shareloc(k,v)
949                               for k,v in st.buckets.iteritems()])
950                    for st in upload_trackers], already_serverids)
951         self.log(msgtempl % values, level=log.OPERATIONAL)
952         # record already-present shares in self._results
953         self._results.preexisting_shares = len(already_serverids)
954
955         self._server_trackers = {} # k: shnum, v: instance of ServerTracker
956         for tracker in upload_trackers:
957             assert isinstance(tracker, ServerTracker)
958         buckets = {}
959         servermap = already_serverids.copy()
960         for tracker in upload_trackers:
961             buckets.update(tracker.buckets)
962             for shnum in tracker.buckets:
963                 self._server_trackers[shnum] = tracker
964                 servermap.setdefault(shnum, set()).add(tracker.get_serverid())
965         assert len(buckets) == sum([len(tracker.buckets)
966                                     for tracker in upload_trackers]), \
967             "%s (%s) != %s (%s)" % (
968                 len(buckets),
969                 buckets,
970                 sum([len(tracker.buckets) for tracker in upload_trackers]),
971                 [(t.buckets, t.get_serverid()) for t in upload_trackers]
972                 )
973         encoder.set_shareholders(buckets, servermap)
974
975     def _encrypted_done(self, verifycap):
976         """ Returns a Deferred that will fire with the UploadResults instance. """
977         r = self._results
978         for shnum in self._encoder.get_shares_placed():
979             server_tracker = self._server_trackers[shnum]
980             serverid = server_tracker.get_serverid()
981             r.sharemap.add(shnum, serverid)
982             r.servermap.add(serverid, shnum)
983         r.pushed_shares = len(self._encoder.get_shares_placed())
984         now = time.time()
985         r.file_size = self._encoder.file_size
986         r.timings["total"] = now - self._started
987         r.timings["storage_index"] = self._storage_index_elapsed
988         r.timings["peer_selection"] = self._server_selection_elapsed
989         r.timings.update(self._encoder.get_times())
990         r.uri_extension_data = self._encoder.get_uri_extension_data()
991         r.verifycapstr = verifycap.to_string()
992         return r
993
994     def get_upload_status(self):
995         return self._upload_status
996
997 def read_this_many_bytes(uploadable, size, prepend_data=[]):
998     if size == 0:
999         return defer.succeed([])
1000     d = uploadable.read(size)
1001     def _got(data):
1002         assert isinstance(data, list)
1003         bytes = sum([len(piece) for piece in data])
1004         assert bytes > 0
1005         assert bytes <= size
1006         remaining = size - bytes
1007         if remaining:
1008             return read_this_many_bytes(uploadable, remaining,
1009                                         prepend_data + data)
1010         return prepend_data + data
1011     d.addCallback(_got)
1012     return d
1013
1014 class LiteralUploader:
1015
1016     def __init__(self):
1017         self._results = UploadResults()
1018         self._status = s = UploadStatus()
1019         s.set_storage_index(None)
1020         s.set_helper(False)
1021         s.set_progress(0, 1.0)
1022         s.set_active(False)
1023         s.set_results(self._results)
1024
1025     def start(self, uploadable):
1026         uploadable = IUploadable(uploadable)
1027         d = uploadable.get_size()
1028         def _got_size(size):
1029             self._size = size
1030             self._status.set_size(size)
1031             self._results.file_size = size
1032             return read_this_many_bytes(uploadable, size)
1033         d.addCallback(_got_size)
1034         d.addCallback(lambda data: uri.LiteralFileURI("".join(data)))
1035         d.addCallback(lambda u: u.to_string())
1036         d.addCallback(self._build_results)
1037         return d
1038
1039     def _build_results(self, uri):
1040         self._results.uri = uri
1041         self._status.set_status("Finished")
1042         self._status.set_progress(1, 1.0)
1043         self._status.set_progress(2, 1.0)
1044         return self._results
1045
1046     def close(self):
1047         pass
1048
1049     def get_upload_status(self):
1050         return self._status
1051
1052 class RemoteEncryptedUploadable(Referenceable):
1053     implements(RIEncryptedUploadable)
1054
1055     def __init__(self, encrypted_uploadable, upload_status):
1056         self._eu = IEncryptedUploadable(encrypted_uploadable)
1057         self._offset = 0
1058         self._bytes_sent = 0
1059         self._status = IUploadStatus(upload_status)
1060         # we are responsible for updating the status string while we run, and
1061         # for setting the ciphertext-fetch progress.
1062         self._size = None
1063
1064     def get_size(self):
1065         if self._size is not None:
1066             return defer.succeed(self._size)
1067         d = self._eu.get_size()
1068         def _got_size(size):
1069             self._size = size
1070             return size
1071         d.addCallback(_got_size)
1072         return d
1073
1074     def remote_get_size(self):
1075         return self.get_size()
1076     def remote_get_all_encoding_parameters(self):
1077         return self._eu.get_all_encoding_parameters()
1078
1079     def _read_encrypted(self, length, hash_only):
1080         d = self._eu.read_encrypted(length, hash_only)
1081         def _read(strings):
1082             if hash_only:
1083                 self._offset += length
1084             else:
1085                 size = sum([len(data) for data in strings])
1086                 self._offset += size
1087             return strings
1088         d.addCallback(_read)
1089         return d
1090
1091     def remote_read_encrypted(self, offset, length):
1092         # we don't support seek backwards, but we allow skipping forwards
1093         precondition(offset >= 0, offset)
1094         precondition(length >= 0, length)
1095         lp = log.msg("remote_read_encrypted(%d-%d)" % (offset, offset+length),
1096                      level=log.NOISY)
1097         precondition(offset >= self._offset, offset, self._offset)
1098         if offset > self._offset:
1099             # read the data from disk anyways, to build up the hash tree
1100             skip = offset - self._offset
1101             log.msg("remote_read_encrypted skipping ahead from %d to %d, skip=%d" %
1102                     (self._offset, offset, skip), level=log.UNUSUAL, parent=lp)
1103             d = self._read_encrypted(skip, hash_only=True)
1104         else:
1105             d = defer.succeed(None)
1106
1107         def _at_correct_offset(res):
1108             assert offset == self._offset, "%d != %d" % (offset, self._offset)
1109             return self._read_encrypted(length, hash_only=False)
1110         d.addCallback(_at_correct_offset)
1111
1112         def _read(strings):
1113             size = sum([len(data) for data in strings])
1114             self._bytes_sent += size
1115             return strings
1116         d.addCallback(_read)
1117         return d
1118
1119     def remote_close(self):
1120         return self._eu.close()
1121
1122
1123 class AssistedUploader:
1124
1125     def __init__(self, helper):
1126         self._helper = helper
1127         self._log_number = log.msg("AssistedUploader starting")
1128         self._storage_index = None
1129         self._upload_status = s = UploadStatus()
1130         s.set_helper(True)
1131         s.set_active(True)
1132
1133     def log(self, *args, **kwargs):
1134         if "parent" not in kwargs:
1135             kwargs["parent"] = self._log_number
1136         return log.msg(*args, **kwargs)
1137
1138     def start(self, encrypted_uploadable, storage_index):
1139         """Start uploading the file.
1140
1141         Returns a Deferred that will fire with the UploadResults instance.
1142         """
1143         precondition(isinstance(storage_index, str), storage_index)
1144         self._started = time.time()
1145         eu = IEncryptedUploadable(encrypted_uploadable)
1146         eu.set_upload_status(self._upload_status)
1147         self._encuploadable = eu
1148         self._storage_index = storage_index
1149         d = eu.get_size()
1150         d.addCallback(self._got_size)
1151         d.addCallback(lambda res: eu.get_all_encoding_parameters())
1152         d.addCallback(self._got_all_encoding_parameters)
1153         d.addCallback(self._contact_helper)
1154         d.addCallback(self._build_verifycap)
1155         def _done(res):
1156             self._upload_status.set_active(False)
1157             return res
1158         d.addBoth(_done)
1159         return d
1160
1161     def _got_size(self, size):
1162         self._size = size
1163         self._upload_status.set_size(size)
1164
1165     def _got_all_encoding_parameters(self, params):
1166         k, happy, n, segment_size = params
1167         # stash these for URI generation later
1168         self._needed_shares = k
1169         self._total_shares = n
1170         self._segment_size = segment_size
1171
1172     def _contact_helper(self, res):
1173         now = self._time_contacting_helper_start = time.time()
1174         self._storage_index_elapsed = now - self._started
1175         self.log(format="contacting helper for SI %(si)s..",
1176                  si=si_b2a(self._storage_index), level=log.NOISY)
1177         self._upload_status.set_status("Contacting Helper")
1178         d = self._helper.callRemote("upload_chk", self._storage_index)
1179         d.addCallback(self._contacted_helper)
1180         return d
1181
1182     def _contacted_helper(self, (upload_results, upload_helper)):
1183         now = time.time()
1184         elapsed = now - self._time_contacting_helper_start
1185         self._elapsed_time_contacting_helper = elapsed
1186         if upload_helper:
1187             self.log("helper says we need to upload", level=log.NOISY)
1188             self._upload_status.set_status("Uploading Ciphertext")
1189             # we need to upload the file
1190             reu = RemoteEncryptedUploadable(self._encuploadable,
1191                                             self._upload_status)
1192             # let it pre-compute the size for progress purposes
1193             d = reu.get_size()
1194             d.addCallback(lambda ignored:
1195                           upload_helper.callRemote("upload", reu))
1196             # this Deferred will fire with the upload results
1197             return d
1198         self.log("helper says file is already uploaded", level=log.OPERATIONAL)
1199         self._upload_status.set_progress(1, 1.0)
1200         self._upload_status.set_results(upload_results)
1201         return upload_results
1202
1203     def _convert_old_upload_results(self, upload_results):
1204         # pre-1.3.0 helpers return upload results which contain a mapping
1205         # from shnum to a single human-readable string, containing things
1206         # like "Found on [x],[y],[z]" (for healthy files that were already in
1207         # the grid), "Found on [x]" (for files that needed upload but which
1208         # discovered pre-existing shares), and "Placed on [x]" (for newly
1209         # uploaded shares). The 1.3.0 helper returns a mapping from shnum to
1210         # set of binary serverid strings.
1211
1212         # the old results are too hard to deal with (they don't even contain
1213         # as much information as the new results, since the nodeids are
1214         # abbreviated), so if we detect old results, just clobber them.
1215
1216         sharemap = upload_results.sharemap
1217         if str in [type(v) for v in sharemap.values()]:
1218             upload_results.sharemap = None
1219
1220     def _build_verifycap(self, upload_results):
1221         self.log("upload finished, building readcap", level=log.OPERATIONAL)
1222         self._convert_old_upload_results(upload_results)
1223         self._upload_status.set_status("Building Readcap")
1224         r = upload_results
1225         assert r.uri_extension_data["needed_shares"] == self._needed_shares
1226         assert r.uri_extension_data["total_shares"] == self._total_shares
1227         assert r.uri_extension_data["segment_size"] == self._segment_size
1228         assert r.uri_extension_data["size"] == self._size
1229         r.verifycapstr = uri.CHKFileVerifierURI(self._storage_index,
1230                                              uri_extension_hash=r.uri_extension_hash,
1231                                              needed_shares=self._needed_shares,
1232                                              total_shares=self._total_shares, size=self._size
1233                                              ).to_string()
1234         now = time.time()
1235         r.file_size = self._size
1236         r.timings["storage_index"] = self._storage_index_elapsed
1237         r.timings["contacting_helper"] = self._elapsed_time_contacting_helper
1238         if "total" in r.timings:
1239             r.timings["helper_total"] = r.timings["total"]
1240         r.timings["total"] = now - self._started
1241         self._upload_status.set_status("Finished")
1242         self._upload_status.set_results(r)
1243         return r
1244
1245     def get_upload_status(self):
1246         return self._upload_status
1247
1248 class BaseUploadable:
1249     # this is overridden by max_segment_size
1250     default_max_segment_size = DEFAULT_MAX_SEGMENT_SIZE
1251     default_encoding_param_k = 3 # overridden by encoding_parameters
1252     default_encoding_param_happy = 7
1253     default_encoding_param_n = 10
1254
1255     max_segment_size = None
1256     encoding_param_k = None
1257     encoding_param_happy = None
1258     encoding_param_n = None
1259
1260     _all_encoding_parameters = None
1261     _status = None
1262
1263     def set_upload_status(self, upload_status):
1264         self._status = IUploadStatus(upload_status)
1265
1266     def set_default_encoding_parameters(self, default_params):
1267         assert isinstance(default_params, dict)
1268         for k,v in default_params.items():
1269             precondition(isinstance(k, str), k, v)
1270             precondition(isinstance(v, int), k, v)
1271         if "k" in default_params:
1272             self.default_encoding_param_k = default_params["k"]
1273         if "happy" in default_params:
1274             self.default_encoding_param_happy = default_params["happy"]
1275         if "n" in default_params:
1276             self.default_encoding_param_n = default_params["n"]
1277         if "max_segment_size" in default_params:
1278             self.default_max_segment_size = default_params["max_segment_size"]
1279
1280     def get_all_encoding_parameters(self):
1281         if self._all_encoding_parameters:
1282             return defer.succeed(self._all_encoding_parameters)
1283
1284         max_segsize = self.max_segment_size or self.default_max_segment_size
1285         k = self.encoding_param_k or self.default_encoding_param_k
1286         happy = self.encoding_param_happy or self.default_encoding_param_happy
1287         n = self.encoding_param_n or self.default_encoding_param_n
1288
1289         d = self.get_size()
1290         def _got_size(file_size):
1291             # for small files, shrink the segment size to avoid wasting space
1292             segsize = min(max_segsize, file_size)
1293             # this must be a multiple of 'required_shares'==k
1294             segsize = mathutil.next_multiple(segsize, k)
1295             encoding_parameters = (k, happy, n, segsize)
1296             self._all_encoding_parameters = encoding_parameters
1297             return encoding_parameters
1298         d.addCallback(_got_size)
1299         return d
1300
1301 class FileHandle(BaseUploadable):
1302     implements(IUploadable)
1303
1304     def __init__(self, filehandle, convergence):
1305         """
1306         Upload the data from the filehandle.  If convergence is None then a
1307         random encryption key will be used, else the plaintext will be hashed,
1308         then the hash will be hashed together with the string in the
1309         "convergence" argument to form the encryption key.
1310         """
1311         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1312         self._filehandle = filehandle
1313         self._key = None
1314         self.convergence = convergence
1315         self._size = None
1316
1317     def _get_encryption_key_convergent(self):
1318         if self._key is not None:
1319             return defer.succeed(self._key)
1320
1321         d = self.get_size()
1322         # that sets self._size as a side-effect
1323         d.addCallback(lambda size: self.get_all_encoding_parameters())
1324         def _got(params):
1325             k, happy, n, segsize = params
1326             f = self._filehandle
1327             enckey_hasher = convergence_hasher(k, n, segsize, self.convergence)
1328             f.seek(0)
1329             BLOCKSIZE = 64*1024
1330             bytes_read = 0
1331             while True:
1332                 data = f.read(BLOCKSIZE)
1333                 if not data:
1334                     break
1335                 enckey_hasher.update(data)
1336                 # TODO: setting progress in a non-yielding loop is kind of
1337                 # pointless, but I'm anticipating (perhaps prematurely) the
1338                 # day when we use a slowjob or twisted's CooperatorService to
1339                 # make this yield time to other jobs.
1340                 bytes_read += len(data)
1341                 if self._status:
1342                     self._status.set_progress(0, float(bytes_read)/self._size)
1343             f.seek(0)
1344             self._key = enckey_hasher.digest()
1345             if self._status:
1346                 self._status.set_progress(0, 1.0)
1347             assert len(self._key) == 16
1348             return self._key
1349         d.addCallback(_got)
1350         return d
1351
1352     def _get_encryption_key_random(self):
1353         if self._key is None:
1354             self._key = os.urandom(16)
1355         return defer.succeed(self._key)
1356
1357     def get_encryption_key(self):
1358         if self.convergence is not None:
1359             return self._get_encryption_key_convergent()
1360         else:
1361             return self._get_encryption_key_random()
1362
1363     def get_size(self):
1364         if self._size is not None:
1365             return defer.succeed(self._size)
1366         self._filehandle.seek(0,2)
1367         size = self._filehandle.tell()
1368         self._size = size
1369         self._filehandle.seek(0)
1370         return defer.succeed(size)
1371
1372     def read(self, length):
1373         return defer.succeed([self._filehandle.read(length)])
1374
1375     def close(self):
1376         # the originator of the filehandle reserves the right to close it
1377         pass
1378
1379 class FileName(FileHandle):
1380     def __init__(self, filename, convergence):
1381         """
1382         Upload the data from the filename.  If convergence is None then a
1383         random encryption key will be used, else the plaintext will be hashed,
1384         then the hash will be hashed together with the string in the
1385         "convergence" argument to form the encryption key.
1386         """
1387         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1388         FileHandle.__init__(self, open(filename, "rb"), convergence=convergence)
1389     def close(self):
1390         FileHandle.close(self)
1391         self._filehandle.close()
1392
1393 class Data(FileHandle):
1394     def __init__(self, data, convergence):
1395         """
1396         Upload the data from the data argument.  If convergence is None then a
1397         random encryption key will be used, else the plaintext will be hashed,
1398         then the hash will be hashed together with the string in the
1399         "convergence" argument to form the encryption key.
1400         """
1401         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1402         FileHandle.__init__(self, StringIO(data), convergence=convergence)
1403
1404 class Uploader(service.MultiService, log.PrefixingLogMixin):
1405     """I am a service that allows file uploading. I am a service-child of the
1406     Client.
1407     """
1408     implements(IUploader)
1409     name = "uploader"
1410     URI_LIT_SIZE_THRESHOLD = 55
1411
1412     def __init__(self, helper_furl=None, stats_provider=None):
1413         self._helper_furl = helper_furl
1414         self.stats_provider = stats_provider
1415         self._helper = None
1416         self._all_uploads = weakref.WeakKeyDictionary() # for debugging
1417         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.upload")
1418         service.MultiService.__init__(self)
1419
1420     def startService(self):
1421         service.MultiService.startService(self)
1422         if self._helper_furl:
1423             self.parent.tub.connectTo(self._helper_furl,
1424                                       self._got_helper)
1425
1426     def _got_helper(self, helper):
1427         self.log("got helper connection, getting versions")
1428         default = { "http://allmydata.org/tahoe/protocols/helper/v1" :
1429                     { },
1430                     "application-version": "unknown: no get_version()",
1431                     }
1432         d = add_version_to_remote_reference(helper, default)
1433         d.addCallback(self._got_versioned_helper)
1434
1435     def _got_versioned_helper(self, helper):
1436         needed = "http://allmydata.org/tahoe/protocols/helper/v1"
1437         if needed not in helper.version:
1438             raise InsufficientVersionError(needed, helper.version)
1439         self._helper = helper
1440         helper.notifyOnDisconnect(self._lost_helper)
1441
1442     def _lost_helper(self):
1443         self._helper = None
1444
1445     def get_helper_info(self):
1446         # return a tuple of (helper_furl_or_None, connected_bool)
1447         return (self._helper_furl, bool(self._helper))
1448
1449
1450     def upload(self, uploadable, history=None):
1451         """
1452         Returns a Deferred that will fire with the UploadResults instance.
1453         """
1454         assert self.parent
1455         assert self.running
1456
1457         uploadable = IUploadable(uploadable)
1458         d = uploadable.get_size()
1459         def _got_size(size):
1460             default_params = self.parent.get_encoding_parameters()
1461             precondition(isinstance(default_params, dict), default_params)
1462             precondition("max_segment_size" in default_params, default_params)
1463             uploadable.set_default_encoding_parameters(default_params)
1464
1465             if self.stats_provider:
1466                 self.stats_provider.count('uploader.files_uploaded', 1)
1467                 self.stats_provider.count('uploader.bytes_uploaded', size)
1468
1469             if size <= self.URI_LIT_SIZE_THRESHOLD:
1470                 uploader = LiteralUploader()
1471                 return uploader.start(uploadable)
1472             else:
1473                 eu = EncryptAnUploadable(uploadable, self._parentmsgid)
1474                 d2 = defer.succeed(None)
1475                 if self._helper:
1476                     uploader = AssistedUploader(self._helper)
1477                     d2.addCallback(lambda x: eu.get_storage_index())
1478                     d2.addCallback(lambda si: uploader.start(eu, si))
1479                 else:
1480                     storage_broker = self.parent.get_storage_broker()
1481                     secret_holder = self.parent._secret_holder
1482                     uploader = CHKUploader(storage_broker, secret_holder)
1483                     d2.addCallback(lambda x: uploader.start(eu))
1484
1485                 self._all_uploads[uploader] = None
1486                 if history:
1487                     history.add_upload(uploader.get_upload_status())
1488                 def turn_verifycap_into_read_cap(uploadresults):
1489                     # Generate the uri from the verifycap plus the key.
1490                     d3 = uploadable.get_encryption_key()
1491                     def put_readcap_into_results(key):
1492                         v = uri.from_string(uploadresults.verifycapstr)
1493                         r = uri.CHKFileURI(key, v.uri_extension_hash, v.needed_shares, v.total_shares, v.size)
1494                         uploadresults.uri = r.to_string()
1495                         return uploadresults
1496                     d3.addCallback(put_readcap_into_results)
1497                     return d3
1498                 d2.addCallback(turn_verifycap_into_read_cap)
1499                 return d2
1500         d.addCallback(_got_size)
1501         def _done(res):
1502             uploadable.close()
1503             return res
1504         d.addBoth(_done)
1505         return d