]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/immutable/upload.py
Add assertions to make sure that set_default_encoding_parameters is always called...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / immutable / upload.py
index ccab2aabac673f92707312aa6b5eb39b366f8436..c63240463cd1a510d411e8939d746fb3a5cd0b4f 100644 (file)
@@ -16,7 +16,7 @@ from allmydata.util import base32, dictutil, idlib, log, mathutil
 from allmydata.util.happinessutil import servers_of_happiness, \
                                          shares_by_server, merge_servers, \
                                          failure_message
-from allmydata.util.assertutil import precondition
+from allmydata.util.assertutil import precondition, _assert
 from allmydata.util.rrefutil import add_version_to_remote_reference
 from allmydata.interfaces import IUploadable, IUploader, IUploadResults, \
      IEncryptedUploadable, RIEncryptedUploadable, IUploadStatus, \
@@ -64,27 +64,46 @@ class UploadResults:
                  ciphertext_fetched, # how much the helper fetched
                  preexisting_shares, # count of shares already present
                  pushed_shares, # count of shares we pushed
-                 sharemap, # {shnum: set(serverid)}
-                 servermap, # {serverid: set(shnum)}
+                 sharemap, # {shnum: set(server)}
+                 servermap, # {server: set(shnum)}
                  timings, # dict of name to number of seconds
                  uri_extension_data,
                  uri_extension_hash,
                  verifycapstr):
-        self.file_size = file_size
-        self.ciphertext_fetched = ciphertext_fetched
-        self.preexisting_shares = preexisting_shares
-        self.pushed_shares = pushed_shares
-        self.sharemap = sharemap
-        self.servermap = servermap
-        self.timings = timings
-        self.uri_extension_data = uri_extension_data
-        self.uri_extension_hash = uri_extension_hash
-        self.verifycapstr = verifycapstr
-        self.uri = None
+        self._file_size = file_size
+        self._ciphertext_fetched = ciphertext_fetched
+        self._preexisting_shares = preexisting_shares
+        self._pushed_shares = pushed_shares
+        self._sharemap = sharemap
+        self._servermap = servermap
+        self._timings = timings
+        self._uri_extension_data = uri_extension_data
+        self._uri_extension_hash = uri_extension_hash
+        self._verifycapstr = verifycapstr
 
     def set_uri(self, uri):
-        self.uri = uri
-
+        self._uri = uri
+
+    def get_file_size(self):
+        return self._file_size
+    def get_uri(self):
+        return self._uri
+    def get_ciphertext_fetched(self):
+        return self._ciphertext_fetched
+    def get_preexisting_shares(self):
+        return self._preexisting_shares
+    def get_pushed_shares(self):
+        return self._pushed_shares
+    def get_sharemap(self):
+        return self._sharemap
+    def get_servermap(self):
+        return self._servermap
+    def get_timings(self):
+        return self._timings
+    def get_uri_extension_data(self):
+        return self._uri_extension_data
+    def get_verifycapstr(self):
+        return self._verifycapstr
 
 # our current uri_extension is 846 bytes for small files, a few bytes
 # more for larger ones (since the filesize is encoded in decimal in a
@@ -125,6 +144,8 @@ class ServerTracker:
         return ("<ServerTracker for server %s and SI %s>"
                 % (self._server.get_name(), si_b2a(self.storage_index)[:5]))
 
+    def get_server(self):
+        return self._server
     def get_serverid(self):
         return self._server.get_serverid()
     def get_name(self):
@@ -603,6 +624,8 @@ class EncryptAnUploadable:
     CHUNKSIZE = 50*1024
 
     def __init__(self, original, log_parent=None):
+        precondition(original.default_params_set,
+                     "set_default_encoding_parameters not called on %r before wrapping with EncryptAnUploadable" % (original,))
         self.original = IUploadable(original)
         self._log_number = log_parent
         self._encryptor = None
@@ -1006,10 +1029,9 @@ class CHKUploader:
         sharemap = dictutil.DictOfSets()
         servermap = dictutil.DictOfSets()
         for shnum in e.get_shares_placed():
