]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/test/no_network.py
test_mutable.Version: exercise 'tahoe debug find-shares' on MDMF. refs #1507
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
index 2b075883b42d7986dbcb967b406571082cf0a3a5..abc87b182d76eb08bb67ea17ee44b846369a7407 100644 (file)
 # or the control.furl .
 
 import os.path
-import sha
+from zope.interface import implements
 from twisted.application import service
-from foolscap import Referenceable
-from foolscap.eventual import fireEventually
+from twisted.internet import defer, reactor
+from twisted.python.failure import Failure
+from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
+from allmydata import uri as tahoe_uri
 from allmydata.client import Client
-from allmydata.storage import StorageServer
-from allmydata.util import fileutil, idlib, hashutil, rrefutil
-from allmydata.introducer.client import RemoteServiceConnector
+from allmydata.storage.server import StorageServer, storage_index_to_dir
+from allmydata.util import fileutil, idlib, hashutil
+from allmydata.util.hashutil import sha1
+from allmydata.test.common_web import HTTPClientGETFactory
+from allmydata.interfaces import IStorageBroker
+from allmydata.test.common import TEST_RSA_KEY_SIZE
+
 
 class IntentionalError(Exception):
     pass
@@ -34,11 +40,13 @@ class LocalWrapper:
     def __init__(self, original):
         self.original = original
         self.broken = False
+        self.hung_until = None
         self.post_call_notifier = None
         self.disconnectors = {}
 
     def callRemoteOnly(self, methname, *args, **kwargs):
         d = self.callRemote(methname, *args, **kwargs)
+        del d # explicitly ignored
         return None
 
     def callRemote(self, methname, *args, **kwargs):
@@ -52,13 +60,30 @@ class LocalWrapper:
                 return a
         args = tuple([wrap(a) for a in args])
         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
+
+        def _really_call():
+            meth = getattr(self.original, "remote_" + methname)
+            return meth(*args, **kwargs)
+
         def _call():
             if self.broken:
                 raise IntentionalError("I was asked to break")
-            meth = getattr(self.original, "remote_" + methname)
-            return meth(*args, **kwargs)
+            if self.hung_until:
+                d2 = defer.Deferred()
+                self.hung_until.addCallback(lambda ign: _really_call())
+                self.hung_until.addCallback(lambda res: d2.callback(res))
+                def _err(res):
+                    d2.errback(res)
+                    return res
+                self.hung_until.addErrback(_err)
+                return d2
+            return _really_call()
+
         d = fireEventually()
         d.addCallback(lambda res: _call())
+        def _wrap_exception(f):
+            return Failure(RemoteException(f))
+        d.addErrback(_wrap_exception)
         def _return_membrane(res):
             # rather than complete the difficult task of building a
             # fully-general Membrane (which would locate all Referenceable
@@ -75,7 +100,7 @@ class LocalWrapper:
             return res
         d.addCallback(_return_membrane)
         if self.post_call_notifier:
-            d.addCallback(self.post_call_notifier, methname)
+            d.addCallback(self.post_call_notifier, self, methname)
         return d
 
     def notifyOnDisconnect(self, f, *args, **kwargs):
@@ -85,24 +110,51 @@ class LocalWrapper:
     def dontNotifyOnDisconnect(self, marker):
         del self.disconnectors[marker]
 
-def wrap(original, service_name):
-    # The code in immutable.checker insists upon asserting the truth of
-    # isinstance(rref, rrefutil.WrappedRemoteReference). Much of the
-    # upload/download code uses rref.version (which normally comes from
-    # rrefutil.VersionedRemoteReference). To avoid using a network, we want a
-    # LocalWrapper here. Try to satisfy all these constraints at the same
-    # time.
-    local = LocalWrapper(original)
-    wrapped = rrefutil.WrappedRemoteReference(local)
-    try:
-        version = original.remote_get_version()
-    except AttributeError:
-        version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
-    wrapped.version = version
-    return wrapped
+def wrap_storage_server(original):
+    # Much of the upload/download code uses rref.version (which normally
+    # comes from rrefutil.add_version_to_remote_reference). To avoid using a
+    # network, we want a LocalWrapper here. Try to satisfy all these
+    # constraints at the same time.
+    wrapper = LocalWrapper(original)
+    wrapper.version = original.remote_get_version()
+    return wrapper
 
-class NoNetworkClient(Client):
+class NoNetworkServer:
+    def __init__(self, serverid, rref):
+        self.serverid = serverid
+        self.rref = rref
+    def __repr__(self):
+        return "<NoNetworkServer for %s>" % self.get_name()
+    def get_serverid(self):
+        return self.serverid
+    def get_permutation_seed(self):
+        return self.serverid
+    def get_lease_seed(self):
+        return self.serverid
+    def get_name(self):
+        return idlib.shortnodeid_b2a(self.serverid)
+    def get_longname(self):
+        return idlib.nodeid_b2a(self.serverid)
+    def get_nickname(self):
+        return "nickname"
+    def get_rref(self):
+        return self.rref
+    def get_version(self):
+        return self.rref.version
+
+class NoNetworkStorageBroker:
+    implements(IStorageBroker)
+    def get_servers_for_psi(self, peer_selection_index):
+        def _permuted(server):
+            seed = server.get_permutation_seed()
+            return sha1(peer_selection_index + seed).digest()
+        return sorted(self.get_connected_servers(), key=_permuted)
+    def get_connected_servers(self):
+        return self.client._servers
+    def get_nickname_for_serverid(self, serverid):
+        return None
 
