Fixes to test infrastructure.
authorDaira Hopwood <daira@jacaranda.org>
Fri, 17 Apr 2015 19:39:20 +0000 (20:39 +0100)
committerDaira Hopwood <daira@jacaranda.org>
Fri, 17 Apr 2015 21:31:02 +0000 (22:31 +0100)
Signed-off-by: Daira Hopwood <daira@jacaranda.org>
src/allmydata/test/common.py
src/allmydata/test/common_web.py
src/allmydata/test/no_network.py

index 0d581755282c036f9c7d86398f22635185d65871..7d6b794a1ac92c3ebe47792a52bb1c53b4117869 100644 (file)
@@ -35,11 +35,32 @@ def flush_but_dont_ignore(res):
     d.addCallback(_done)
     return d
 
+
 class DummyProducer:
     implements(IPullProducer)
     def resumeProducing(self):
         pass
 
+
+class Marker:
+    pass
+
+class FakeCanary:
+    def __init__(self, ignore_disconnectors=False):
+        self.ignore = ignore_disconnectors
+        self.disconnectors = {}
+    def notifyOnDisconnect(self, f, *args, **kwargs):
+        if self.ignore:
+            return
+        m = Marker()
+        self.disconnectors[m] = (f, args, kwargs)
+        return m
+    def dontNotifyOnDisconnect(self, marker):
+        if self.ignore:
+            return
+        del self.disconnectors[marker]
+
+
 class FakeCHKFileNode:
     """I provide IImmutableFileNode, but all of my data is stored in a
     class-level dictionary."""
@@ -485,6 +506,9 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
         d.addBoth(flush_but_dont_ignore)
         return d
 
+    def workdir(self, name):
+        return os.path.join("system", self.__class__.__name__, name)
+
     def getdir(self, subdir):
         return os.path.join(self.basedir, subdir)
 
@@ -603,11 +627,10 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
             else:
                 config += nodeconfig
 
-            fileutil.write(os.path.join(basedir, 'tahoe.cfg'), config)
+            # give subclasses a chance to append lines to the nodes' tahoe.cfg files.
+            config += self._get_extra_config(i)
 
-        # give subclasses a chance to append lines to the node's tahoe.cfg
-        # files before they are launched.
-        self._set_up_nodes_extra_config()
+            fileutil.write(os.path.join(basedir, 'tahoe.cfg'), config)
 
         # start clients[0], wait for it's tub to be ready (at which point it
         # will have registered the helper furl).
@@ -645,9 +668,9 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
         d.addCallback(_connected)
         return d
 
-    def _set_up_nodes_extra_config(self):
+    def _get_extra_config(self, i):
         # for overriding by subclasses
-        pass
+        return ""
 
     def _grab_stats(self, res):
         d = self.stats_gatherer.poll()
index f6e0ac7a1840d530e6e9f7daf52ca5c793c7b58b..1b064f151e8eb0d858ca9b17720ced0ea0ea930b 100644 (file)
@@ -51,8 +51,9 @@ class WebRenderingMixin:
         ctx = self.make_context(req)
         return page.renderSynchronously(ctx)
 
-    def failUnlessIn(self, substring, s):
-        self.failUnless(substring in s, s)
+    def render_json(self, page):
+        d = self.render1(page, args={"t": ["json"]})
+        return d
 
     def remove_tags(self, s):
         s = re.sub(r'<[^>]*>', ' ', s)
index 98215ddda17779c04841e2c90f2daa233618c924..8c95f0966884649bbdeced6bfc70a0b1eb58898c 100644 (file)
 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
 # or the control.furl .
 
-import os
+import os, shutil
+
 from zope.interface import implements
 from twisted.application import service
 from twisted.internet import defer, reactor
 from twisted.python.failure import Failure
 from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
+
 from allmydata import uri as tahoe_uri
 from allmydata.client import Client
-from allmydata.storage.server import StorageServer, storage_index_to_dir
-from allmydata.util import fileutil, idlib, hashutil
+from allmydata.storage.server import StorageServer
+from allmydata.storage.backends.disk.disk_backend import DiskBackend
+from allmydata.util import fileutil, idlib, hashutil, log
 from allmydata.util.hashutil import sha1
 from allmydata.test.common_web import HTTPClientGETFactory
 from allmydata.interfaces import IStorageBroker, IServer
 from allmydata.test.common import TEST_RSA_KEY_SIZE
 
 
