]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/immutable/upload.py
Flesh out "tahoe magic-folder status" command
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / upload.py
1 import os, time, weakref, itertools
2 from zope.interface import implements
3 from twisted.python import failure
4 from twisted.internet import defer
5 from twisted.application import service
6 from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually
7
8 from allmydata.util.hashutil import file_renewal_secret_hash, \
9      file_cancel_secret_hash, bucket_renewal_secret_hash, \
10      bucket_cancel_secret_hash, plaintext_hasher, \
11      storage_index_hash, plaintext_segment_hasher, convergence_hasher
12 from allmydata import hashtree, uri
13 from allmydata.storage.server import si_b2a
14 from allmydata.immutable import encode
15 from allmydata.util import base32, dictutil, idlib, log, mathutil
16 from allmydata.util.happinessutil import servers_of_happiness, \
17                                          shares_by_server, merge_servers, \
18                                          failure_message
19 from allmydata.util.assertutil import precondition, _assert
20 from allmydata.util.rrefutil import add_version_to_remote_reference
21 from allmydata.interfaces import IUploadable, IUploader, IUploadResults, \
22      IEncryptedUploadable, RIEncryptedUploadable, IUploadStatus, \
23      NoServersError, InsufficientVersionError, UploadUnhappinessError, \
24      DEFAULT_MAX_SEGMENT_SIZE, IProgress
25 from allmydata.immutable import layout
26 from pycryptopp.cipher.aes import AES
27
28 from cStringIO import StringIO
29
30
31 # this wants to live in storage, not here
32 class TooFullError(Exception):
33     pass
34
35 # HelperUploadResults are what we get from the Helper, and to retain
36 # backwards compatibility with old Helpers we can't change the format. We
37 # convert them into a local UploadResults upon receipt.
38 class HelperUploadResults(Copyable, RemoteCopy):
39     # note: don't change this string, it needs to match the value used on the
40     # helper, and it does *not* need to match the fully-qualified
41     # package/module/class name
42     typeToCopy = "allmydata.upload.UploadResults.tahoe.allmydata.com"
43     copytype = typeToCopy
44
45     # also, think twice about changing the shape of any existing attribute,
46     # because instances of this class are sent from the helper to its client,
47     # so changing this may break compatibility. Consider adding new fields
48     # instead of modifying existing ones.
49
50     def __init__(self):
51         self.timings = {} # dict of name to number of seconds
52         self.sharemap = dictutil.DictOfSets() # {shnum: set(serverid)}
53         self.servermap = dictutil.DictOfSets() # {serverid: set(shnum)}
54         self.file_size = None
55         self.ciphertext_fetched = None # how much the helper fetched
56         self.uri = None
57         self.preexisting_shares = None # count of shares already present
58         self.pushed_shares = None # count of shares we pushed
59
60 class UploadResults:
61     implements(IUploadResults)
62
63     def __init__(self, file_size,
64                  ciphertext_fetched, # how much the helper fetched
65                  preexisting_shares, # count of shares already present
66                  pushed_shares, # count of shares we pushed
67                  sharemap, # {shnum: set(server)}
68                  servermap, # {server: set(shnum)}
69                  timings, # dict of name to number of seconds
70                  uri_extension_data,
71                  uri_extension_hash,
72                  verifycapstr):
73         self._file_size = file_size
74         self._ciphertext_fetched = ciphertext_fetched
75         self._preexisting_shares = preexisting_shares
76         self._pushed_shares = pushed_shares
77         self._sharemap = sharemap
78         self._servermap = servermap
79         self._timings = timings
80         self._uri_extension_data = uri_extension_data
81         self._uri_extension_hash = uri_extension_hash
82         self._verifycapstr = verifycapstr
83
84     def set_uri(self, uri):
85         self._uri = uri
86
87     def get_file_size(self):
88         return self._file_size
89     def get_uri(self):
90         return self._uri
91     def get_ciphertext_fetched(self):
92         return self._ciphertext_fetched
93     def get_preexisting_shares(self):
94         return self._preexisting_shares
95     def get_pushed_shares(self):
96         return self._pushed_shares
97     def get_sharemap(self):
98         return self._sharemap
99     def get_servermap(self):
100         return self._servermap
101     def get_timings(self):
102         return self._timings
103     def get_uri_extension_data(self):
104         return self._uri_extension_data
105     def get_verifycapstr(self):
106         return self._verifycapstr
107
108 # our current uri_extension is 846 bytes for small files, a few bytes
109 # more for larger ones (since the filesize is encoded in decimal in a
110 # few places). Ask for a little bit more just in case we need it. If
111 # the extension changes size, we can change EXTENSION_SIZE to
112 # allocate a more accurate amount of space.
113 EXTENSION_SIZE = 1000
114 # TODO: actual extensions are closer to 419 bytes, so we can probably lower
115 # this.
116
117 def pretty_print_shnum_to_servers(s):
118     return ', '.join([ "sh%s: %s" % (k, '+'.join([idlib.shortnodeid_b2a(x) for x in v])) for k, v in s.iteritems() ])
119
120 class ServerTracker:
121     def __init__(self, server,
122                  sharesize, blocksize, num_segments, num_share_hashes,
123                  storage_index,
124                  bucket_renewal_secret, bucket_cancel_secret):
125         self._server = server
126         self.buckets = {} # k: shareid, v: IRemoteBucketWriter
127         self.sharesize = sharesize
128
129         wbp = layout.make_write_bucket_proxy(None, None, sharesize,
130                                              blocksize, num_segments,
131                                              num_share_hashes,
132                                              EXTENSION_SIZE)
133         self.wbp_class = wbp.__class__ # to create more of them
134         self.allocated_size = wbp.get_allocated_size()
135         self.blocksize = blocksize
136         self.num_segments = num_segments
137         self.num_share_hashes = num_share_hashes
138         self.storage_index = storage_index
139
140         self.renew_secret = bucket_renewal_secret
141         self.cancel_secret = bucket_cancel_secret
142
143     def __repr__(self):
144         return ("<ServerTracker for server %s and SI %s>"
145                 % (self._server.get_name(), si_b2a(self.storage_index)[:5]))
146
147     def get_server(self):
148         return self._server
149     def get_serverid(self):
150         return self._server.get_serverid()
151     def get_name(self):
152         return self._server.get_name()
153
154     def query(self, sharenums):
155         rref = self._server.get_rref()
156         d = rref.callRemote("allocate_buckets",
157                             self.storage_index,
158                             self.renew_secret,
159                             self.cancel_secret,
160                             sharenums,
161                             self.allocated_size,
162                             canary=Referenceable())
163         d.addCallback(self._got_reply)
164         return d
165
166     def ask_about_existing_shares(self):
167         rref = self._server.get_rref()
168         return rref.callRemote("get_buckets", self.storage_index)
169
170     def _got_reply(self, (alreadygot, buckets)):
171         #log.msg("%s._got_reply(%s)" % (self, (alreadygot, buckets)))
172         b = {}
173         for sharenum, rref in buckets.iteritems():
174             bp = self.wbp_class(rref, self._server, self.sharesize,
175                                 self.blocksize,
176                                 self.num_segments,
177                                 self.num_share_hashes,
178                                 EXTENSION_SIZE)
179             b[sharenum] = bp
180         self.buckets.update(b)
181         return (alreadygot, set(b.keys()))
182
183
184     def abort(self):
185         """
186         I abort the remote bucket writers for all shares. This is a good idea
187         to conserve space on the storage server.
188         """
189         self.abort_some_buckets(self.buckets.keys())
190
191     def abort_some_buckets(self, sharenums):
192         """
193         I abort the remote bucket writers for the share numbers in sharenums.
194         """
195         for sharenum in sharenums:
196             if sharenum in self.buckets:
197                 self.buckets[sharenum].abort()
198                 del self.buckets[sharenum]
199
200
201 def str_shareloc(shnum, bucketwriter):
202     return "%s: %s" % (shnum, bucketwriter.get_servername(),)
203
204 class Tahoe2ServerSelector(log.PrefixingLogMixin):
205
206     def __init__(self, upload_id, logparent=None, upload_status=None):
207         self.upload_id = upload_id
208         self.query_count, self.good_query_count, self.bad_query_count = 0,0,0
209         # Servers that are working normally, but full.
210         self.full_count = 0
211         self.error_count = 0
212         self.num_servers_contacted = 0
213         self.last_failure_msg = None
214         self._status = IUploadStatus(upload_status)
215         log.PrefixingLogMixin.__init__(self, 'tahoe.immutable.upload', logparent, prefix=upload_id)
216         self.log("starting", level=log.OPERATIONAL)
217
218     def __repr__(self):
219         return "<Tahoe2ServerSelector for upload %s>" % self.upload_id
220
221     def get_shareholders(self, storage_broker, secret_holder,
222                          storage_index, share_size, block_size,
223                          num_segments, total_shares, needed_shares,
224                          servers_of_happiness):
225         """
226         @return: (upload_trackers, already_serverids), where upload_trackers
227                  is a set of ServerTracker instances that have agreed to hold
228                  some shares for us (the shareids are stashed inside the
229                  ServerTracker), and already_serverids is a dict mapping
230                  shnum to a set of serverids for servers which claim to
231                  already have the share.
232         """
233
234         if self._status:
235             self._status.set_status("Contacting Servers..")
236
237         self.total_shares = total_shares
238         self.servers_of_happiness = servers_of_happiness
239         self.needed_shares = needed_shares
240
241         self.homeless_shares = set(range(total_shares))
242         self.use_trackers = set() # ServerTrackers that have shares assigned
243                                   # to them
244         self.preexisting_shares = {} # shareid => set(serverids) holding shareid
245
246         # These servers have shares -- any shares -- for our SI. We keep
247         # track of these to write an error message with them later.
248         self.serverids_with_shares = set()
249
250         # this needed_hashes computation should mirror
251         # Encoder.send_all_share_hash_trees. We use an IncompleteHashTree
252         # (instead of a HashTree) because we don't require actual hashing
253         # just to count the levels.
254         ht = hashtree.IncompleteHashTree(total_shares)
255         num_share_hashes = len(ht.needed_hashes(0, include_leaf=True))
256
257         # figure out how much space to ask for
258         wbp = layout.make_write_bucket_proxy(None, None,
259                                              share_size, 0, num_segments,
260                                              num_share_hashes, EXTENSION_SIZE)
261         allocated_size = wbp.get_allocated_size()
262         all_servers = storage_broker.get_servers_for_psi(storage_index)
263         if not all_servers:
264             raise NoServersError("client gave us zero servers")
265
266         # filter the list of servers according to which ones can accomodate
267         # this request. This excludes older servers (which used a 4-byte size
268         # field) from getting large shares (for files larger than about
269         # 12GiB). See #439 for details.
270         def _get_maxsize(server):
271             v0 = server.get_rref().version
272             v1 = v0["http://allmydata.org/tahoe/protocols/storage/v1"]
273             return v1["maximum-immutable-share-size"]
274         writeable_servers = [server for server in all_servers
275                             if _get_maxsize(server) >= allocated_size]
276         readonly_servers = set(all_servers[:2*total_shares]) - set(writeable_servers)
277
278         # decide upon the renewal/cancel secrets, to include them in the
279         # allocate_buckets query.
280         client_renewal_secret = secret_holder.get_renewal_secret()
281         client_cancel_secret = secret_holder.get_cancel_secret()
282
283         file_renewal_secret = file_renewal_secret_hash(client_renewal_secret,
284                                                        storage_index)
285         file_cancel_secret = file_cancel_secret_hash(client_cancel_secret,
286                                                      storage_index)
287         def _make_trackers(servers):
288             trackers = []
289             for s in servers:
290                 seed = s.get_lease_seed()
291                 renew = bucket_renewal_secret_hash(file_renewal_secret, seed)
292                 cancel = bucket_cancel_secret_hash(file_cancel_secret, seed)
293                 st = ServerTracker(s,
294                                    share_size, block_size,
295                                    num_segments, num_share_hashes,
296                                    storage_index,
297                                    renew, cancel)
298                 trackers.append(st)
299             return trackers
300
301         # We assign each servers/trackers into one three lists. They all
302         # start in the "first pass" list. During the first pass, as we ask
303         # each one to hold a share, we move their tracker to the "second
304         # pass" list, until the first-pass list is empty. Then during the
305         # second pass, as we ask each to hold more shares, we move their
306         # tracker to the "next pass" list, until the second-pass list is
307         # empty. Then we move everybody from the next-pass list back to the
308         # second-pass list and repeat the "second" pass (really the third,
309         # fourth, etc pass), until all shares are assigned, or we've run out
310         # of potential servers.
311         self.first_pass_trackers = _make_trackers(writeable_servers)
312         self.second_pass_trackers = [] # servers worth asking again
313         self.next_pass_trackers = [] # servers that we have asked again
314         self._started_second_pass = False
315
316         # We don't try to allocate shares to these servers, since they've
317         # said that they're incapable of storing shares of the size that we'd
318         # want to store. We ask them about existing shares for this storage
319         # index, which we want to know about for accurate
320         # servers_of_happiness accounting, then we forget about them.
321         readonly_trackers = _make_trackers(readonly_servers)
322
323         # We now ask servers that can't hold any new shares about existing
324         # shares that they might have for our SI. Once this is done, we
325         # start placing the shares that we haven't already accounted
326         # for.
327         ds = []
328         if self._status and readonly_trackers:
329             self._status.set_status("Contacting readonly servers to find "
330                                     "any existing shares")
331         for tracker in readonly_trackers:
332             assert isinstance(tracker, ServerTracker)
333             d = tracker.ask_about_existing_shares()
334             d.addBoth(self._handle_existing_response, tracker)
335             ds.append(d)
336             self.num_servers_contacted += 1
337             self.query_count += 1
338             self.log("asking server %s for any existing shares" %
339                      (tracker.get_name(),), level=log.NOISY)
340         dl = defer.DeferredList(ds)
341         dl.addCallback(lambda ign: self._loop())
342         return dl
343
344
345     def _handle_existing_response(self, res, tracker):
346         """
347         I handle responses to the queries sent by
348         Tahoe2ServerSelector._existing_shares.
349         """
350         serverid = tracker.get_serverid()
351         if isinstance(res, failure.Failure):
352             self.log("%s got error during existing shares check: %s"
353                     % (tracker.get_name(), res), level=log.UNUSUAL)
354             self.error_count += 1
355             self.bad_query_count += 1
356         else:
357             buckets = res
358             if buckets:
359                 self.serverids_with_shares.add(serverid)
360             self.log("response to get_buckets() from server %s: alreadygot=%s"
361                     % (tracker.get_name(), tuple(sorted(buckets))),
362                     level=log.NOISY)
363             for bucket in buckets:
364                 self.preexisting_shares.setdefault(bucket, set()).add(serverid)
365                 self.homeless_shares.discard(bucket)
366             self.full_count += 1
367             self.bad_query_count += 1
368
369
370     def _get_progress_message(self):
371         if not self.homeless_shares:
372             msg = "placed all %d shares, " % (self.total_shares)
373         else:
374             msg = ("placed %d shares out of %d total (%d homeless), " %
375                    (self.total_shares - len(self.homeless_shares),
376                     self.total_shares,
377                     len(self.homeless_shares)))
378         return (msg + "want to place shares on at least %d servers such that "
379                       "any %d of them have enough shares to recover the file, "
380                       "sent %d queries to %d servers, "
381                       "%d queries placed some shares, %d placed none "
382                       "(of which %d placed none due to the server being"
383                       " full and %d placed none due to an error)" %
384                         (self.servers_of_happiness, self.needed_shares,
385                          self.query_count, self.num_servers_contacted,
386                          self.good_query_count, self.bad_query_count,
387                          self.full_count, self.error_count))
388
389
390     def _loop(self):
391         if not self.homeless_shares:
392             merged = merge_servers(self.preexisting_shares, self.use_trackers)
393             effective_happiness = servers_of_happiness(merged)
394             if self.servers_of_happiness <= effective_happiness:
395                 msg = ("server selection successful for %s: %s: pretty_print_merged: %s, "
396                        "self.use_trackers: %s, self.preexisting_shares: %s") \
397                        % (self, self._get_progress_message(),
398                           pretty_print_shnum_to_servers(merged),
399                           [', '.join([str_shareloc(k,v)
400                                       for k,v in st.buckets.iteritems()])
401                            for st in self.use_trackers],
402                           pretty_print_shnum_to_servers(self.preexisting_shares))
403                 self.log(msg, level=log.OPERATIONAL)
404                 return (self.use_trackers, self.preexisting_shares)
405             else:
406                 # We're not okay right now, but maybe we can fix it by
407                 # redistributing some shares. In cases where one or two
408                 # servers has, before the upload, all or most of the
409                 # shares for a given SI, this can work by allowing _loop
410                 # a chance to spread those out over the other servers,
411                 delta = self.servers_of_happiness - effective_happiness
412                 shares = shares_by_server(self.preexisting_shares)
413                 # Each server in shares maps to a set of shares stored on it.
414                 # Since we want to keep at least one share on each server
415                 # that has one (otherwise we'd only be making
416                 # the situation worse by removing distinct servers),
417                 # each server has len(its shares) - 1 to spread around.
418                 shares_to_spread = sum([len(list(sharelist)) - 1
419                                         for (server, sharelist)
420                                         in shares.items()])
421                 if delta <= len(self.first_pass_trackers) and \
422                    shares_to_spread >= delta:
423                     items = shares.items()
424                     while len(self.homeless_shares) < delta:
425                         # Loop through the allocated shares, removing
426                         # one from each server that has more than one
427                         # and putting it back into self.homeless_shares
428                         # until we've done this delta times.
429                         server, sharelist = items.pop()
430                         if len(sharelist) > 1:
431                             share = sharelist.pop()
432                             self.homeless_shares.add(share)
433                             self.preexisting_shares[share].remove(server)
434                             if not self.preexisting_shares[share]:
435                                 del self.preexisting_shares[share]
436                             items.append((server, sharelist))
437                         for writer in self.use_trackers:
438                             writer.abort_some_buckets(self.homeless_shares)
439                     return self._loop()
440                 else:
441                     # Redistribution won't help us; fail.
442                     server_count = len(self.serverids_with_shares)
443                     failmsg = failure_message(server_count,
444                                               self.needed_shares,
445                                               self.servers_of_happiness,
446                                               effective_happiness)
447                     servmsgtempl = "server selection unsuccessful for %r: %s (%s), merged=%s"
448                     servmsg = servmsgtempl % (
449                         self,
450                         failmsg,
451                         self._get_progress_message(),
452                         pretty_print_shnum_to_servers(merged)
453                         )
454                     self.log(servmsg, level=log.INFREQUENT)
455                     return self._failed("%s (%s)" % (failmsg, self._get_progress_message()))
456
457         if self.first_pass_trackers:
458             tracker = self.first_pass_trackers.pop(0)
459             # TODO: don't pre-convert all serverids to ServerTrackers
460             assert isinstance(tracker, ServerTracker)
461
462             shares_to_ask = set(sorted(self.homeless_shares)[:1])
463             self.homeless_shares -= shares_to_ask
464             self.query_count += 1
465             self.num_servers_contacted += 1
466             if self._status:
467                 self._status.set_status("Contacting Servers [%s] (first query),"
468                                         " %d shares left.."
469                                         % (tracker.get_name(),
470                                            len(self.homeless_shares)))
471             d = tracker.query(shares_to_ask)
472             d.addBoth(self._got_response, tracker, shares_to_ask,
473                       self.second_pass_trackers)
474             return d
475         elif self.second_pass_trackers:
476             # ask a server that we've already asked.
477             if not self._started_second_pass:
478                 self.log("starting second pass",
479                         level=log.NOISY)
480                 self._started_second_pass = True
481             num_shares = mathutil.div_ceil(len(self.homeless_shares),
482                                            len(self.second_pass_trackers))
483             tracker = self.second_pass_trackers.pop(0)
484             shares_to_ask = set(sorted(self.homeless_shares)[:num_shares])
485             self.homeless_shares -= shares_to_ask
486             self.query_count += 1
487             if self._status:
488                 self._status.set_status("Contacting Servers [%s] (second query),"
489                                         " %d shares left.."
490                                         % (tracker.get_name(),
491                                            len(self.homeless_shares)))
492             d = tracker.query(shares_to_ask)
493             d.addBoth(self._got_response, tracker, shares_to_ask,
494                       self.next_pass_trackers)
495             return d
496         elif self.next_pass_trackers:
497             # we've finished the second-or-later pass. Move all the remaining
498             # servers back into self.second_pass_trackers for the next pass.
499             self.second_pass_trackers.extend(self.next_pass_trackers)
500             self.next_pass_trackers[:] = []
501             return self._loop()
502         else:
503             # no more servers. If we haven't placed enough shares, we fail.
504             merged = merge_servers(self.preexisting_shares, self.use_trackers)
505             effective_happiness = servers_of_happiness(merged)
506             if effective_happiness < self.servers_of_happiness:
507                 msg = failure_message(len(self.serverids_with_shares),
508                                       self.needed_shares,
509                                       self.servers_of_happiness,
510                                       effective_happiness)
511                 msg = ("server selection failed for %s: %s (%s)" %
512                        (self, msg, self._get_progress_message()))
513                 if self.last_failure_msg:
514                     msg += " (%s)" % (self.last_failure_msg,)
515                 self.log(msg, level=log.UNUSUAL)
516                 return self._failed(msg)
517             else:
518                 # we placed enough to be happy, so we're done
519                 if self._status:
520                     self._status.set_status("Placed all shares")
521                 msg = ("server selection successful (no more servers) for %s: %s: %s" % (self,
522                             self._get_progress_message(), pretty_print_shnum_to_servers(merged)))
523                 self.log(msg, level=log.OPERATIONAL)
524                 return (self.use_trackers, self.preexisting_shares)
525
526     def _got_response(self, res, tracker, shares_to_ask, put_tracker_here):
527         if isinstance(res, failure.Failure):
528             # This is unusual, and probably indicates a bug or a network
529             # problem.
530             self.log("%s got error during server selection: %s" % (tracker, res),
531                     level=log.UNUSUAL)
532             self.error_count += 1
533             self.bad_query_count += 1
534             self.homeless_shares |= shares_to_ask
535             if (self.first_pass_trackers
536                 or self.second_pass_trackers
537                 or self.next_pass_trackers):
538                 # there is still hope, so just loop
539                 pass
540             else:
541                 # No more servers, so this upload might fail (it depends upon
542                 # whether we've hit servers_of_happiness or not). Log the last
543                 # failure we got: if a coding error causes all servers to fail
544                 # in the same way, this allows the common failure to be seen
545                 # by the uploader and should help with debugging
546                 msg = ("last failure (from %s) was: %s" % (tracker, res))
547                 self.last_failure_msg = msg
548         else:
549             (alreadygot, allocated) = res
550             self.log("response to allocate_buckets() from server %s: alreadygot=%s, allocated=%s"
551                     % (tracker.get_name(),
552                        tuple(sorted(alreadygot)), tuple(sorted(allocated))),
553                     level=log.NOISY)
554             progress = False
555             for s in alreadygot:
556                 self.preexisting_shares.setdefault(s, set()).add(tracker.get_serverid())
557                 if s in self.homeless_shares:
558                     self.homeless_shares.remove(s)
559                     progress = True
560                 elif s in shares_to_ask:
561                     progress = True
562
563             # the ServerTracker will remember which shares were allocated on
564             # that peer. We just have to remember to use them.
565             if allocated:
566                 self.use_trackers.add(tracker)
567                 progress = True
568
569             if allocated or alreadygot:
570                 self.serverids_with_shares.add(tracker.get_serverid())
571
572             not_yet_present = set(shares_to_ask) - set(alreadygot)
573             still_homeless = not_yet_present - set(allocated)
574
575             if progress:
576                 # They accepted at least one of the shares that we asked
577                 # them to accept, or they had a share that we didn't ask
578                 # them to accept but that we hadn't placed yet, so this
579                 # was a productive query
580                 self.good_query_count += 1
581             else:
582                 self.bad_query_count += 1
583                 self.full_count += 1
584
585             if still_homeless:
586                 # In networks with lots of space, this is very unusual and
587                 # probably indicates an error. In networks with servers that
588                 # are full, it is merely unusual. In networks that are very
589                 # full, it is common, and many uploads will fail. In most
590                 # cases, this is obviously not fatal, and we'll just use some
591                 # other servers.
592
593                 # some shares are still homeless, keep trying to find them a
594                 # home. The ones that were rejected get first priority.
595                 self.homeless_shares |= still_homeless
596                 # Since they were unable to accept all of our requests, so it
597                 # is safe to assume that asking them again won't help.
598             else:
599                 # if they *were* able to accept everything, they might be
600                 # willing to accept even more.
601                 put_tracker_here.append(tracker)
602
603         # now loop
604         return self._loop()
605
606
607     def _failed(self, msg):
608         """
609         I am called when server selection fails. I first abort all of the
610         remote buckets that I allocated during my unsuccessful attempt to
611         place shares for this file. I then raise an
612         UploadUnhappinessError with my msg argument.
613         """
614         for tracker in self.use_trackers:
615             assert isinstance(tracker, ServerTracker)
616             tracker.abort()
617         raise UploadUnhappinessError(msg)
618
619
620 class EncryptAnUploadable:
621     """This is a wrapper that takes an IUploadable and provides
622     IEncryptedUploadable."""
623     implements(IEncryptedUploadable)
624     CHUNKSIZE = 50*1024
625
626     def __init__(self, original, log_parent=None, progress=None):
627         precondition(original.default_params_set,
628                      "set_default_encoding_parameters not called on %r before wrapping with EncryptAnUploadable" % (original,))
629         self.original = IUploadable(original)
630         self._log_number = log_parent
631         self._encryptor = None
632         self._plaintext_hasher = plaintext_hasher()
633         self._plaintext_segment_hasher = None
634         self._plaintext_segment_hashes = []
635         self._encoding_parameters = None
636         self._file_size = None
637         self._ciphertext_bytes_read = 0
638         self._status = None
639         self._progress = progress
640
641     def set_upload_status(self, upload_status):
642         self._status = IUploadStatus(upload_status)
643         self.original.set_upload_status(upload_status)
644
645     def log(self, *args, **kwargs):
646         if "facility" not in kwargs:
647             kwargs["facility"] = "upload.encryption"
648         if "parent" not in kwargs:
649             kwargs["parent"] = self._log_number
650         return log.msg(*args, **kwargs)
651
652     def get_size(self):
653         if self._file_size is not None:
654             return defer.succeed(self._file_size)
655         d = self.original.get_size()
656         def _got_size(size):
657             self._file_size = size
658             if self._status:
659                 self._status.set_size(size)
660             if self._progress:
661                 self._progress.set_progress_total(size)
662             return size
663         d.addCallback(_got_size)
664         return d
665
666     def get_all_encoding_parameters(self):
667         if self._encoding_parameters is not None:
668             return defer.succeed(self._encoding_parameters)
669         d = self.original.get_all_encoding_parameters()
670         def _got(encoding_parameters):
671             (k, happy, n, segsize) = encoding_parameters
672             self._segment_size = segsize # used by segment hashers
673             self._encoding_parameters = encoding_parameters
674             self.log("my encoding parameters: %s" % (encoding_parameters,),
675                      level=log.NOISY)
676             return encoding_parameters
677         d.addCallback(_got)
678         return d
679
680     def _get_encryptor(self):
681         if self._encryptor:
682             return defer.succeed(self._encryptor)
683
684         d = self.original.get_encryption_key()
685         def _got(key):
686             e = AES(key)
687             self._encryptor = e
688
689             storage_index = storage_index_hash(key)
690             assert isinstance(storage_index, str)
691             # There's no point to having the SI be longer than the key, so we
692             # specify that it is truncated to the same 128 bits as the AES key.
693             assert len(storage_index) == 16  # SHA-256 truncated to 128b
694             self._storage_index = storage_index
695             if self._status:
696                 self._status.set_storage_index(storage_index)
697             return e
698         d.addCallback(_got)
699         return d
700
701     def get_storage_index(self):
702         d = self._get_encryptor()
703         d.addCallback(lambda res: self._storage_index)
704         return d
705
706     def _get_segment_hasher(self):
707         p = self._plaintext_segment_hasher
708         if p:
709             left = self._segment_size - self._plaintext_segment_hashed_bytes
710             return p, left
711         p = plaintext_segment_hasher()
712         self._plaintext_segment_hasher = p
713         self._plaintext_segment_hashed_bytes = 0
714         return p, self._segment_size
715
716     def _update_segment_hash(self, chunk):
717         offset = 0
718         while offset < len(chunk):
719             p, segment_left = self._get_segment_hasher()
720             chunk_left = len(chunk) - offset
721             this_segment = min(chunk_left, segment_left)
722             p.update(chunk[offset:offset+this_segment])
723             self._plaintext_segment_hashed_bytes += this_segment
724
725             if self._plaintext_segment_hashed_bytes == self._segment_size:
726                 # we've filled this segment
727                 self._plaintext_segment_hashes.append(p.digest())
728                 self._plaintext_segment_hasher = None
729                 self.log("closed hash [%d]: %dB" %
730                          (len(self._plaintext_segment_hashes)-1,
731                           self._plaintext_segment_hashed_bytes),
732                          level=log.NOISY)
733                 self.log(format="plaintext leaf hash [%(segnum)d] is %(hash)s",
734                          segnum=len(self._plaintext_segment_hashes)-1,
735                          hash=base32.b2a(p.digest()),
736                          level=log.NOISY)
737
738             offset += this_segment
739
740
741     def read_encrypted(self, length, hash_only):
742         # make sure our parameters have been set up first
743         d = self.get_all_encoding_parameters()
744         # and size
745         d.addCallback(lambda ignored: self.get_size())
746         d.addCallback(lambda ignored: self._get_encryptor())
747         # then fetch and encrypt the plaintext. The unusual structure here
748         # (passing a Deferred *into* a function) is needed to avoid
749         # overflowing the stack: Deferreds don't optimize out tail recursion.
750         # We also pass in a list, to which _read_encrypted will append
751         # ciphertext.
752         ciphertext = []
753         d2 = defer.Deferred()
754         d.addCallback(lambda ignored:
755                       self._read_encrypted(length, ciphertext, hash_only, d2))
756         d.addCallback(lambda ignored: d2)
757         return d
758
759     def _read_encrypted(self, remaining, ciphertext, hash_only, fire_when_done):
760         if not remaining:
761             fire_when_done.callback(ciphertext)
762             return None
763         # tolerate large length= values without consuming a lot of RAM by
764         # reading just a chunk (say 50kB) at a time. This only really matters
765         # when hash_only==True (i.e. resuming an interrupted upload), since
766         # that's the case where we will be skipping over a lot of data.
767         size = min(remaining, self.CHUNKSIZE)
768         remaining = remaining - size
769         # read a chunk of plaintext..
770         d = defer.maybeDeferred(self.original.read, size)
771         # N.B.: if read() is synchronous, then since everything else is
772         # actually synchronous too, we'd blow the stack unless we stall for a
773         # tick. Once you accept a Deferred from IUploadable.read(), you must
774         # be prepared to have it fire immediately too.
775         d.addCallback(fireEventually)
776         def _good(plaintext):
777             # and encrypt it..
778             # o/' over the fields we go, hashing all the way, sHA! sHA! sHA! o/'
779             ct = self._hash_and_encrypt_plaintext(plaintext, hash_only)
780             ciphertext.extend(ct)
781             self._read_encrypted(remaining, ciphertext, hash_only,
782                                  fire_when_done)
783         def _err(why):
784             fire_when_done.errback(why)
785         d.addCallback(_good)
786         d.addErrback(_err)
787         return None
788
789     def _hash_and_encrypt_plaintext(self, data, hash_only):
790         assert isinstance(data, (tuple, list)), type(data)
791         data = list(data)
792         cryptdata = []
793         # we use data.pop(0) instead of 'for chunk in data' to save
794         # memory: each chunk is destroyed as soon as we're done with it.
795         bytes_processed = 0
796         while data:
797             chunk = data.pop(0)
798             self.log(" read_encrypted handling %dB-sized chunk" % len(chunk),
799                      level=log.NOISY)
800             bytes_processed += len(chunk)
801             self._plaintext_hasher.update(chunk)
802             self._update_segment_hash(chunk)
803             # TODO: we have to encrypt the data (even if hash_only==True)
804             # because pycryptopp's AES-CTR implementation doesn't offer a
805             # way to change the counter value. Once pycryptopp acquires
806             # this ability, change this to simply update the counter
807             # before each call to (hash_only==False) _encryptor.process()
808             ciphertext = self._encryptor.process(chunk)
809             if hash_only:
810                 self.log("  skipping encryption", level=log.NOISY)
811             else:
812                 cryptdata.append(ciphertext)
813             del ciphertext
814             del chunk
815         self._ciphertext_bytes_read += bytes_processed
816         if self._status:
817             progress = float(self._ciphertext_bytes_read) / self._file_size
818             self._status.set_progress(1, progress)
819         return cryptdata
820
821
822     def get_plaintext_hashtree_leaves(self, first, last, num_segments):
823         # this is currently unused, but will live again when we fix #453
824         if len(self._plaintext_segment_hashes) < num_segments:
825             # close out the last one
826             assert len(self._plaintext_segment_hashes) == num_segments-1
827             p, segment_left = self._get_segment_hasher()
828             self._plaintext_segment_hashes.append(p.digest())
829             del self._plaintext_segment_hasher
830             self.log("closing plaintext leaf hasher, hashed %d bytes" %
831                      self._plaintext_segment_hashed_bytes,
832                      level=log.NOISY)
833             self.log(format="plaintext leaf hash [%(segnum)d] is %(hash)s",
834                      segnum=len(self._plaintext_segment_hashes)-1,
835                      hash=base32.b2a(p.digest()),
836                      level=log.NOISY)
837         assert len(self._plaintext_segment_hashes) == num_segments
838         return defer.succeed(tuple(self._plaintext_segment_hashes[first:last]))
839
840     def get_plaintext_hash(self):
841         h = self._plaintext_hasher.digest()
842         return defer.succeed(h)
843
844     def close(self):
845         return self.original.close()
846
847 class UploadStatus:
848     implements(IUploadStatus)
849     statusid_counter = itertools.count(0)
850
851     def __init__(self):
852         self.storage_index = None
853         self.size = None
854         self.helper = False
855         self.status = "Not started"
856         self.progress = [0.0, 0.0, 0.0]
857         self.active = True
858         self.results = None
859         self.counter = self.statusid_counter.next()
860         self.started = time.time()
861
862     def get_started(self):
863         return self.started
864     def get_storage_index(self):
865         return self.storage_index
866     def get_size(self):
867         return self.size
868     def using_helper(self):
869         return self.helper
870     def get_status(self):
871         return self.status
872     def get_progress(self):
873         return tuple(self.progress)
874     def get_active(self):
875         return self.active
876     def get_results(self):
877         return self.results
878     def get_counter(self):
879         return self.counter
880
881     def set_storage_index(self, si):
882         self.storage_index = si
883     def set_size(self, size):
884         self.size = size
885     def set_helper(self, helper):
886         self.helper = helper
887     def set_status(self, status):
888         self.status = status
889     def set_progress(self, which, value):
890         # [0]: chk, [1]: ciphertext, [2]: encode+push
891         self.progress[which] = value
892     def set_active(self, value):
893         self.active = value
894     def set_results(self, value):
895         self.results = value
896
897 class CHKUploader:
898     server_selector_class = Tahoe2ServerSelector
899
900     def __init__(self, storage_broker, secret_holder, progress=None):
901         # server_selector needs storage_broker and secret_holder
902         self._storage_broker = storage_broker
903         self._secret_holder = secret_holder
904         self._log_number = self.log("CHKUploader starting", parent=None)
905         self._encoder = None
906         self._storage_index = None
907         self._upload_status = UploadStatus()
908         self._upload_status.set_helper(False)
909         self._upload_status.set_active(True)
910         self._progress = progress
911
912         # locate_all_shareholders() will create the following attribute:
913         # self._server_trackers = {} # k: shnum, v: instance of ServerTracker
914
915     def log(self, *args, **kwargs):
916         if "parent" not in kwargs:
917             kwargs["parent"] = self._log_number
918         if "facility" not in kwargs:
919             kwargs["facility"] = "tahoe.upload"
920         return log.msg(*args, **kwargs)
921
922     def start(self, encrypted_uploadable):
923         """Start uploading the file.
924
925         Returns a Deferred that will fire with the UploadResults instance.
926         """
927
928         self._started = time.time()
929         eu = IEncryptedUploadable(encrypted_uploadable)
930         self.log("starting upload of %s" % eu)
931
932         eu.set_upload_status(self._upload_status)
933         d = self.start_encrypted(eu)
934         def _done(uploadresults):
935             self._upload_status.set_active(False)
936             return uploadresults
937         d.addBoth(_done)
938         return d
939
940     def abort(self):
941         """Call this if the upload must be abandoned before it completes.
942         This will tell the shareholders to delete their partial shares. I
943         return a Deferred that fires when these messages have been acked."""
944         if not self._encoder:
945             # how did you call abort() before calling start() ?
946             return defer.succeed(None)
947         return self._encoder.abort()
948
949     def start_encrypted(self, encrypted):
950         """ Returns a Deferred that will fire with the UploadResults instance. """
951         eu = IEncryptedUploadable(encrypted)
952
953         started = time.time()
954         self._encoder = e = encode.Encoder(
955             self._log_number,
956             self._upload_status,
957             progress=self._progress,
958         )
959         d = e.set_encrypted_uploadable(eu)
960         d.addCallback(self.locate_all_shareholders, started)
961         d.addCallback(self.set_shareholders, e)
962         d.addCallback(lambda res: e.start())
963         d.addCallback(self._encrypted_done)
964         return d
965
966     def locate_all_shareholders(self, encoder, started):
967         server_selection_started = now = time.time()
968         self._storage_index_elapsed = now - started
969         storage_broker = self._storage_broker
970         secret_holder = self._secret_holder
971         storage_index = encoder.get_param("storage_index")
972         self._storage_index = storage_index
973         upload_id = si_b2a(storage_index)[:5]
974         self.log("using storage index %s" % upload_id)
975         server_selector = self.server_selector_class(upload_id,
976                                                      self._log_number,
977                                                      self._upload_status)
978
979         share_size = encoder.get_param("share_size")
980         block_size = encoder.get_param("block_size")
981         num_segments = encoder.get_param("num_segments")
982         k,desired,n = encoder.get_param("share_counts")
983
984         self._server_selection_started = time.time()
985         d = server_selector.get_shareholders(storage_broker, secret_holder,
986                                              storage_index,
987                                              share_size, block_size,
988                                              num_segments, n, k, desired)
989         def _done(res):
990             self._server_selection_elapsed = time.time() - server_selection_started
991             return res
992         d.addCallback(_done)
993         return d
994
995     def set_shareholders(self, (upload_trackers, already_serverids), encoder):
996         """
997         @param upload_trackers: a sequence of ServerTracker objects that
998                                 have agreed to hold some shares for us (the
999                                 shareids are stashed inside the ServerTracker)
1000
1001         @paran already_serverids: a dict mapping sharenum to a set of
1002                                   serverids for servers that claim to already
1003                                   have this share
1004         """
1005         msgtempl = "set_shareholders; upload_trackers is %s, already_serverids is %s"
1006         values = ([', '.join([str_shareloc(k,v)
1007                               for k,v in st.buckets.iteritems()])
1008                    for st in upload_trackers], already_serverids)
1009         self.log(msgtempl % values, level=log.OPERATIONAL)
1010         # record already-present shares in self._results
1011         self._count_preexisting_shares = len(already_serverids)
1012
1013         self._server_trackers = {} # k: shnum, v: instance of ServerTracker
1014         for tracker in upload_trackers:
1015             assert isinstance(tracker, ServerTracker)
1016         buckets = {}
1017         servermap = already_serverids.copy()
1018         for tracker in upload_trackers:
1019             buckets.update(tracker.buckets)
1020             for shnum in tracker.buckets:
1021                 self._server_trackers[shnum] = tracker
1022                 servermap.setdefault(shnum, set()).add(tracker.get_serverid())
1023         assert len(buckets) == sum([len(tracker.buckets)
1024                                     for tracker in upload_trackers]), \
1025             "%s (%s) != %s (%s)" % (
1026                 len(buckets),
1027                 buckets,
1028                 sum([len(tracker.buckets) for tracker in upload_trackers]),
1029                 [(t.buckets, t.get_serverid()) for t in upload_trackers]
1030                 )
1031         encoder.set_shareholders(buckets, servermap)
1032
1033     def _encrypted_done(self, verifycap):
1034         """Returns a Deferred that will fire with the UploadResults instance."""
1035         e = self._encoder
1036         sharemap = dictutil.DictOfSets()
1037         servermap = dictutil.DictOfSets()
1038         for shnum in e.get_shares_placed():
1039             server = self._server_trackers[shnum].get_server()
1040             sharemap.add(shnum, server)
1041             servermap.add(server, shnum)
1042         now = time.time()
1043         timings = {}
1044         timings["total"] = now - self._started
1045         timings["storage_index"] = self._storage_index_elapsed
1046         timings["peer_selection"] = self._server_selection_elapsed
1047         timings.update(e.get_times())
1048         ur = UploadResults(file_size=e.file_size,
1049                            ciphertext_fetched=0,
1050                            preexisting_shares=self._count_preexisting_shares,
1051                            pushed_shares=len(e.get_shares_placed()),
1052                            sharemap=sharemap,
1053                            servermap=servermap,
1054                            timings=timings,
1055                            uri_extension_data=e.get_uri_extension_data(),
1056                            uri_extension_hash=e.get_uri_extension_hash(),
1057                            verifycapstr=verifycap.to_string())
1058         self._upload_status.set_results(ur)
1059         return ur
1060
1061     def get_upload_status(self):
1062         return self._upload_status
1063
1064 def read_this_many_bytes(uploadable, size, prepend_data=[]):
1065     if size == 0:
1066         return defer.succeed([])
1067     d = uploadable.read(size)
1068     def _got(data):
1069         assert isinstance(data, list)
1070         bytes = sum([len(piece) for piece in data])
1071         assert bytes > 0
1072         assert bytes <= size
1073         remaining = size - bytes
1074         if remaining:
1075             return read_this_many_bytes(uploadable, remaining,
1076                                         prepend_data + data)
1077         return prepend_data + data
1078     d.addCallback(_got)
1079     return d
1080
1081 class LiteralUploader:
1082
1083     def __init__(self, progress=None):
1084         self._status = s = UploadStatus()
1085         s.set_storage_index(None)
1086         s.set_helper(False)
1087         s.set_progress(0, 1.0)
1088         s.set_active(False)
1089         self._progress = progress
1090
1091     def start(self, uploadable):
1092         uploadable = IUploadable(uploadable)
1093         d = uploadable.get_size()
1094         def _got_size(size):
1095             self._size = size
1096             self._status.set_size(size)
1097             if self._progress:
1098                 self._progress.set_progress_total(size)
1099             return read_this_many_bytes(uploadable, size)
1100         d.addCallback(_got_size)
1101         d.addCallback(lambda data: uri.LiteralFileURI("".join(data)))
1102         d.addCallback(lambda u: u.to_string())
1103         d.addCallback(self._build_results)
1104         return d
1105
1106     def _build_results(self, uri):
1107         ur = UploadResults(file_size=self._size,
1108                            ciphertext_fetched=0,
1109                            preexisting_shares=0,
1110                            pushed_shares=0,
1111                            sharemap={},
1112                            servermap={},
1113                            timings={},
1114                            uri_extension_data=None,
1115                            uri_extension_hash=None,
1116                            verifycapstr=None)
1117         ur.set_uri(uri)
1118         self._status.set_status("Finished")
1119         self._status.set_progress(1, 1.0)
1120         self._status.set_progress(2, 1.0)
1121         self._status.set_results(ur)
1122         if self._progress:
1123             self._progress.set_progress(self._size)
1124         return ur
1125
1126     def close(self):
1127         pass
1128
1129     def get_upload_status(self):
1130         return self._status
1131
1132 class RemoteEncryptedUploadable(Referenceable):
1133     implements(RIEncryptedUploadable)
1134
1135     def __init__(self, encrypted_uploadable, upload_status):
1136         self._eu = IEncryptedUploadable(encrypted_uploadable)
1137         self._offset = 0
1138         self._bytes_sent = 0
1139         self._status = IUploadStatus(upload_status)
1140         # we are responsible for updating the status string while we run, and
1141         # for setting the ciphertext-fetch progress.
1142         self._size = None
1143
1144     def get_size(self):
1145         if self._size is not None:
1146             return defer.succeed(self._size)
1147         d = self._eu.get_size()
1148         def _got_size(size):
1149             self._size = size
1150             return size
1151         d.addCallback(_got_size)
1152         return d
1153
1154     def remote_get_size(self):
1155         return self.get_size()
1156     def remote_get_all_encoding_parameters(self):
1157         return self._eu.get_all_encoding_parameters()
1158
1159     def _read_encrypted(self, length, hash_only):
1160         d = self._eu.read_encrypted(length, hash_only)
1161         def _read(strings):
1162             if hash_only:
1163                 self._offset += length
1164             else:
1165                 size = sum([len(data) for data in strings])
1166                 self._offset += size
1167             return strings
1168         d.addCallback(_read)
1169         return d
1170
1171     def remote_read_encrypted(self, offset, length):
1172         # we don't support seek backwards, but we allow skipping forwards
1173         precondition(offset >= 0, offset)
1174         precondition(length >= 0, length)
1175         lp = log.msg("remote_read_encrypted(%d-%d)" % (offset, offset+length),
1176                      level=log.NOISY)
1177         precondition(offset >= self._offset, offset, self._offset)
1178         if offset > self._offset:
1179             # read the data from disk anyways, to build up the hash tree
1180             skip = offset - self._offset
1181             log.msg("remote_read_encrypted skipping ahead from %d to %d, skip=%d" %
1182                     (self._offset, offset, skip), level=log.UNUSUAL, parent=lp)
1183             d = self._read_encrypted(skip, hash_only=True)
1184         else:
1185             d = defer.succeed(None)
1186
1187         def _at_correct_offset(res):
1188             assert offset == self._offset, "%d != %d" % (offset, self._offset)
1189             return self._read_encrypted(length, hash_only=False)
1190         d.addCallback(_at_correct_offset)
1191
1192         def _read(strings):
1193             size = sum([len(data) for data in strings])
1194             self._bytes_sent += size
1195             return strings
1196         d.addCallback(_read)
1197         return d
1198
1199     def remote_close(self):
1200         return self._eu.close()
1201
1202
1203 class AssistedUploader:
1204
1205     def __init__(self, helper, storage_broker):
1206         self._helper = helper
1207         self._storage_broker = storage_broker
1208         self._log_number = log.msg("AssistedUploader starting")
1209         self._storage_index = None
1210         self._upload_status = s = UploadStatus()
1211         s.set_helper(True)
1212         s.set_active(True)
1213
1214     def log(self, *args, **kwargs):
1215         if "parent" not in kwargs:
1216             kwargs["parent"] = self._log_number
1217         return log.msg(*args, **kwargs)
1218
1219     def start(self, encrypted_uploadable, storage_index):
1220         """Start uploading the file.
1221
1222         Returns a Deferred that will fire with the UploadResults instance.
1223         """
1224         precondition(isinstance(storage_index, str), storage_index)
1225         self._started = time.time()
1226         eu = IEncryptedUploadable(encrypted_uploadable)
1227         eu.set_upload_status(self._upload_status)
1228         self._encuploadable = eu
1229         self._storage_index = storage_index
1230         d = eu.get_size()
1231         d.addCallback(self._got_size)
1232         d.addCallback(lambda res: eu.get_all_encoding_parameters())
1233         d.addCallback(self._got_all_encoding_parameters)
1234         d.addCallback(self._contact_helper)
1235         d.addCallback(self._build_verifycap)
1236         def _done(res):
1237             self._upload_status.set_active(False)
1238             return res
1239         d.addBoth(_done)
1240         return d
1241
1242     def _got_size(self, size):
1243         self._size = size
1244         self._upload_status.set_size(size)
1245
1246     def _got_all_encoding_parameters(self, params):
1247         k, happy, n, segment_size = params
1248         # stash these for URI generation later
1249         self._needed_shares = k
1250         self._total_shares = n
1251         self._segment_size = segment_size
1252
1253     def _contact_helper(self, res):
1254         now = self._time_contacting_helper_start = time.time()
1255         self._storage_index_elapsed = now - self._started
1256         self.log(format="contacting helper for SI %(si)s..",
1257                  si=si_b2a(self._storage_index), level=log.NOISY)
1258         self._upload_status.set_status("Contacting Helper")
1259         d = self._helper.callRemote("upload_chk", self._storage_index)
1260         d.addCallback(self._contacted_helper)
1261         return d
1262
1263     def _contacted_helper(self, (helper_upload_results, upload_helper)):
1264         now = time.time()
1265         elapsed = now - self._time_contacting_helper_start
1266         self._elapsed_time_contacting_helper = elapsed
1267         if upload_helper:
1268             self.log("helper says we need to upload", level=log.NOISY)
1269             self._upload_status.set_status("Uploading Ciphertext")
1270             # we need to upload the file
1271             reu = RemoteEncryptedUploadable(self._encuploadable,
1272                                             self._upload_status)
1273             # let it pre-compute the size for progress purposes
1274             d = reu.get_size()
1275             d.addCallback(lambda ignored:
1276                           upload_helper.callRemote("upload", reu))
1277             # this Deferred will fire with the upload results
1278             return d
1279         self.log("helper says file is already uploaded", level=log.OPERATIONAL)
1280         self._upload_status.set_progress(1, 1.0)
1281         return helper_upload_results
1282
1283     def _convert_old_upload_results(self, upload_results):
1284         # pre-1.3.0 helpers return upload results which contain a mapping
1285         # from shnum to a single human-readable string, containing things
1286         # like "Found on [x],[y],[z]" (for healthy files that were already in
1287         # the grid), "Found on [x]" (for files that needed upload but which
1288         # discovered pre-existing shares), and "Placed on [x]" (for newly
1289         # uploaded shares). The 1.3.0 helper returns a mapping from shnum to
1290         # set of binary serverid strings.
1291
1292         # the old results are too hard to deal with (they don't even contain
1293         # as much information as the new results, since the nodeids are
1294         # abbreviated), so if we detect old results, just clobber them.
1295
1296         sharemap = upload_results.sharemap
1297         if str in [type(v) for v in sharemap.values()]:
1298             upload_results.sharemap = None
1299
1300     def _build_verifycap(self, helper_upload_results):
1301         self.log("upload finished, building readcap", level=log.OPERATIONAL)
1302         self._convert_old_upload_results(helper_upload_results)
1303         self._upload_status.set_status("Building Readcap")
1304         hur = helper_upload_results
1305         assert hur.uri_extension_data["needed_shares"] == self._needed_shares
1306         assert hur.uri_extension_data["total_shares"] == self._total_shares
1307         assert hur.uri_extension_data["segment_size"] == self._segment_size
1308         assert hur.uri_extension_data["size"] == self._size
1309
1310         # hur.verifycap doesn't exist if already found
1311         v = uri.CHKFileVerifierURI(self._storage_index,
1312                                    uri_extension_hash=hur.uri_extension_hash,
1313                                    needed_shares=self._needed_shares,
1314                                    total_shares=self._total_shares,
1315                                    size=self._size)
1316         timings = {}
1317         timings["storage_index"] = self._storage_index_elapsed
1318         timings["contacting_helper"] = self._elapsed_time_contacting_helper
1319         for key,val in hur.timings.items():
1320             if key == "total":
1321                 key = "helper_total"
1322             timings[key] = val
1323         now = time.time()
1324         timings["total"] = now - self._started
1325
1326         gss = self._storage_broker.get_stub_server
1327         sharemap = {}
1328         servermap = {}
1329         for shnum, serverids in hur.sharemap.items():
1330             sharemap[shnum] = set([gss(serverid) for serverid in serverids])
1331         # if the file was already in the grid, hur.servermap is an empty dict
1332         for serverid, shnums in hur.servermap.items():
1333             servermap[gss(serverid)] = set(shnums)
1334
1335         ur = UploadResults(file_size=self._size,
1336                            # not if already found
1337                            ciphertext_fetched=hur.ciphertext_fetched,
1338                            preexisting_shares=hur.preexisting_shares,
1339                            pushed_shares=hur.pushed_shares,
1340                            sharemap=sharemap,
1341                            servermap=servermap,
1342                            timings=timings,
1343                            uri_extension_data=hur.uri_extension_data,
1344                            uri_extension_hash=hur.uri_extension_hash,
1345                            verifycapstr=v.to_string())
1346
1347         self._upload_status.set_status("Finished")
1348         self._upload_status.set_results(ur)
1349         return ur
1350
1351     def get_upload_status(self):
1352         return self._upload_status
1353
1354 class BaseUploadable:
1355     # this is overridden by max_segment_size
1356     default_max_segment_size = DEFAULT_MAX_SEGMENT_SIZE
1357     default_params_set = False
1358
1359     max_segment_size = None
1360     encoding_param_k = None
1361     encoding_param_happy = None
1362     encoding_param_n = None
1363
1364     _all_encoding_parameters = None
1365     _status = None
1366
1367     def set_upload_status(self, upload_status):
1368         self._status = IUploadStatus(upload_status)
1369
1370     def set_default_encoding_parameters(self, default_params):
1371         assert isinstance(default_params, dict)
1372         for k,v in default_params.items():
1373             precondition(isinstance(k, str), k, v)
1374             precondition(isinstance(v, int), k, v)
1375         if "k" in default_params:
1376             self.default_encoding_param_k = default_params["k"]
1377         if "happy" in default_params:
1378             self.default_encoding_param_happy = default_params["happy"]
1379         if "n" in default_params:
1380             self.default_encoding_param_n = default_params["n"]
1381         if "max_segment_size" in default_params:
1382             self.default_max_segment_size = default_params["max_segment_size"]
1383         self.default_params_set = True
1384
1385     def get_all_encoding_parameters(self):
1386         _assert(self.default_params_set, "set_default_encoding_parameters not called on %r" % (self,))
1387         if self._all_encoding_parameters:
1388             return defer.succeed(self._all_encoding_parameters)
1389
1390         max_segsize = self.max_segment_size or self.default_max_segment_size
1391         k = self.encoding_param_k or self.default_encoding_param_k
1392         happy = self.encoding_param_happy or self.default_encoding_param_happy
1393         n = self.encoding_param_n or self.default_encoding_param_n
1394
1395         d = self.get_size()
1396         def _got_size(file_size):
1397             # for small files, shrink the segment size to avoid wasting space
1398             segsize = min(max_segsize, file_size)
1399             # this must be a multiple of 'required_shares'==k
1400             segsize = mathutil.next_multiple(segsize, k)
1401             encoding_parameters = (k, happy, n, segsize)
1402             self._all_encoding_parameters = encoding_parameters
1403             return encoding_parameters
1404         d.addCallback(_got_size)
1405         return d
1406
1407 class FileHandle(BaseUploadable):
1408     implements(IUploadable)
1409
1410     def __init__(self, filehandle, convergence):
1411         """
1412         Upload the data from the filehandle.  If convergence is None then a
1413         random encryption key will be used, else the plaintext will be hashed,
1414         then the hash will be hashed together with the string in the
1415         "convergence" argument to form the encryption key.
1416         """
1417         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1418         self._filehandle = filehandle
1419         self._key = None
1420         self.convergence = convergence
1421         self._size = None
1422
1423     def _get_encryption_key_convergent(self):
1424         if self._key is not None:
1425             return defer.succeed(self._key)
1426
1427         d = self.get_size()
1428         # that sets self._size as a side-effect
1429         d.addCallback(lambda size: self.get_all_encoding_parameters())
1430         def _got(params):
1431             k, happy, n, segsize = params
1432             f = self._filehandle
1433             enckey_hasher = convergence_hasher(k, n, segsize, self.convergence)
1434             f.seek(0)
1435             BLOCKSIZE = 64*1024
1436             bytes_read = 0
1437             while True:
1438                 data = f.read(BLOCKSIZE)
1439                 if not data:
1440                     break
1441                 enckey_hasher.update(data)
1442                 # TODO: setting progress in a non-yielding loop is kind of
1443                 # pointless, but I'm anticipating (perhaps prematurely) the
1444                 # day when we use a slowjob or twisted's CooperatorService to
1445                 # make this yield time to other jobs.
1446                 bytes_read += len(data)
1447                 if self._status:
1448                     self._status.set_progress(0, float(bytes_read)/self._size)
1449             f.seek(0)
1450             self._key = enckey_hasher.digest()
1451             if self._status:
1452                 self._status.set_progress(0, 1.0)
1453             assert len(self._key) == 16
1454             return self._key
1455         d.addCallback(_got)
1456         return d
1457
1458     def _get_encryption_key_random(self):
1459         if self._key is None:
1460             self._key = os.urandom(16)
1461         return defer.succeed(self._key)
1462
1463     def get_encryption_key(self):
1464         if self.convergence is not None:
1465             return self._get_encryption_key_convergent()
1466         else:
1467             return self._get_encryption_key_random()
1468
1469     def get_size(self):
1470         if self._size is not None:
1471             return defer.succeed(self._size)
1472         self._filehandle.seek(0, os.SEEK_END)
1473         size = self._filehandle.tell()
1474         self._size = size
1475         self._filehandle.seek(0)
1476         return defer.succeed(size)
1477
1478     def read(self, length):
1479         return defer.succeed([self._filehandle.read(length)])
1480
1481     def close(self):
1482         # the originator of the filehandle reserves the right to close it
1483         pass
1484
1485 class FileName(FileHandle):
1486     def __init__(self, filename, convergence):
1487         """
1488         Upload the data from the filename.  If convergence is None then a
1489         random encryption key will be used, else the plaintext will be hashed,
1490         then the hash will be hashed together with the string in the
1491         "convergence" argument to form the encryption key.
1492         """
1493         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1494         FileHandle.__init__(self, open(filename, "rb"), convergence=convergence)
1495     def close(self):
1496         FileHandle.close(self)
1497         self._filehandle.close()
1498
1499 class Data(FileHandle):
1500     def __init__(self, data, convergence):
1501         """
1502         Upload the data from the data argument.  If convergence is None then a
1503         random encryption key will be used, else the plaintext will be hashed,
1504         then the hash will be hashed together with the string in the
1505         "convergence" argument to form the encryption key.
1506         """
1507         assert convergence is None or isinstance(convergence, str), (convergence, type(convergence))
1508         FileHandle.__init__(self, StringIO(data), convergence=convergence)
1509
1510 class Uploader(service.MultiService, log.PrefixingLogMixin):
1511     """I am a service that allows file uploading. I am a service-child of the
1512     Client.
1513     """
1514     implements(IUploader)
1515     name = "uploader"
1516     URI_LIT_SIZE_THRESHOLD = 55
1517
1518     def __init__(self, helper_furl=None, stats_provider=None, history=None, progress=None):
1519         self._helper_furl = helper_furl
1520         self.stats_provider = stats_provider
1521         self._history = history
1522         self._helper = None
1523         self._all_uploads = weakref.WeakKeyDictionary() # for debugging
1524         self._progress = progress
1525         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.upload")
1526         service.MultiService.__init__(self)
1527
1528     def startService(self):
1529         service.MultiService.startService(self)
1530         if self._helper_furl:
1531             self.parent.tub.connectTo(self._helper_furl,
1532                                       self._got_helper)
1533
1534     def _got_helper(self, helper):
1535         self.log("got helper connection, getting versions")
1536         default = { "http://allmydata.org/tahoe/protocols/helper/v1" :
1537                     { },
1538                     "application-version": "unknown: no get_version()",
1539                     }
1540         d = add_version_to_remote_reference(helper, default)
1541         d.addCallback(self._got_versioned_helper)
1542
1543     def _got_versioned_helper(self, helper):
1544         needed = "http://allmydata.org/tahoe/protocols/helper/v1"
1545         if needed not in helper.version:
1546             raise InsufficientVersionError(needed, helper.version)
1547         self._helper = helper
1548         helper.notifyOnDisconnect(self._lost_helper)
1549
1550     def _lost_helper(self):
1551         self._helper = None
1552
1553     def get_helper_info(self):
1554         # return a tuple of (helper_furl_or_None, connected_bool)
1555         return (self._helper_furl, bool(self._helper))
1556
1557
1558     def upload(self, uploadable, progress=None):
1559         """
1560         Returns a Deferred that will fire with the UploadResults instance.
1561         """
1562         assert self.parent
1563         assert self.running
1564         assert progress is None or IProgress.providedBy(progress)
1565
1566         uploadable = IUploadable(uploadable)
1567         d = uploadable.get_size()
1568         def _got_size(size):
1569             default_params = self.parent.get_encoding_parameters()
1570             precondition(isinstance(default_params, dict), default_params)
1571             precondition("max_segment_size" in default_params, default_params)
1572             uploadable.set_default_encoding_parameters(default_params)
1573             if progress:
1574                 progress.set_progress_total(size)
1575
1576             if self.stats_provider:
1577                 self.stats_provider.count('uploader.files_uploaded', 1)
1578                 self.stats_provider.count('uploader.bytes_uploaded', size)
1579
1580             if size <= self.URI_LIT_SIZE_THRESHOLD:
1581                 uploader = LiteralUploader(progress=progress)
1582                 return uploader.start(uploadable)
1583             else:
1584                 eu = EncryptAnUploadable(uploadable, self._parentmsgid)
1585                 d2 = defer.succeed(None)
1586                 storage_broker = self.parent.get_storage_broker()
1587                 if self._helper:
1588                     uploader = AssistedUploader(self._helper, storage_broker)
1589                     d2.addCallback(lambda x: eu.get_storage_index())
1590                     d2.addCallback(lambda si: uploader.start(eu, si))
1591                 else:
1592                     storage_broker = self.parent.get_storage_broker()
1593                     secret_holder = self.parent._secret_holder
1594                     uploader = CHKUploader(storage_broker, secret_holder, progress=progress)
1595                     d2.addCallback(lambda x: uploader.start(eu))
1596
1597                 self._all_uploads[uploader] = None
1598                 if self._history:
1599                     self._history.add_upload(uploader.get_upload_status())
1600                 def turn_verifycap_into_read_cap(uploadresults):
1601                     # Generate the uri from the verifycap plus the key.
1602                     d3 = uploadable.get_encryption_key()
1603                     def put_readcap_into_results(key):
1604                         v = uri.from_string(uploadresults.get_verifycapstr())
1605                         r = uri.CHKFileURI(key, v.uri_extension_hash, v.needed_shares, v.total_shares, v.size)
1606                         uploadresults.set_uri(r.to_string())
1607                         return uploadresults
1608                     d3.addCallback(put_readcap_into_results)
1609                     return d3
1610                 d2.addCallback(turn_verifycap_into_read_cap)
1611                 return d2
1612         d.addCallback(_got_size)
1613         def _done(res):
1614             uploadable.close()
1615             return res
1616         d.addBoth(_done)
1617         return d