From: david-sarah Date: Fri, 15 Jun 2012 03:44:37 +0000 (+0000) Subject: Fix a bug in mutable publish that could cause an IndexError when a writer is removed... X-Git-Url: https://git.rkrishnan.org/?p=tahoe-lafs%2Ftahoe-lafs.git;a=commitdiff_plain;h=635e87bd7b460324631f1f6b464fc65009d3cb36 Fix a bug in mutable publish that could cause an IndexError when a writer is removed in Publish._connection_problem. This version uses DictOfSets as suggested by warner. fixes #1749 --- diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py index 3b7c6e05..c87e8978 100644 --- a/src/allmydata/mutable/publish.py +++ b/src/allmydata/mutable/publish.py @@ -253,7 +253,9 @@ class Publish: # updating, we ignore damaged and missing shares -- callers must # do a repair to repair and recreate these. self.goal = set(self._servermap.get_known_shares()) - self.writers = {} + + # shnum -> set of IMutableSlotWriter + self.writers = DictOfSets() # SDMF files are updated differently. self._version = MDMF_VERSION @@ -278,7 +280,7 @@ class Publish: self.segment_size, self.datalength) - self.writers.setdefault(shnum, []).append(writer) + self.writers.add(shnum, writer) writer.server = server known_shares = self._servermap.get_known_shares() assert (server, shnum) in known_shares @@ -294,7 +296,7 @@ class Publish: # after we are done writing share data and have started to write # blocks. In the meantime, we need to know what to look for when # writing, so that we can detect UncoordinatedWriteErrors. - self._checkstring = self.writers.values()[0][0].get_checkstring() + self._checkstring = self._get_some_writer().get_checkstring() # Now, we start pushing shares. self._status.timings["setup"] = time.time() - self._started @@ -452,7 +454,10 @@ class Publish: # TODO: Make this part do server selection. self.update_goal() - self.writers = {} + + # shnum -> set of IMutableSlotWriter + self.writers = DictOfSets() + if self._version == MDMF_VERSION: writer_class = MDMFSlotWriteProxy else: @@ -476,7 +481,7 @@ class Publish: self.total_shares, self.segment_size, self.datalength) - self.writers.setdefault(shnum, []).append(writer) + self.writers.add(shnum, writer) writer.server = server known_shares = self._servermap.get_known_shares() if (server, shnum) in known_shares: @@ -495,7 +500,7 @@ class Publish: # after we are done writing share data and have started to write # blocks. In the meantime, we need to know what to look for when # writing, so that we can detect UncoordinatedWriteErrors. - self._checkstring = self.writers.values()[0][0].get_checkstring() + self._checkstring = self._get_some_writer().get_checkstring() # Now, we start pushing shares. self._status.timings["setup"] = time.time() - self._started @@ -521,6 +526,8 @@ class Publish: return self.done_deferred + def _get_some_writer(self): + return list(self.writers.values()[0])[0] def _update_status(self): self._status.set_status("Sending Shares: %d placed out of %d, " @@ -622,9 +629,8 @@ class Publish: # Can we still successfully publish this file? # TODO: Keep track of outstanding queries before aborting the # process. - all_shnums = filter(lambda sh: len(self.writers[sh]) > 0, - self.writers.iterkeys()) - if len(all_shnums) < self.required_shares or self.surprised: + num_shnums = len(self.writers) + if num_shnums < self.required_shares or self.surprised: return self._failure() # Figure out what we need to do next. Each of these needs to @@ -835,7 +841,7 @@ class Publish: uncoordinated writes. SDMF files will have the same checkstring, so we need not do anything. """ - self._checkstring = self.writers.values()[0][0].get_checkstring() + self._checkstring = self._get_some_writer().get_checkstring() def _make_and_place_signature(self): @@ -844,7 +850,7 @@ class Publish: """ started = time.time() self._status.set_status("Signing prefix") - signable = self.writers.values()[0][0].get_signable() + signable = self._get_some_writer().get_signable() self.signature = self._privkey.sign(signable) for (shnum, writers) in self.writers.iteritems(): @@ -881,7 +887,7 @@ class Publish: def _record_verinfo(self): - self.versioninfo = self.writers.values()[0][0].get_verinfo() + self.versioninfo = self._get_some_writer().get_verinfo() def _connection_problem(self, f, writer): @@ -891,7 +897,7 @@ class Publish: """ self.log("found problem: %s" % str(f)) self._last_failure = f - self.writers[writer.shnum].remove(writer) + self.writers.discard(writer.shnum, writer) def log_goal(self, goal, message=""):