import os.path
from zope.interface import implements
from twisted.application import service
-from twisted.internet import reactor
+from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from foolscap.api import Referenceable, fireEventually, RemoteException
from base64 import b32encode
from allmydata.util import fileutil, idlib, hashutil
from allmydata.util.hashutil import sha1
from allmydata.test.common_web import HTTPClientGETFactory
-from allmydata.interfaces import IStorageBroker
+from allmydata.interfaces import IStorageBroker, IServer
+from allmydata.test.common import TEST_RSA_KEY_SIZE
+
class IntentionalError(Exception):
pass
def __init__(self, original):
self.original = original
self.broken = False
+ self.hung_until = None
self.post_call_notifier = None
self.disconnectors = {}
+ self.counter_by_methname = {}
+
+ def _clear_counters(self):
+ self.counter_by_methname = {}
def callRemoteOnly(self, methname, *args, **kwargs):
d = self.callRemote(methname, *args, **kwargs)
+ del d # explicitly ignored
return None
def callRemote(self, methname, *args, **kwargs):
return a
args = tuple([wrap(a) for a in args])
kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
+
+ def _really_call():
+ def incr(d, k): d[k] = d.setdefault(k, 0) + 1
+ incr(self.counter_by_methname, methname)
+ meth = getattr(self.original, "remote_" + methname)
+ return meth(*args, **kwargs)
+
def _call():
if self.broken:
+ if self.broken is not True: # a counter, not boolean
+ self.broken -= 1
raise IntentionalError("I was asked to break")
- meth = getattr(self.original, "remote_" + methname)
- return meth(*args, **kwargs)
+ if self.hung_until:
+ d2 = defer.Deferred()
+ self.hung_until.addCallback(lambda ign: _really_call())
+ self.hung_until.addCallback(lambda res: d2.callback(res))
+ def _err(res):
+ d2.errback(res)
+ return res
+ self.hung_until.addErrback(_err)
+ return d2
+ return _really_call()
+
d = fireEventually()
d.addCallback(lambda res: _call())
def _wrap_exception(f):
wrapper.version = original.remote_get_version()
return wrapper
+class NoNetworkServer:
+ implements(IServer)
+ def __init__(self, serverid, rref):
+ self.serverid = serverid
+ self.rref = rref
+ def __repr__(self):
+ return "<NoNetworkServer for %s>" % self.get_name()
+ # Special method used by copy.copy() and copy.deepcopy(). When those are
+ # used in allmydata.immutable.filenode to copy CheckResults during
+ # repair, we want it to treat the IServer instances as singletons.
+ def __copy__(self):
+ return self
+ def __deepcopy__(self, memodict):
+ return self
+ def get_serverid(self):
+ return self.serverid
+ def get_permutation_seed(self):
+ return self.serverid
+ def get_lease_seed(self):
+ return self.serverid
+ def get_foolscap_write_enabler_seed(self):
+ return self.serverid
+
+ def get_name(self):
+ return idlib.shortnodeid_b2a(self.serverid)
+ def get_longname(self):
+ return idlib.nodeid_b2a(self.serverid)
+ def get_nickname(self):
+ return "nickname"
+ def get_rref(self):
+ return self.rref
+ def get_version(self):
+ return self.rref.version
+
class NoNetworkStorageBroker:
implements(IStorageBroker)
- def get_servers_for_index(self, key):
- return sorted(self.client._servers,
- key=lambda x: sha1(key+x[0]).digest())
- def get_all_servers(self):
- return frozenset(self.client._servers)
+ def get_servers_for_psi(self, peer_selection_index):
+ def _permuted(server):
+ seed = server.get_permutation_seed()
+ return sha1(peer_selection_index + seed).digest()
+ return sorted(self.get_connected_servers(), key=_permuted)
+ def get_connected_servers(self):
+ return self.client._servers
def get_nickname_for_serverid(self, serverid):
return None
self.basedir = basedir
fileutil.make_dirs(basedir)
- self.servers_by_number = {}
- self.servers_by_id = {}
+ self.servers_by_number = {} # maps to StorageServer instance
+ self.wrappers_by_id = {} # maps to wrapped StorageServer instance
+ self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
+ # StorageServer
self.clients = []
for i in range(num_servers):
c = client_config_hooks[i](clientdir)
if not c:
c = NoNetworkClient(clientdir)
- c.set_default_mutable_keysize(522)
+ c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
c.nodeid = clientid
c.short_nodeid = b32encode(clientid).lower()[:8]
c._servers = self.all_servers # can be updated later
c.setServiceParent(self)
self.clients.append(c)
- def make_server(self, i):
+ def make_server(self, i, readonly=False):
serverid = hashutil.tagged_hash("serverid", str(i))[:20]
serverdir = os.path.join(self.basedir, "servers",
- idlib.shortnodeid_b2a(serverid))
+ idlib.shortnodeid_b2a(serverid), "storage")
fileutil.make_dirs(serverdir)
- ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats())
+ ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
+ readonly_storage=readonly)
+ ss._no_network_server_number = i
return ss
def add_server(self, i, ss):
ss.setServiceParent(middleman)
serverid = ss.my_nodeid
self.servers_by_number[i] = ss
- self.servers_by_id[serverid] = wrap_storage_server(ss)
+ wrapper = wrap_storage_server(ss)
+ self.wrappers_by_id[serverid] = wrapper
+ self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
self.rebuild_serverlist()
+ def get_all_serverids(self):
+ return self.proxies_by_id.keys()
+
def rebuild_serverlist(self):
- self.all_servers = frozenset(self.servers_by_id.items())
+ self.all_servers = frozenset(self.proxies_by_id.values())
for c in self.clients:
c._servers = self.all_servers
if ss.my_nodeid == serverid:
del self.servers_by_number[i]
break
- del self.servers_by_id[serverid]
+ del self.wrappers_by_id[serverid]
+ del self.proxies_by_id[serverid]
self.rebuild_serverlist()
+ return ss
- def break_server(self, serverid):
+ def break_server(self, serverid, count=True):
# mark the given server as broken, so it will throw exceptions when
- # asked to hold a share
- self.servers_by_id[serverid].broken = True
+ # asked to hold a share or serve a share. If count= is a number,
+ # throw that many exceptions before starting to work again.
+ self.wrappers_by_id[serverid].broken = count
+
+ def hang_server(self, serverid):
+ # hang the given server
+ ss = self.wrappers_by_id[serverid]
+ assert ss.hung_until is None
+ ss.hung_until = defer.Deferred()
+
+ def unhang_server(self, serverid):
+ # unhang the given server
+ ss = self.wrappers_by_id[serverid]
+ assert ss.hung_until is not None
+ ss.hung_until.callback(None)
+ ss.hung_until = None
+
class GridTestMixin:
def setUp(self):
num_servers=num_servers,
client_config_hooks=client_config_hooks)
self.g.setServiceParent(self.s)
- self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
+ self.client_webports = [c.getServiceNamed("webish").getPortnum()
+ for c in self.g.clients]
+ self.client_baseurls = [c.getServiceNamed("webish").getURL()
for c in self.g.clients]
- self.client_baseurls = ["http://localhost:%d/" % p
- for p in self.client_webports]
def get_clientdir(self, i=0):
return self.g.clients[i].basedir
ss = self.g.servers_by_number[i]
yield (i, ss, ss.storedir)
- def find_shares(self, uri):
+ def find_uri_shares(self, uri):
si = tahoe_uri.from_string(uri).get_storage_index()
prefixdir = storage_index_to_dir(si)
shares = []
for i,ss in self.g.servers_by_number.items():
serverid = ss.my_nodeid
- basedir = os.path.join(ss.storedir, "shares", prefixdir)
+ basedir = os.path.join(ss.sharedir, prefixdir)
if not os.path.exists(basedir):
continue
for f in os.listdir(basedir):
pass
return sorted(shares)
+ def copy_shares(self, uri):
+ shares = {}
+ for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
+ shares[sharefile] = open(sharefile, "rb").read()
+ return shares
+
+ def restore_all_shares(self, shares):
+ for sharefile, data in shares.items():
+ open(sharefile, "wb").write(data)
+
def delete_share(self, (shnum, serverid, sharefile)):
os.unlink(sharefile)
def delete_shares_numbered(self, uri, shnums):
- for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
+ for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
if i_shnum in shnums:
os.unlink(i_sharefile)
corruptdata = corruptor_function(sharedata)
open(sharefile, "wb").write(corruptdata)
- def corrupt_shares_numbered(self, uri, shnums, corruptor):
- for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
+ def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
+ for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
if i_shnum in shnums:
sharedata = open(i_sharefile, "rb").read()
- corruptdata = corruptor(sharedata)
+ corruptdata = corruptor(sharedata, debug=debug)
open(i_sharefile, "wb").write(corruptdata)
+ def corrupt_all_shares(self, uri, corruptor, debug=False):
+ for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
+ sharedata = open(i_sharefile, "rb").read()
+ corruptdata = corruptor(sharedata, debug=debug)
+ open(i_sharefile, "wb").write(corruptdata)
+
def GET(self, urlpath, followRedirect=False, return_response=False,
method="GET", clientnum=0, **kwargs):
# if return_response=True, this fires with (data, statuscode,
if return_response:
d.addCallback(_got_data)
return factory.deferred
+
+ def PUT(self, urlpath, **kwargs):
+ return self.GET(urlpath, method="PUT", **kwargs)