From c7c57bd85c5fbfda00d06266dc4c1997137a03f5 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Wed, 6 Aug 2008 12:06:07 -0700
Subject: [PATCH] mutable: more repair tests, one with force=True to check out
 merging

---
 src/allmydata/mutable/node.py      |  4 +-
 src/allmydata/test/test_mutable.py | 81 +++++++++++++++++++++++++++---
 2 files changed, 75 insertions(+), 10 deletions(-)

diff --git a/src/allmydata/mutable/node.py b/src/allmydata/mutable/node.py
index 2d4b2026..e850c675 100644
--- a/src/allmydata/mutable/node.py
+++ b/src/allmydata/mutable/node.py
@@ -259,10 +259,10 @@ class MutableFileNode:
     #################################
     # IRepairable
 
-    def repair(self, checker_results):
+    def repair(self, checker_results, force=False):
         assert ICheckerResults(checker_results)
         r = Repairer(self, checker_results)
-        d = r.start()
+        d = r.start(force)
         return d
 
 
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 33f33c7e..6abfe1bd 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -13,6 +13,7 @@ from allmydata.util.hashutil import tagged_hash
 from allmydata.util.fileutil import make_dirs
 from allmydata.interfaces import IURI, IMutableFileURI, IUploadable, \
      FileTooLargeError, IRepairResults
+from allmydata.test.common import ShouldFailMixin
 from foolscap.eventual import eventually, fireEventually
 from foolscap.logging import log
 import sha
@@ -26,6 +27,7 @@ from allmydata.mutable.retrieve import Retrieve
 from allmydata.mutable.publish import Publish
 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
 from allmydata.mutable.layout import unpack_header, unpack_share
+from allmydata.mutable.repair import MustForceRepairError
 
 # this "FastMutableFileNode" exists solely to speed up tests by using smaller
 # public/private keys. Once we switch to fast DSA-based keys, we can get rid
@@ -1285,7 +1287,7 @@ class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
                       "test_verify_one_bad_encprivkey_uncheckable")
         return d
 
-class Repair(unittest.TestCase, PublishMixin):
+class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
 
     def get_shares(self, s):
         all_shares = {} # maps (peerid, shnum) to share data
@@ -1296,34 +1298,39 @@ class Repair(unittest.TestCase, PublishMixin):
                 all_shares[ (peerid, shnum) ] = data
         return all_shares
 
+    def copy_shares(self, ignored=None):
+        self.old_shares.append(self.get_shares(self._storage))
+
     def test_repair_nop(self):
+        self.old_shares = []
         d = self.publish_one()
-        def _published(res):
-            self.initial_shares = self.get_shares(self._storage)
-        d.addCallback(_published)
+        d.addCallback(self.copy_shares)
         d.addCallback(lambda res: self._fn.check())
         d.addCallback(lambda check_results: self._fn.repair(check_results))
         def _check_results(rres):
             self.failUnless(IRepairResults.providedBy(rres))
             # TODO: examine results
 
-            new_shares = self.get_shares(self._storage)
+            self.copy_shares()
+
+            initial_shares = self.old_shares[0]
+            new_shares = self.old_shares[1]
             # TODO: this really shouldn't change anything. When we implement
             # a "minimal-bandwidth" repairer", change this test to assert:
             #self.failUnlessEqual(new_shares, initial_shares)
 
             # all shares should be in the same place as before
-            self.failUnlessEqual(set(self.initial_shares.keys()),
+            self.failUnlessEqual(set(initial_shares.keys()),
                                  set(new_shares.keys()))
             # but they should all be at a newer seqnum. The IV will be
             # different, so the roothash will be too.
-            for key in self.initial_shares:
+            for key in initial_shares:
                 (version0,
                  seqnum0,
                  root_hash0,
                  IV0,
                  k0, N0, segsize0, datalen0,
-                 o0) = unpack_header(self.initial_shares[key])
+                 o0) = unpack_header(initial_shares[key])
                 (version1,
                  seqnum1,
                  root_hash1,
@@ -1339,6 +1346,64 @@ class Repair(unittest.TestCase, PublishMixin):
         d.addCallback(_check_results)
         return d
 
+    def failIfSharesChanged(self, ignored=None):
+        old_shares = self.old_shares[-2]
+        current_shares = self.old_shares[-1]
+        self.failUnlessEqual(old_shares, current_shares)
+
+    def test_merge(self):
+        self.old_shares = []
+        d = self.publish_multiple()
+        # repair will refuse to merge multiple highest seqnums unless you
+        # pass force=True
+        d.addCallback(lambda res:
+                      self._set_versions({0:3,2:3,4:3,6:3,8:3,
+                                          1:4,3:4,5:4,7:4,9:4}))
+        d.addCallback(self.copy_shares)
+        d.addCallback(lambda res: self._fn.check())
+        def _try_repair(check_results):
+            ex = "There were multiple recoverable versions with identical seqnums, so force=True must be passed to the repair() operation"
+            d2 = self.shouldFail(MustForceRepairError, "test_merge", ex,
+                                 self._fn.repair, check_results)
+            d2.addCallback(self.copy_shares)
+            d2.addCallback(self.failIfSharesChanged)
+            d2.addCallback(lambda res: check_results)
+            return d2
+        d.addCallback(_try_repair)
+        d.addCallback(lambda check_results:
+                      self._fn.repair(check_results, force=True))
+        # this should give us 10 shares of the highest roothash
+        def _check_repair_results(rres):
+            pass # TODO
+        d.addCallback(_check_repair_results)
+        d.addCallback(lambda res: self._fn.get_servermap(MODE_CHECK))
+        def _check_smap(smap):
+            self.failUnlessEqual(len(smap.recoverable_versions()), 1)
+            self.failIf(smap.unrecoverable_versions())
+            # now, which should have won?
+            roothash_s4a = self.get_roothash_for(3)
+            roothash_s4b = self.get_roothash_for(4)
+            if roothash_s4b > roothash_s4a:
+                expected_contents = self.CONTENTS[4]
+            else:
+                expected_contents = self.CONTENTS[3]
+            new_versionid = smap.best_recoverable_version()
+            self.failUnlessEqual(new_versionid[0], 5) # seqnum 5
+            d2 = self._fn.download_version(smap, new_versionid)
+            d2.addCallback(self.failUnlessEqual, expected_contents)
+            return d2
+        d.addCallback(_check_smap)
+        return d
+
+    def get_roothash_for(self, index):
+        # return the roothash for the first share we see in the saved set
+        shares = self._copied_shares[index]
+        for peerid in shares:
+            for shnum in shares[peerid]:
+                share = shares[peerid][shnum]
+                (version, seqnum, root_hash, IV, k, N, segsize, datalen, o) = \
+                          unpack_header(share)
+                return root_hash
 
 class MultipleEncodings(unittest.TestCase):
     def setUp(self):
-- 
2.45.2