-            server_tracker = self._server_trackers[shnum]
-            serverid = server_tracker.get_serverid()
-            sharemap.add(shnum, serverid)
-            servermap.add(serverid, shnum)
+            server = self._server_trackers[shnum].get_server()
+            sharemap.add(shnum, server)
+            servermap.add(server, shnum)
         now = time.time()
         timings = {}
         timings["total"] = now - self._started
@@ -1168,8 +1190,9 @@ class RemoteEncryptedUploadable(Referenceable):
 
 class AssistedUploader:
 
-    def __init__(self, helper):
+    def __init__(self, helper, storage_broker):
         self._helper = helper
+        self._storage_broker = storage_broker
         self._log_number = log.msg("AssistedUploader starting")
         self._storage_index = None
         self._upload_status = s = UploadStatus()
@@ -1288,13 +1311,22 @@ class AssistedUploader:
         now = time.time()
         timings["total"] = now - self._started
 
+        gss = self._storage_broker.get_stub_server
+        sharemap = {}
+        servermap = {}
+        for shnum, serverids in hur.sharemap.items():
+            sharemap[shnum] = set([gss(serverid) for serverid in serverids])
+        # if the file was already in the grid, hur.servermap is an empty dict
+        for serverid, shnums in hur.servermap.items():
+            servermap[gss(serverid)] = set(shnums)
+
         ur = UploadResults(file_size=self._size,
                            # not if already found
                            ciphertext_fetched=hur.ciphertext_fetched,
                            preexisting_shares=hur.preexisting_shares,
                            pushed_shares=hur.pushed_shares,
-                           sharemap=hur.sharemap,
-                           servermap=hur.servermap, # not if already found
+                           sharemap=sharemap,
+                           servermap=servermap,
                            timings=timings,
                            uri_extension_data=hur.uri_extension_data,
                            uri_extension_hash=hur.uri_extension_hash,
@@ -1310,9 +1342,7 @@ class AssistedUploader:
 class BaseUploadable:
     # this is overridden by max_segment_size
     default_max_segment_size = DEFAULT_MAX_SEGMENT_SIZE
-    default_encoding_param_k = 3 # overridden by encoding_parameters
-    default_encoding_param_happy = 7
-    default_encoding_param_n = 10
+    default_params_set = False
 
     max_segment_size = None
     encoding_param_k = None
@@ -1338,8 +1368,10 @@ class BaseUploadable:
             self.default_encoding_param_n = default_params["n"]
         if "max_segment_size" in default_params:
             self.default_max_segment_size = default_params["max_segment_size"]
+        self.default_params_set = True
 
     def get_all_encoding_parameters(self):
+        _assert(self.default_params_set, "set_default_encoding_parameters not called on %r" % (self,))
         if self._all_encoding_parameters:
             return defer.succeed(self._all_encoding_parameters)
 
@@ -1535,8 +1567,9 @@ class Uploader(service.MultiService, log.PrefixingLogMixin):
             else:
                 eu = EncryptAnUploadable(uploadable, self._parentmsgid)
                 d2 = defer.succeed(None)
+                storage_broker = self.parent.get_storage_broker()
                 if self._helper:
-                    uploader = AssistedUploader(self._helper)
+                    uploader = AssistedUploader(self._helper, storage_broker)
                     d2.addCallback(lambda x: eu.get_storage_index())
                     d2.addCallback(lambda si: uploader.start(eu, si))
                 else:
@@ -1552,7 +1585,7 @@ class Uploader(service.MultiService, log.PrefixingLogMixin):
                     # Generate the uri from the verifycap plus the key.
                     d3 = uploadable.get_encryption_key()
                     def put_readcap_into_results(key):
-                        v = uri.from_string(uploadresults.verifycapstr)
+                        v = uri.from_string(uploadresults.get_verifycapstr())
                         r = uri.CHKFileURI(key, v.uri_extension_hash, v.needed_shares, v.total_shares, v.size)
                         uploadresults.set_uri(r.to_string())
                         return uploadresults