]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/test/no_network.py
Improvements to no_network test harness.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
index 42a90dd5db18f7c6c5c70b45d5754ada3a4f90c5..2dc59381ba57ff559c36338bcb1861429aefe629 100644 (file)
 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
 # or the control.furl .
 
-import os.path
+import os
 from zope.interface import implements
 from twisted.application import service
 from twisted.internet import defer, reactor
 from twisted.python.failure import Failure
 from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
+
+from allmydata.util.assertutil import _assert
+
 from allmydata import uri as tahoe_uri
 from allmydata.client import Client
 from allmydata.storage.server import StorageServer, storage_index_to_dir
 from allmydata.util import fileutil, idlib, hashutil
 from allmydata.util.hashutil import sha1
 from allmydata.test.common_web import HTTPClientGETFactory
-from allmydata.interfaces import IStorageBroker
+from allmydata.interfaces import IStorageBroker, IServer
+from allmydata.test.common import TEST_RSA_KEY_SIZE
+
 
 class IntentionalError(Exception):
     pass
@@ -41,6 +46,10 @@ class LocalWrapper:
         self.hung_until = None
         self.post_call_notifier = None
         self.disconnectors = {}
+        self.counter_by_methname = {}
+
+    def _clear_counters(self):
+        self.counter_by_methname = {}
 
     def callRemoteOnly(self, methname, *args, **kwargs):
         d = self.callRemote(methname, *args, **kwargs)
@@ -60,11 +69,15 @@ class LocalWrapper:
         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
 
         def _really_call():
+            def incr(d, k): d[k] = d.setdefault(k, 0) + 1
+            incr(self.counter_by_methname, methname)
             meth = getattr(self.original, "remote_" + methname)
             return meth(*args, **kwargs)
 
         def _call():
             if self.broken:
+                if self.broken is not True: # a counter, not boolean
+                    self.broken -= 1
                 raise IntentionalError("I was asked to break")
             if self.hung_until:
                 d2 = defer.Deferred()
@@ -118,20 +131,31 @@ def wrap_storage_server(original):
     return wrapper
 
 class NoNetworkServer:
+    implements(IServer)
     def __init__(self, serverid, rref):
         self.serverid = serverid
         self.rref = rref
     def __repr__(self):
-        return "<NoNetworkServer for %s>" % self.name()
+        return "<NoNetworkServer for %s>" % self.get_name()
+    # Special method used by copy.copy() and copy.deepcopy(). When those are
+    # used in allmydata.immutable.filenode to copy CheckResults during
+    # repair, we want it to treat the IServer instances as singletons.
+    def __copy__(self):
+        return self
+    def __deepcopy__(self, memodict):
+        return self
     def get_serverid(self):
         return self.serverid
     def get_permutation_seed(self):
         return self.serverid
     def get_lease_seed(self):
         return self.serverid
-    def name(self):
+    def get_foolscap_write_enabler_seed(self):
+        return self.serverid
+
+    def get_name(self):
         return idlib.shortnodeid_b2a(self.serverid)
-    def longname(self):
+    def get_longname(self):
         return idlib.nodeid_b2a(self.serverid)
     def get_nickname(self):
         return "nickname"
@@ -153,6 +177,9 @@ class NoNetworkStorageBroker:
         return None
 
 class NoNetworkClient(Client):
+
+    def disownServiceParent(self):
+        self.disownServiceParent()
     def create_tub(self):
         pass
     def init_introducer_client(self):
@@ -211,6 +238,7 @@ class NoNetworkGrid(service.MultiService):
         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
                                 # StorageServer
         self.clients = []
+        self.client_config_hooks = client_config_hooks
 
         for i in range(num_servers):
             ss = self.make_server(i)
@@ -218,35 +246,47 @@ class NoNetworkGrid(service.MultiService):
         self.rebuild_serverlist()
 
         for i in range(num_clients):
-            clientid = hashutil.tagged_hash("clientid", str(i))[:20]
-            clientdir = os.path.join(basedir, "clients",
-                                     idlib.shortnodeid_b2a(clientid))
-            fileutil.make_dirs(clientdir)
-            f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
+            c = self.make_client(i)
+            self.clients.append(c)
+
+    def make_client(self, i, write_config=True):
+        clientid = hashutil.tagged_hash("clientid", str(i))[:20]
+        clientdir = os.path.join(self.basedir, "clients",
+                                 idlib.shortnodeid_b2a(clientid))
+        fileutil.make_dirs(clientdir)
+
+        tahoe_cfg_path = os.path.join(clientdir, "tahoe.cfg")
+        if write_config:
+            f = open(tahoe_cfg_path, "w")
             f.write("[node]\n")
             f.write("nickname = client-%d\n" % i)
             f.write("web.port = tcp:0:interface=127.0.0.1\n")
             f.write("[storage]\n")
             f.write("enabled = false\n")
             f.close()