+PRINT_TRACEBACKS = False
+
 class IntentionalError(Exception):
     pass
 
@@ -87,23 +92,34 @@ class LocalWrapper:
                 return d2
             return _really_call()
 
+        if PRINT_TRACEBACKS:
+            import traceback
+            tb = traceback.extract_stack()
         d = fireEventually()
         d.addCallback(lambda res: _call())
         def _wrap_exception(f):
+            if PRINT_TRACEBACKS and not f.check(NameError):
+                print ">>>" + ">>>".join(traceback.format_list(tb))
+                print "+++ %s%r %r: %s" % (methname, args, kwargs, f)
+                #f.printDetailedTraceback()
             return Failure(RemoteException(f))
         d.addErrback(_wrap_exception)
         def _return_membrane(res):
-            # rather than complete the difficult task of building a
+            # Rather than complete the difficult task of building a
             # fully-general Membrane (which would locate all Referenceable
             # objects that cross the simulated wire and replace them with
             # wrappers), we special-case certain methods that we happen to
             # know will return Referenceables.
+            # The outer return value of such a method may be Deferred, but
+            # its components must not be.
             if methname == "allocate_buckets":
                 (alreadygot, allocated) = res
                 for shnum in allocated:
+                    assert not isinstance(allocated[shnum], defer.Deferred), (methname, allocated)
                     allocated[shnum] = LocalWrapper(allocated[shnum])
             if methname == "get_buckets":
                 for shnum in res:
+                    assert not isinstance(res[shnum], defer.Deferred), (methname, res)
                     res[shnum] = LocalWrapper(res[shnum])
             return res
         d.addCallback(_return_membrane)
@@ -168,11 +184,20 @@ class NoNetworkStorageBroker:
             seed = server.get_permutation_seed()
             return sha1(peer_selection_index + seed).digest()
         return sorted(self.get_connected_servers(), key=_permuted)
+
     def get_connected_servers(self):
         return self.client._servers
+
     def get_nickname_for_serverid(self, serverid):
         return None
 
+    def get_known_servers(self):
+        return self.get_connected_servers()
+
+    def get_all_serverids(self):
+        return self.client.get_all_serverids()
+
+
 class NoNetworkClient(Client):
     def create_tub(self):
         pass
@@ -234,8 +259,8 @@ class NoNetworkGrid(service.MultiService):
         self.clients = []
 
         for i in range(num_servers):
-            ss = self.make_server(i)
-            self.add_server(i, ss)
+            server = self.make_server(i)
+            self.add_server(i, server)
         self.rebuild_serverlist()
 
         for i in range(num_clients):
@@ -266,23 +291,24 @@ class NoNetworkGrid(service.MultiService):
 
     def make_server(self, i, readonly=False):
         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
-        serverdir = os.path.join(self.basedir, "servers",
-                                 idlib.shortnodeid_b2a(serverid), "storage")
-        fileutil.make_dirs(serverdir)
-        ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
-                           readonly_storage=readonly)
-        ss._no_network_server_number = i
-        return ss
-
-    def add_server(self, i, ss):
+        storagedir = os.path.join(self.basedir, "servers",
+                                  idlib.shortnodeid_b2a(serverid), "storage")
+
+        # The backend will make the storage directory and any necessary parents.
+        backend = DiskBackend(storagedir, readonly=readonly)
+        server = StorageServer(serverid, backend, storagedir, stats_provider=SimpleStats())
+        server._no_network_server_number = i
+        return server
+
+    def add_server(self, i, server):
         # to deal with the fact that all StorageServers are named 'storage',
         # we interpose a middleman
         middleman = service.MultiService()
         middleman.setServiceParent(self)
-        ss.setServiceParent(middleman)
-        serverid = ss.my_nodeid
-        self.servers_by_number[i] = ss
-        aa = ss.get_accountant().get_anonymous_account()
+        server.setServiceParent(middleman)
+        serverid = server.get_serverid()
+        self.servers_by_number[i] = server
+        aa = server.get_accountant().get_anonymous_account()
         wrapper = wrap_storage_server(aa)
         self.wrappers_by_id[serverid] = wrapper
         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