+class NoNetworkClient(Client):
     def create_tub(self):
         pass
     def init_introducer_client(self):
@@ -114,7 +166,7 @@ class NoNetworkClient(Client):
     def stopService(self):
         service.MultiService.stopService(self)
     def when_tub_ready(self):
-        raise RuntimeError("NoNetworkClient has no Tub")
+        raise NotImplementedError("NoNetworkClient has no Tub")
     def init_control(self):
         pass
     def init_helper(self):
@@ -123,15 +175,31 @@ class NoNetworkClient(Client):
         pass
     def init_storage(self):
         pass
+    def init_client_storage_broker(self):
+        self.storage_broker = NoNetworkStorageBroker()
+        self.storage_broker.client = self
     def init_stub_client(self):
         pass
+    #._servers will be set by the NoNetworkGrid which creates us
 
-    def get_servers(self, service_name):
-        return self._servers
+class SimpleStats:
+    def __init__(self):
+        self.counters = {}
+        self.stats_producers = []
 
-    def get_permuted_peers(self, service_name, key):
-        return sorted(self._servers, key=lambda x: sha.new(key+x[0]).digest())
+    def count(self, name, delta=1):
+        val = self.counters.setdefault(name, 0)
+        self.counters[name] = val + delta
 
+    def register_producer(self, stats_producer):
+        self.stats_producers.append(stats_producer)
+
+    def get_stats(self):
+        stats = {}
+        for sp in self.stats_producers:
+            stats.update(sp.get_stats())
+        ret = { 'counters': self.counters, 'stats': stats }
+        return ret
 
 class NoNetworkGrid(service.MultiService):
     def __init__(self, basedir, num_clients=1, num_servers=10,
@@ -140,17 +208,16 @@ class NoNetworkGrid(service.MultiService):
         self.basedir = basedir
         fileutil.make_dirs(basedir)
 
-        self.servers_by_number = {}
-        self.servers_by_id = {}
+        self.servers_by_number = {} # maps to StorageServer instance
+        self.wrappers_by_id = {} # maps to wrapped StorageServer instance
+        self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
+                                # StorageServer
         self.clients = []
 
         for i in range(num_servers):
-            serverid = hashutil.tagged_hash("serverid", str(i))[:20]
-            serverdir = os.path.join(basedir, "servers",
-                                     idlib.shortnodeid_b2a(serverid))
-            fileutil.make_dirs(serverdir)
-            ss = StorageServer(serverdir)
-            self.add_server(i, serverid, ss)
+            ss = self.make_server(i)
+            self.add_server(i, ss)
+        self.rebuild_serverlist()
 
         for i in range(num_clients):
             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
@@ -171,24 +238,74 @@ class NoNetworkGrid(service.MultiService):
                 c = client_config_hooks[i](clientdir)
             if not c:
                 c = NoNetworkClient(clientdir)
+                c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
             c.nodeid = clientid
             c.short_nodeid = b32encode(clientid).lower()[:8]
             c._servers = self.all_servers # can be updated later
             c.setServiceParent(self)
             self.clients.append(c)
 
-    def add_server(self, i, serverid, ss):
-        # TODO: ss.setServiceParent(self), but first remove the goofy
-        # self.parent.nodeid from Storage.startService . At the moment,
-        # Storage doesn't really need to be startService'd, but it will in
-        # the future.
-        ss.setNodeID(serverid)
+    def make_server(self, i, readonly=False):
+        serverid = hashutil.tagged_hash("serverid", str(i))[:20]
+        serverdir = os.path.join(self.basedir, "servers",
+                                 idlib.shortnodeid_b2a(serverid), "storage")
+        fileutil.make_dirs(serverdir)
+        ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
+                           readonly_storage=readonly)
+        ss._no_network_server_number = i
+        return ss
+
+    def add_server(self, i, ss):
+        # to deal with the fact that all StorageServers are named 'storage',
+        # we interpose a middleman
+        middleman = service.MultiService()
+        middleman.setServiceParent(self)
+        ss.setServiceParent(middleman)
+        serverid = ss.my_nodeid
         self.servers_by_number[i] = ss
-        self.servers_by_id[serverid] = wrap(ss, "storage")
-        self.all_servers = frozenset(self.servers_by_id.items())
+        wrapper = wrap_storage_server(ss)
+        self.wrappers_by_id[serverid] = wrapper
+        self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
+        self.rebuild_serverlist()
+
+    def get_all_serverids(self):
+        return self.proxies_by_id.keys()
+
+    def rebuild_serverlist(self):
+        self.all_servers = frozenset(self.proxies_by_id.values())
         for c in self.clients:
             c._servers = self.all_servers
 