-            c = None
-            if i in client_config_hooks:
-                # this hook can either modify tahoe.cfg, or return an
-                # entirely new Client instance
-                c = client_config_hooks[i](clientdir)
-            if not c:
-                c = NoNetworkClient(clientdir)
-                c.set_default_mutable_keysize(522)
-            c.nodeid = clientid
-            c.short_nodeid = b32encode(clientid).lower()[:8]
-            c._servers = self.all_servers # can be updated later
-            c.setServiceParent(self)
-            self.clients.append(c)
+        else:
+            _assert(os.path.exists(tahoe_cfg_path), tahoe_cfg_path=tahoe_cfg_path)
+
+        c = None
+        if i in self.client_config_hooks:
+            # this hook can either modify tahoe.cfg, or return an
+            # entirely new Client instance
+            c = self.client_config_hooks[i](clientdir)
+
+        if not c:
+            c = NoNetworkClient(clientdir)
+            c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
+
+        c.nodeid = clientid
+        c.short_nodeid = b32encode(clientid).lower()[:8]
+        c._servers = self.all_servers # can be updated later
+        c.setServiceParent(self)
+        return c
 
     def make_server(self, i, readonly=False):
         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
         serverdir = os.path.join(self.basedir, "servers",
-                                 idlib.shortnodeid_b2a(serverid))
+                                 idlib.shortnodeid_b2a(serverid), "storage")
         fileutil.make_dirs(serverdir)
         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
                            readonly_storage=readonly)
@@ -284,11 +324,13 @@ class NoNetworkGrid(service.MultiService):
         del self.wrappers_by_id[serverid]
         del self.proxies_by_id[serverid]
         self.rebuild_serverlist()
+        return ss
 
-    def break_server(self, serverid):
+    def break_server(self, serverid, count=True):
         # mark the given server as broken, so it will throw exceptions when
-        # asked to hold a share or serve a share
-        self.wrappers_by_id[serverid].broken = True
+        # asked to hold a share or serve a share. If count= is a number,
+        # throw that many exceptions before starting to work again.
+        self.wrappers_by_id[serverid].broken = count
 
     def hang_server(self, serverid):
         # hang the given server
@@ -303,6 +345,13 @@ class NoNetworkGrid(service.MultiService):
         ss.hung_until.callback(None)
         ss.hung_until = None
 
+    def nuke_from_orbit(self):
+        """ Empty all share directories in this grid. It's the only way to be sure ;-) """
+        for server in self.servers_by_number.values():
+            for prefixdir in os.listdir(server.sharedir):
+                if prefixdir != 'incoming':
+                    fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
+
 
 class GridTestMixin:
     def setUp(self):
@@ -320,6 +369,9 @@ class GridTestMixin:
                                num_servers=num_servers,
                                client_config_hooks=client_config_hooks)
         self.g.setServiceParent(self.s)
+        self._record_webports_and_baseurls()
+
+    def _record_webports_and_baseurls(self):
         self.client_webports = [c.getServiceNamed("webish").getPortnum()
                                 for c in self.g.clients]
         self.client_baseurls = [c.getServiceNamed("webish").getURL()
@@ -328,6 +380,23 @@ class GridTestMixin:
     def get_clientdir(self, i=0):
         return self.g.clients[i].basedir
 
+    def set_clientdir(self, basedir, i=0):
+        self.g.clients[i].basedir = basedir
+
+    def get_client(self, i=0):
+        return self.g.clients[i]
+
+    def restart_client(self, i=0):
+        client = self.g.clients[i]
+        d = defer.succeed(None)
+        d.addCallback(lambda ign: self.g.removeService(client))
+        def _make_client(ign):
+            c = self.g.make_client(i, write_config=False)
+            self.g.clients[i] = c
+            self._record_webports_and_baseurls()
+        d.addCallback(_make_client)
+        return d
+
     def get_serverdir(self, i):
         return self.g.servers_by_number[i].storedir
 
@@ -342,7 +411,7 @@ class GridTestMixin:
         shares = []
         for i,ss in self.g.servers_by_number.items():
             serverid = ss.my_nodeid
-            basedir = os.path.join(ss.storedir, "shares", prefixdir)
+            basedir = os.path.join(ss.sharedir, prefixdir)
             if not os.path.exists(basedir):
                 continue
             for f in os.listdir(basedir):
@@ -371,6 +440,12 @@ class GridTestMixin:
             if i_shnum in shnums:
                 os.unlink(i_sharefile)
 
+    def delete_all_shares(self, serverdir):
+        sharedir = os.path.join(serverdir, "shares")
+        for prefixdir in os.listdir(sharedir):
+            if prefixdir != 'incoming':
+                fileutil.rm_dir(os.path.join(sharedir, prefixdir))
+
     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
         sharedata = open(sharefile, "rb").read()
         corruptdata = corruptor_function(sharedata)
@@ -404,3 +479,6 @@ class GridTestMixin:
         if return_response:
             d.addCallback(_got_data)
         return factory.deferred
+
+    def PUT(self, urlpath, **kwargs):
+        return self.GET(urlpath, method="PUT", **kwargs)