]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
test_mutable: factor out common setup code
authorBrian Warner <warner@lothar.com>
Wed, 6 Aug 2008 17:38:04 +0000 (10:38 -0700)
committerBrian Warner <warner@lothar.com>
Wed, 6 Aug 2008 17:38:04 +0000 (10:38 -0700)
src/allmydata/test/test_mutable.py

index d4bc5a334b02e527b8374012ef3595c40d3a2d50..efa8fcf64075d5d864b74fd01651559fe057d13d 100644 (file)
@@ -614,19 +614,80 @@ class MakeShares(unittest.TestCase):
     # when we publish to 3 peers, we should get either 3 or 4 shares per peer
     # when we publish to zero peers, we should get a NotEnoughSharesError
 
-class Servermap(unittest.TestCase):
-    def setUp(self):
+class PublishMixin:
+    def publish_one(self):
         # publish a file and create shares, which can then be manipulated
         # later.
+        self.CONTENTS = "New contents go here" * 1000
         num_peers = 20
         self._client = FakeClient(num_peers)
         self._storage = self._client._storage
-        d = self._client.create_mutable_file("New contents go here")
+        d = self._client.create_mutable_file(self.CONTENTS)
         def _created(node):
             self._fn = node
             self._fn2 = self._client.create_node_from_uri(node.get_uri())
         d.addCallback(_created)
         return d
+    def publish_multiple(self):
+        self.CONTENTS = ["Contents 0",
+                         "Contents 1",
+                         "Contents 2",
+                         "Contents 3a",
+                         "Contents 3b"]
+        self._copied_shares = {}
+        num_peers = 20
+        self._client = FakeClient(num_peers)
+        self._storage = self._client._storage
+        d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
+        def _created(node):
+            self._fn = node
+            # now create multiple versions of the same file, and accumulate
+            # their shares, so we can mix and match them later.
+            d = defer.succeed(None)
+            d.addCallback(self._copy_shares, 0)
+            d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
+            d.addCallback(self._copy_shares, 1)
+            d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
+            d.addCallback(self._copy_shares, 2)
+            d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
+            d.addCallback(self._copy_shares, 3)
+            # now we replace all the shares with version s3, and upload a new
+            # version to get s4b.
+            rollback = dict([(i,2) for i in range(10)])
+            d.addCallback(lambda res: self._set_versions(rollback))
+            d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
+            d.addCallback(self._copy_shares, 4)
+            # we leave the storage in state 4
+            return d
+        d.addCallback(_created)
+        return d
+
+    def _copy_shares(self, ignored, index):
+        shares = self._client._storage._peers
+        # we need a deep copy
+        new_shares = {}
+        for peerid in shares:
+            new_shares[peerid] = {}
+            for shnum in shares[peerid]:
+                new_shares[peerid][shnum] = shares[peerid][shnum]
+        self._copied_shares[index] = new_shares
+
+    def _set_versions(self, versionmap):
+        # versionmap maps shnums to which version (0,1,2,3,4) we want the
+        # share to be at. Any shnum which is left out of the map will stay at
+        # its current version.
+        shares = self._client._storage._peers
+        oldshares = self._copied_shares
+        for peerid in shares:
+            for shnum in shares[peerid]:
+                if shnum in versionmap:
+                    index = versionmap[shnum]
+                    shares[peerid][shnum] = oldshares[index][peerid][shnum]
+
+
+class Servermap(unittest.TestCase, PublishMixin):
+    def setUp(self):
+        return self.publish_one()
 
     def make_servermap(self, mode=MODE_CHECK, fn=None):
         if fn is None:
@@ -803,19 +864,9 @@ class Servermap(unittest.TestCase):
 
 
 
-class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin):
+class Roundtrip(unittest.TestCase, testutil.ShouldFailMixin, PublishMixin):
     def setUp(self):
-        # publish a file and create shares, which can then be manipulated
-        # later.
-        self.CONTENTS = "New contents go here" * 1000
-        num_peers = 20
-        self._client = FakeClient(num_peers)
-        self._storage = self._client._storage
-        d = self._client.create_mutable_file(self.CONTENTS)
-        def _created(node):
-            self._fn = node
-        d.addCallback(_created)
-        return d
+        return self.publish_one()
 
     def make_servermap(self, mode=MODE_READ, oldmap=None):
         if oldmap is None:
@@ -1141,19 +1192,9 @@ class CheckerMixin:
                   (where, expected_exception, r.problems))
 
 
-class Checker(unittest.TestCase, CheckerMixin):
+class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
     def setUp(self):
-        # publish a file and create shares, which can then be manipulated
-        # later.
-        self.CONTENTS = "New contents go here" * 1000
-        num_peers = 20
-        self._client = FakeClient(num_peers)
-        self._storage = self._client._storage
-        d = self._client.create_mutable_file(self.CONTENTS)
-        def _created(node):
-            self._fn = node
-        d.addCallback(_created)
-        return d
+        return self.publish_one()
 
 
     def test_check_good(self):