+    def remove_server(self, serverid):
+        # it's enough to remove the server from c._servers (we don't actually
+        # have to detach and stopService it)
+        for i,ss in self.servers_by_number.items():
+            if ss.my_nodeid == serverid:
+                del self.servers_by_number[i]
+                break
+        del self.wrappers_by_id[serverid]
+        del self.proxies_by_id[serverid]
+        self.rebuild_serverlist()
+
+    def break_server(self, serverid):
+        # mark the given server as broken, so it will throw exceptions when
+        # asked to hold a share or serve a share
+        self.wrappers_by_id[serverid].broken = True
+
+    def hang_server(self, serverid):
+        # hang the given server
+        ss = self.wrappers_by_id[serverid]
+        assert ss.hung_until is None
+        ss.hung_until = defer.Deferred()
+
+    def unhang_server(self, serverid):
+        # unhang the given server
+        ss = self.wrappers_by_id[serverid]
+        assert ss.hung_until is not None
+        ss.hung_until.callback(None)
+        ss.hung_until = None
+
+
 class GridTestMixin:
     def setUp(self):
         self.s = service.MultiService()
@@ -197,11 +314,18 @@ class GridTestMixin:
     def tearDown(self):
         return self.s.stopService()
 
-    def set_up_grid(self, client_config_hooks={}):
+    def set_up_grid(self, num_clients=1, num_servers=10,
+                    client_config_hooks={}):
         # self.basedir must be set
         self.g = NoNetworkGrid(self.basedir,
+                               num_clients=num_clients,
+                               num_servers=num_servers,
                                client_config_hooks=client_config_hooks)
         self.g.setServiceParent(self.s)
+        self.client_webports = [c.getServiceNamed("webish").getPortnum()
+                                for c in self.g.clients]
+        self.client_baseurls = [c.getServiceNamed("webish").getURL()
+                                for c in self.g.clients]
 
     def get_clientdir(self, i=0):
         return self.g.clients[i].basedir
@@ -213,3 +337,75 @@ class GridTestMixin:
         for i in sorted(self.g.servers_by_number.keys()):
             ss = self.g.servers_by_number[i]
             yield (i, ss, ss.storedir)
+
+    def find_uri_shares(self, uri):
+        si = tahoe_uri.from_string(uri).get_storage_index()
+        prefixdir = storage_index_to_dir(si)
+        shares = []
+        for i,ss in self.g.servers_by_number.items():
+            serverid = ss.my_nodeid
+            basedir = os.path.join(ss.sharedir, prefixdir)
+            if not os.path.exists(basedir):
+                continue
+            for f in os.listdir(basedir):
+                try:
+                    shnum = int(f)
+                    shares.append((shnum, serverid, os.path.join(basedir, f)))
+                except ValueError:
+                    pass
+        return sorted(shares)
+
+    def copy_shares(self, uri):
+        shares = {}
+        for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
+            shares[sharefile] = open(sharefile, "rb").read()
+        return shares
+
+    def restore_all_shares(self, shares):
+        for sharefile, data in shares.items():
+            open(sharefile, "wb").write(data)
+
+    def delete_share(self, (shnum, serverid, sharefile)):
+        os.unlink(sharefile)
+
+    def delete_shares_numbered(self, uri, shnums):
+        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
+            if i_shnum in shnums:
+                os.unlink(i_sharefile)
+
+    def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
+        sharedata = open(sharefile, "rb").read()
+        corruptdata = corruptor_function(sharedata)
+        open(sharefile, "wb").write(corruptdata)
+
+    def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
+        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
+            if i_shnum in shnums:
+                sharedata = open(i_sharefile, "rb").read()
+                corruptdata = corruptor(sharedata, debug=debug)
+                open(i_sharefile, "wb").write(corruptdata)
+
+    def corrupt_all_shares(self, uri, corruptor, debug=False):
+        for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
+            sharedata = open(i_sharefile, "rb").read()
+            corruptdata = corruptor(sharedata, debug=debug)
+            open(i_sharefile, "wb").write(corruptdata)
+
+    def GET(self, urlpath, followRedirect=False, return_response=False,
+            method="GET", clientnum=0, **kwargs):
+        # if return_response=True, this fires with (data, statuscode,
+        # respheaders) instead of just data.
+        assert not isinstance(urlpath, unicode)
+        url = self.client_baseurls[clientnum] + urlpath
+        factory = HTTPClientGETFactory(url, method=method,
+                                       followRedirect=followRedirect, **kwargs)
+        reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
+        d = factory.deferred
+        def _got_data(data):
+            return (data, factory.status, factory.response_headers)
+        if return_response:
+            d.addCallback(_got_data)
+        return factory.deferred
+
+    def PUT(self, urlpath, **kwargs):
+        return self.GET(urlpath, method="PUT", **kwargs)