@@ -299,14 +325,14 @@ class NoNetworkGrid(service.MultiService):
     def remove_server(self, serverid):
         # it's enough to remove the server from c._servers (we don't actually
         # have to detach and stopService it)
-        for i,ss in self.servers_by_number.items():
-            if ss.my_nodeid == serverid:
+        for i, server in self.servers_by_number.items():
+            if server.get_serverid() == serverid:
                 del self.servers_by_number[i]
                 break
         del self.wrappers_by_id[serverid]
         del self.proxies_by_id[serverid]
         self.rebuild_serverlist()
-        return ss
+        return server
 
     def break_server(self, serverid, count=True):
         # mark the given server as broken, so it will throw exceptions when
@@ -316,16 +342,16 @@ class NoNetworkGrid(service.MultiService):
 
     def hang_server(self, serverid):
         # hang the given server
-        ss = self.wrappers_by_id[serverid]
-        assert ss.hung_until is None
-        ss.hung_until = defer.Deferred()
+        server = self.wrappers_by_id[serverid]
+        assert server.hung_until is None
+        server.hung_until = defer.Deferred()
 
     def unhang_server(self, serverid):
         # unhang the given server
-        ss = self.wrappers_by_id[serverid]
-        assert ss.hung_until is not None
-        ss.hung_until.callback(None)
-        ss.hung_until = None
+        server = self.wrappers_by_id[serverid]
+        assert server.hung_until is not None
+        server.hung_until.callback(None)
+        server.hung_until = None
 
     def nuke_from_orbit(self):
         """Empty all share directories in this grid. It's the only way to be sure ;-)
@@ -361,48 +387,95 @@ class GridTestMixin:
     def get_clientdir(self, i=0):
         return self.g.clients[i].basedir
 
+    def get_server(self, i):
+        return self.g.servers_by_number[i]
+
     def get_serverdir(self, i):
-        return self.g.servers_by_number[i].storedir
+        return self.g.servers_by_number[i].backend._storedir
+
+    def remove_server(self, i):
+        self.g.remove_server(self.g.servers_by_number[i].get_serverid())
 
     def iterate_servers(self):
         for i in sorted(self.g.servers_by_number.keys()):
-            ss = self.g.servers_by_number[i]
-            yield (i, ss, ss.storedir)
+            server = self.g.servers_by_number[i]
+            yield (i, server, server.backend._storedir)
 
     def find_uri_shares(self, uri):
         si = tahoe_uri.from_string(uri).get_storage_index()
-        prefixdir = storage_index_to_dir(si)
-        shares = []
-        for i,ss in self.g.servers_by_number.items():
-            serverid = ss.my_nodeid
-            basedir = os.path.join(ss.sharedir, prefixdir)
-            if not os.path.exists(basedir):
-                continue
-            for f in os.listdir(basedir):
-                try:
-                    shnum = int(f)
-                    shares.append((shnum, serverid, os.path.join(basedir, f)))
-                except ValueError:
-                    pass
-        return sorted(shares)
+        sharelist = []
+        d = defer.succeed(None)
+        for i, server in self.g.servers_by_number.items():
+            d.addCallback(lambda ign, server=server: server.backend.get_shareset(si).get_shares())
+            def _append_shares( (shares_for_server, corrupted), server=server):
+                assert len(corrupted) == 0, (shares_for_server, corrupted)
+                for share in shares_for_server:
+                    assert not isinstance(share, defer.Deferred), share
+                    sharelist.append( (share.get_shnum(), server.get_serverid(), share._get_path()) )
+            d.addCallback(_append_shares)
+
+        d.addCallback(lambda ign: sorted(sharelist))
+        return d
+
+    def add_server(self, server_number, readonly=False):
+        assert self.g, "I tried to find a grid at self.g, but failed"
+        ss = self.g.make_server(server_number, readonly)
+        log.msg("just created a server, number: %s => %s" % (server_number, ss,))
+        self.g.add_server(server_number, ss)
+
+    def add_server_with_share(self, uri, server_number, share_number=None,
+                              readonly=False):
+        self.add_server(server_number, readonly)
+        if share_number is not None:
+            self.copy_share_to_server(uri, server_number, share_number)
+
+    def copy_share_to_server(self, uri, server_number, share_number):
+        ss = self.g.servers_by_number[server_number]
+        self.copy_share(self.shares[share_number], uri, ss)
 
     def copy_shares(self, uri):
         shares = {}
-        for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
-            shares[sharefile] = open(sharefile, "rb").read()
-        return shares
+        d = self.find_uri_shares(uri)
+        def _got_shares(sharelist):
+            for (shnum, serverid, sharefile) in sharelist:
+                shares[sharefile] = fileutil.read(sharefile)
+
+            return shares
+        d.addCallback(_got_shares)
+        return d
+
+    def copy_share(self, from_share, uri, to_server):
+        si = tahoe_uri.from_string(uri).get_storage_index()
+        (i_shnum, i_serverid, i_sharefile) = from_share
+        shares_dir = to_server.backend.get_shareset(si)._get_sharedir()
+        new_sharefile = os.path.join(shares_dir, str(i_shnum))
+        fileutil.make_dirs(shares_dir)
+        if os.path.normpath(i_sharefile) != os.path.normpath(new_sharefile):
+            shutil.copy(i_sharefile, new_sharefile)
 
     def restore_all_shares(self, shares):
         for sharefile, data in shares.items():
-            open(sharefile, "wb").write(data)
+            fileutil.write(sharefile, data)
 
     def delete_share(self, (shnum, serverid, sharefile)):
-        os.unlink(sharefile)
+        fileutil.remove(sharefile)
 
     def delete_shares_numbered(self, uri, shnums):
-        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
-            if i_shnum in shnums:
-                os.unlink(i_sharefile)
+        d = self.find_uri_shares(uri)
+        def _got_shares(sharelist):
+            for (i_shnum, i_serverid, i_sharefile) in sharelist:
+                if i_shnum in shnums:
+                    fileutil.remove(i_sharefile)
+        d.addCallback(_got_shares)
+        return d
+
+    def delete_all_shares(self, uri):
+        d = self.find_uri_shares(uri)
+        def _got_shares(shares):
+            for sh in shares:
+                self.delete_share(sh)
+        d.addCallback(_got_shares)
+        return d
 
     def empty_sharedir(self, serverdir):
         sharedir = os.path.join(serverdir, "shares")
@@ -410,23 +483,27 @@ class GridTestMixin:
             if prefixdir != 'incoming':
                 fileutil.rm_dir(os.path.join(sharedir, prefixdir))
 
-    def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
-        sharedata = open(sharefile, "rb").read()
-        corruptdata = corruptor_function(sharedata)
-        open(sharefile, "wb").write(corruptdata)
+    def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function, debug=False):
+        sharedata = fileutil.read(sharefile)
+        corruptdata = corruptor_function(sharedata, debug=debug)
+        fileutil.write(sharefile, corruptdata)
 
     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
-        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
-            if i_shnum in shnums:
-                sharedata = open(i_sharefile, "rb").read()
-                corruptdata = corruptor(sharedata, debug=debug)
-                open(i_sharefile, "wb").write(corruptdata)
+        d = self.find_uri_shares(uri)
+        def _got_shares(sharelist):
+            for (i_shnum, i_serverid, i_sharefile) in sharelist:
+                if i_shnum in shnums:
+                    self.corrupt_share((i_shnum, i_serverid, i_sharefile), corruptor, debug=debug)
+        d.addCallback(_got_shares)
+        return d
 
     def corrupt_all_shares(self, uri, corruptor, debug=False):
-        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
-            sharedata = open(i_sharefile, "rb").read()
-            corruptdata = corruptor(sharedata, debug=debug)
-            open(i_sharefile, "wb").write(corruptdata)
+        d = self.find_uri_shares(uri)
+        def _got_shares(sharelist):
+            for (i_shnum, i_serverid, i_sharefile) in sharelist:
+                self.corrupt_share((i_shnum, i_serverid, i_sharefile), corruptor, debug=debug)
+        d.addCallback(_got_shares)
+        return d
 
     def GET(self, urlpath, followRedirect=False, return_response=False,
             method="GET", clientnum=0, **kwargs):