]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
Fix a bug in mutable publish that could cause an IndexError when a writer is removed...
authordavid-sarah <david-sarah@jacaranda.org>
Fri, 15 Jun 2012 03:44:37 +0000 (03:44 +0000)
committerdavid-sarah <david-sarah@jacaranda.org>
Fri, 15 Jun 2012 03:44:37 +0000 (03:44 +0000)
src/allmydata/mutable/publish.py

index 3b7c6e0595da443a62f6fbd841c74d28abf7d8d8..c87e89781b40782e412f5f3a4536d2ac29cce31a 100644 (file)
@@ -253,7 +253,9 @@ class Publish:
         # updating, we ignore damaged and missing shares -- callers must
         # do a repair to repair and recreate these.
         self.goal = set(self._servermap.get_known_shares())
-        self.writers = {}
+
+        # shnum -> set of IMutableSlotWriter
+        self.writers = DictOfSets()
 
         # SDMF files are updated differently.
         self._version = MDMF_VERSION
@@ -278,7 +280,7 @@ class Publish:
                                   self.segment_size,
                                   self.datalength)
 
-            self.writers.setdefault(shnum, []).append(writer)
+            self.writers.add(shnum, writer)
             writer.server = server
             known_shares = self._servermap.get_known_shares()
             assert (server, shnum) in known_shares
@@ -294,7 +296,7 @@ class Publish:
         # after we are done writing share data and have started to write
         # blocks. In the meantime, we need to know what to look for when
         # writing, so that we can detect UncoordinatedWriteErrors.
-        self._checkstring = self.writers.values()[0][0].get_checkstring()
+        self._checkstring = self._get_some_writer().get_checkstring()
 
         # Now, we start pushing shares.
         self._status.timings["setup"] = time.time() - self._started
@@ -452,7 +454,10 @@ class Publish:
 
         # TODO: Make this part do server selection.
         self.update_goal()
-        self.writers = {}
+
+        # shnum -> set of IMutableSlotWriter
+        self.writers = DictOfSets()
+
         if self._version == MDMF_VERSION:
             writer_class = MDMFSlotWriteProxy
         else:
@@ -476,7 +481,7 @@ class Publish:
                                    self.total_shares,
                                    self.segment_size,
                                    self.datalength)
-            self.writers.setdefault(shnum, []).append(writer)
+            self.writers.add(shnum, writer)
             writer.server = server
             known_shares = self._servermap.get_known_shares()
             if (server, shnum) in known_shares:
@@ -495,7 +500,7 @@ class Publish:
         # after we are done writing share data and have started to write
         # blocks. In the meantime, we need to know what to look for when
         # writing, so that we can detect UncoordinatedWriteErrors.
-        self._checkstring = self.writers.values()[0][0].get_checkstring()
+        self._checkstring = self._get_some_writer().get_checkstring()
 
         # Now, we start pushing shares.
         self._status.timings["setup"] = time.time() - self._started
@@ -521,6 +526,8 @@ class Publish:
 
         return self.done_deferred
 
+    def _get_some_writer(self):
+        return list(self.writers.values()[0])[0]
 
     def _update_status(self):
         self._status.set_status("Sending Shares: %d placed out of %d, "
@@ -622,9 +629,8 @@ class Publish:
         # Can we still successfully publish this file?
         # TODO: Keep track of outstanding queries before aborting the
         #       process.
-        all_shnums = filter(lambda sh: len(self.writers[sh]) > 0,
-                            self.writers.iterkeys())
-        if len(all_shnums) < self.required_shares or self.surprised:
+        num_shnums = len(self.writers)
+        if num_shnums < self.required_shares or self.surprised:
             return self._failure()
 
         # Figure out what we need to do next. Each of these needs to
@@ -835,7 +841,7 @@ class Publish:
         uncoordinated writes. SDMF files will have the same checkstring,
         so we need not do anything.
         """
-        self._checkstring = self.writers.values()[0][0].get_checkstring()
+        self._checkstring = self._get_some_writer().get_checkstring()
 
 
     def _make_and_place_signature(self):
@@ -844,7 +850,7 @@ class Publish:
         """
         started = time.time()
         self._status.set_status("Signing prefix")
-        signable = self.writers.values()[0][0].get_signable()
+        signable = self._get_some_writer().get_signable()
         self.signature = self._privkey.sign(signable)
 
         for (shnum, writers) in self.writers.iteritems():
@@ -881,7 +887,7 @@ class Publish:
 
 
     def _record_verinfo(self):
-        self.versioninfo = self.writers.values()[0][0].get_verinfo()
+        self.versioninfo = self._get_some_writer().get_verinfo()
 
 
     def _connection_problem(self, f, writer):
@@ -891,7 +897,7 @@ class Publish:
         """
         self.log("found problem: %s" % str(f))
         self._last_failure = f
-        self.writers[writer.shnum].remove(writer)
+        self.writers.discard(writer.shnum, writer)
 
 
     def log_goal(self, goal, message=""):