@@ -1244,19 +1285,7 @@ class Checker(unittest.TestCase, CheckerMixin):
                       "test_verify_one_bad_encprivkey_uncheckable")
         return d
 
-class Repair(unittest.TestCase, CheckerMixin):
-    def setUp(self):
-        # publish a file and create shares, which can then be manipulated
-        # later.
-        self.CONTENTS = "New contents go here" * 1000
-        num_peers = 20
-        self._client = FakeClient(num_peers)
-        self._storage = self._client._storage
-        d = self._client.create_mutable_file(self.CONTENTS)
-        def _created(node):
-            self._fn = node
-        d.addCallback(_created)
-        return d
+class Repair(unittest.TestCase, PublishMixin):
 
     def get_shares(self, s):
         all_shares = {} # maps (peerid, shnum) to share data
@@ -1268,27 +1297,29 @@ class Repair(unittest.TestCase, CheckerMixin):
         return all_shares
 
     def test_repair_nop(self):
-        initial_shares = self.get_shares(self._storage)
-
-        d = self._fn.check()
-        d.addCallback(self._fn.repair)
+        d = self.publish_one()
+        def _published(res):
+            self.initial_shares = self.get_shares(self._storage)
+        d.addCallback(_published)
+        d.addCallback(lambda res: self._fn.check())
+        d.addCallback(lambda check_results: self._fn.repair(check_results))
         def _check_results(rres):
             self.failUnless(IRepairResults.providedBy(rres))
             # TODO: examine results
 
             new_shares = self.get_shares(self._storage)
             # all shares should be in the same place as before
-            self.failUnlessEqual(set(initial_shares.keys()),
+            self.failUnlessEqual(set(self.initial_shares.keys()),
                                  set(new_shares.keys()))
             # but they should all be at a newer seqnum. The IV will be
             # different, so the roothash will be too.
-            for key in initial_shares:
+            for key in self.initial_shares:
                 (version0,
                  seqnum0,
                  root_hash0,
                  IV0,
                  k0, N0, segsize0, datalen0,
-                 o0) = unpack_header(initial_shares[key])
+                 o0) = unpack_header(self.initial_shares[key])
                 (version1,
                  seqnum1,
                  root_hash1,
@@ -1447,62 +1478,11 @@ class MultipleEncodings(unittest.TestCase):
         d.addCallback(_retrieved)
         return d
 
-class MultipleVersions(unittest.TestCase, CheckerMixin):
-    def setUp(self):
-        self.CONTENTS = ["Contents 0",
-                         "Contents 1",
-                         "Contents 2",
-                         "Contents 3a",
-                         "Contents 3b"]
-        self._copied_shares = {}
-        num_peers = 20
-        self._client = FakeClient(num_peers)
-        self._storage = self._client._storage
-        d = self._client.create_mutable_file(self.CONTENTS[0]) # seqnum=1
-        def _created(node):
-            self._fn = node
-            # now create multiple versions of the same file, and accumulate
-            # their shares, so we can mix and match them later.
-            d = defer.succeed(None)
-            d.addCallback(self._copy_shares, 0)
-            d.addCallback(lambda res: node.overwrite(self.CONTENTS[1])) #s2
-            d.addCallback(self._copy_shares, 1)
-            d.addCallback(lambda res: node.overwrite(self.CONTENTS[2])) #s3
-            d.addCallback(self._copy_shares, 2)
-            d.addCallback(lambda res: node.overwrite(self.CONTENTS[3])) #s4a
-            d.addCallback(self._copy_shares, 3)
-            # now we replace all the shares with version s3, and upload a new
-            # version to get s4b.
-            rollback = dict([(i,2) for i in range(10)])
-            d.addCallback(lambda res: self._set_versions(rollback))
-            d.addCallback(lambda res: node.overwrite(self.CONTENTS[4])) #s4b
-            d.addCallback(self._copy_shares, 4)
-            # we leave the storage in state 4
-            return d
-        d.addCallback(_created)
-        return d
 
-    def _copy_shares(self, ignored, index):
-        shares = self._client._storage._peers
-        # we need a deep copy
-        new_shares = {}
-        for peerid in shares:
-            new_shares[peerid] = {}
-            for shnum in shares[peerid]:
-                new_shares[peerid][shnum] = shares[peerid][shnum]
-        self._copied_shares[index] = new_shares
+class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
 
-    def _set_versions(self, versionmap):
-        # versionmap maps shnums to which version (0,1,2,3,4) we want the
-        # share to be at. Any shnum which is left out of the map will stay at
-        # its current version.
-        shares = self._client._storage._peers
-        oldshares = self._copied_shares
-        for peerid in shares:
-            for shnum in shares[peerid]:
-                if shnum in versionmap:
-                    index = versionmap[shnum]
-                    shares[peerid][shnum] = oldshares[index][peerid][shnum]
+    def setUp(self):
+        return self.publish_multiple()
 
     def test_multiple_versions(self):
         # if we see a mix of versions in the grid, download_best_version