]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
make IServer instances retain identity in copy() and deepcopy()
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker, IServer
30 from allmydata.test.common import TEST_RSA_KEY_SIZE
31
32
33 class IntentionalError(Exception):
34     pass
35
36 class Marker:
37     pass
38
39 class LocalWrapper:
40     def __init__(self, original):
41         self.original = original
42         self.broken = False
43         self.hung_until = None
44         self.post_call_notifier = None
45         self.disconnectors = {}
46
47     def callRemoteOnly(self, methname, *args, **kwargs):
48         d = self.callRemote(methname, *args, **kwargs)
49         del d # explicitly ignored
50         return None
51
52     def callRemote(self, methname, *args, **kwargs):
53         # this is ideally a Membrane, but that's too hard. We do a shallow
54         # wrapping of inbound arguments, and per-methodname wrapping of
55         # selected return values.
56         def wrap(a):
57             if isinstance(a, Referenceable):
58                 return LocalWrapper(a)
59             else:
60                 return a
61         args = tuple([wrap(a) for a in args])
62         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
63
64         def _really_call():
65             meth = getattr(self.original, "remote_" + methname)
66             return meth(*args, **kwargs)
67
68         def _call():
69             if self.broken:
70                 raise IntentionalError("I was asked to break")
71             if self.hung_until:
72                 d2 = defer.Deferred()
73                 self.hung_until.addCallback(lambda ign: _really_call())
74                 self.hung_until.addCallback(lambda res: d2.callback(res))
75                 def _err(res):
76                     d2.errback(res)
77                     return res
78                 self.hung_until.addErrback(_err)
79                 return d2
80             return _really_call()
81
82         d = fireEventually()
83         d.addCallback(lambda res: _call())
84         def _wrap_exception(f):
85             return Failure(RemoteException(f))
86         d.addErrback(_wrap_exception)
87         def _return_membrane(res):
88             # rather than complete the difficult task of building a
89             # fully-general Membrane (which would locate all Referenceable
90             # objects that cross the simulated wire and replace them with
91             # wrappers), we special-case certain methods that we happen to
92             # know will return Referenceables.
93             if methname == "allocate_buckets":
94                 (alreadygot, allocated) = res
95                 for shnum in allocated:
96                     allocated[shnum] = LocalWrapper(allocated[shnum])
97             if methname == "get_buckets":
98                 for shnum in res:
99                     res[shnum] = LocalWrapper(res[shnum])
100             return res
101         d.addCallback(_return_membrane)
102         if self.post_call_notifier:
103             d.addCallback(self.post_call_notifier, self, methname)
104         return d
105
106     def notifyOnDisconnect(self, f, *args, **kwargs):
107         m = Marker()
108         self.disconnectors[m] = (f, args, kwargs)
109         return m
110     def dontNotifyOnDisconnect(self, marker):
111         del self.disconnectors[marker]
112
113 def wrap_storage_server(original):
114     # Much of the upload/download code uses rref.version (which normally
115     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
116     # network, we want a LocalWrapper here. Try to satisfy all these
117     # constraints at the same time.
118     wrapper = LocalWrapper(original)
119     wrapper.version = original.remote_get_version()
120     return wrapper
121
122 class NoNetworkServer:
123     implements(IServer)
124     def __init__(self, serverid, rref):
125         self.serverid = serverid
126         self.rref = rref
127     def __repr__(self):
128         return "<NoNetworkServer for %s>" % self.get_name()
129     # Special method used by copy.copy() and copy.deepcopy(). When those are
130     # used in allmydata.immutable.filenode to copy CheckResults during
131     # repair, we want it to treat the IServer instances as singletons.
132     def __copy__(self):
133         return self
134     def __deepcopy__(self, memodict):
135         return self
136     def get_serverid(self):
137         return self.serverid
138     def get_permutation_seed(self):
139         return self.serverid
140     def get_lease_seed(self):
141         return self.serverid
142     def get_foolscap_write_enabler_seed(self):
143         return self.serverid
144
145     def get_name(self):
146         return idlib.shortnodeid_b2a(self.serverid)
147     def get_longname(self):
148         return idlib.nodeid_b2a(self.serverid)
149     def get_nickname(self):
150         return "nickname"
151     def get_rref(self):
152         return self.rref
153     def get_version(self):
154         return self.rref.version
155
156 class NoNetworkStorageBroker:
157     implements(IStorageBroker)
158     def get_servers_for_psi(self, peer_selection_index):
159         def _permuted(server):
160             seed = server.get_permutation_seed()
161             return sha1(peer_selection_index + seed).digest()
162         return sorted(self.get_connected_servers(), key=_permuted)
163     def get_connected_servers(self):
164         return self.client._servers
165     def get_nickname_for_serverid(self, serverid):
166         return None
167
168 class NoNetworkClient(Client):
169     def create_tub(self):
170         pass
171     def init_introducer_client(self):
172         pass
173     def setup_logging(self):
174         pass
175     def startService(self):
176         service.MultiService.startService(self)
177     def stopService(self):
178         service.MultiService.stopService(self)
179     def when_tub_ready(self):
180         raise NotImplementedError("NoNetworkClient has no Tub")
181     def init_control(self):
182         pass
183     def init_helper(self):
184         pass
185     def init_key_gen(self):
186         pass
187     def init_storage(self):
188         pass
189     def init_client_storage_broker(self):
190         self.storage_broker = NoNetworkStorageBroker()
191         self.storage_broker.client = self
192     def init_stub_client(self):
193         pass
194     #._servers will be set by the NoNetworkGrid which creates us
195
196 class SimpleStats:
197     def __init__(self):
198         self.counters = {}
199         self.stats_producers = []
200
201     def count(self, name, delta=1):
202         val = self.counters.setdefault(name, 0)
203         self.counters[name] = val + delta
204
205     def register_producer(self, stats_producer):
206         self.stats_producers.append(stats_producer)
207
208     def get_stats(self):
209         stats = {}
210         for sp in self.stats_producers:
211             stats.update(sp.get_stats())
212         ret = { 'counters': self.counters, 'stats': stats }
213         return ret
214
215 class NoNetworkGrid(service.MultiService):
216     def __init__(self, basedir, num_clients=1, num_servers=10,
217                  client_config_hooks={}):
218         service.MultiService.__init__(self)
219         self.basedir = basedir
220         fileutil.make_dirs(basedir)
221
222         self.servers_by_number = {} # maps to StorageServer instance
223         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
224         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
225                                 # StorageServer
226         self.clients = []
227
228         for i in range(num_servers):
229             ss = self.make_server(i)
230             self.add_server(i, ss)
231         self.rebuild_serverlist()
232
233         for i in range(num_clients):
234             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
235             clientdir = os.path.join(basedir, "clients",
236                                      idlib.shortnodeid_b2a(clientid))
237             fileutil.make_dirs(clientdir)
238             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
239             f.write("[node]\n")
240             f.write("nickname = client-%d\n" % i)
241             f.write("web.port = tcp:0:interface=127.0.0.1\n")
242             f.write("[storage]\n")
243             f.write("enabled = false\n")
244             f.close()
245             c = None
246             if i in client_config_hooks:
247                 # this hook can either modify tahoe.cfg, or return an
248                 # entirely new Client instance
249                 c = client_config_hooks[i](clientdir)
250             if not c:
251                 c = NoNetworkClient(clientdir)
252                 c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
253             c.nodeid = clientid
254             c.short_nodeid = b32encode(clientid).lower()[:8]
255             c._servers = self.all_servers # can be updated later
256             c.setServiceParent(self)
257             self.clients.append(c)
258
259     def make_server(self, i, readonly=False):
260         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
261         serverdir = os.path.join(self.basedir, "servers",
262                                  idlib.shortnodeid_b2a(serverid), "storage")
263         fileutil.make_dirs(serverdir)
264         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
265                            readonly_storage=readonly)
266         ss._no_network_server_number = i
267         return ss
268
269     def add_server(self, i, ss):
270         # to deal with the fact that all StorageServers are named 'storage',
271         # we interpose a middleman
272         middleman = service.MultiService()
273         middleman.setServiceParent(self)
274         ss.setServiceParent(middleman)
275         serverid = ss.my_nodeid
276         self.servers_by_number[i] = ss
277         wrapper = wrap_storage_server(ss)
278         self.wrappers_by_id[serverid] = wrapper
279         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
280         self.rebuild_serverlist()
281
282     def get_all_serverids(self):
283         return self.proxies_by_id.keys()
284
285     def rebuild_serverlist(self):
286         self.all_servers = frozenset(self.proxies_by_id.values())
287         for c in self.clients:
288             c._servers = self.all_servers
289
290     def remove_server(self, serverid):
291         # it's enough to remove the server from c._servers (we don't actually
292         # have to detach and stopService it)
293         for i,ss in self.servers_by_number.items():
294             if ss.my_nodeid == serverid:
295                 del self.servers_by_number[i]
296                 break
297         del self.wrappers_by_id[serverid]
298         del self.proxies_by_id[serverid]
299         self.rebuild_serverlist()
300         return ss
301
302     def break_server(self, serverid):
303         # mark the given server as broken, so it will throw exceptions when
304         # asked to hold a share or serve a share
305         self.wrappers_by_id[serverid].broken = True
306
307     def hang_server(self, serverid):
308         # hang the given server
309         ss = self.wrappers_by_id[serverid]
310         assert ss.hung_until is None
311         ss.hung_until = defer.Deferred()
312
313     def unhang_server(self, serverid):
314         # unhang the given server
315         ss = self.wrappers_by_id[serverid]
316         assert ss.hung_until is not None
317         ss.hung_until.callback(None)
318         ss.hung_until = None
319
320
321 class GridTestMixin:
322     def setUp(self):
323         self.s = service.MultiService()
324         self.s.startService()
325
326     def tearDown(self):
327         return self.s.stopService()
328
329     def set_up_grid(self, num_clients=1, num_servers=10,
330                     client_config_hooks={}):
331         # self.basedir must be set
332         self.g = NoNetworkGrid(self.basedir,
333                                num_clients=num_clients,
334                                num_servers=num_servers,
335                                client_config_hooks=client_config_hooks)
336         self.g.setServiceParent(self.s)
337         self.client_webports = [c.getServiceNamed("webish").getPortnum()
338                                 for c in self.g.clients]
339         self.client_baseurls = [c.getServiceNamed("webish").getURL()
340                                 for c in self.g.clients]
341
342     def get_clientdir(self, i=0):
343         return self.g.clients[i].basedir
344
345     def get_serverdir(self, i):
346         return self.g.servers_by_number[i].storedir
347
348     def iterate_servers(self):
349         for i in sorted(self.g.servers_by_number.keys()):
350             ss = self.g.servers_by_number[i]
351             yield (i, ss, ss.storedir)
352
353     def find_uri_shares(self, uri):
354         si = tahoe_uri.from_string(uri).get_storage_index()
355         prefixdir = storage_index_to_dir(si)
356         shares = []
357         for i,ss in self.g.servers_by_number.items():
358             serverid = ss.my_nodeid
359             basedir = os.path.join(ss.sharedir, prefixdir)
360             if not os.path.exists(basedir):
361                 continue
362             for f in os.listdir(basedir):
363                 try:
364                     shnum = int(f)
365                     shares.append((shnum, serverid, os.path.join(basedir, f)))
366                 except ValueError:
367                     pass
368         return sorted(shares)
369
370     def copy_shares(self, uri):
371         shares = {}
372         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
373             shares[sharefile] = open(sharefile, "rb").read()
374         return shares
375
376     def restore_all_shares(self, shares):
377         for sharefile, data in shares.items():
378             open(sharefile, "wb").write(data)
379
380     def delete_share(self, (shnum, serverid, sharefile)):
381         os.unlink(sharefile)
382
383     def delete_shares_numbered(self, uri, shnums):
384         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
385             if i_shnum in shnums:
386                 os.unlink(i_sharefile)
387
388     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
389         sharedata = open(sharefile, "rb").read()
390         corruptdata = corruptor_function(sharedata)
391         open(sharefile, "wb").write(corruptdata)
392
393     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
394         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
395             if i_shnum in shnums:
396                 sharedata = open(i_sharefile, "rb").read()
397                 corruptdata = corruptor(sharedata, debug=debug)
398                 open(i_sharefile, "wb").write(corruptdata)
399
400     def corrupt_all_shares(self, uri, corruptor, debug=False):
401         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
402             sharedata = open(i_sharefile, "rb").read()
403             corruptdata = corruptor(sharedata, debug=debug)
404             open(i_sharefile, "wb").write(corruptdata)
405
406     def GET(self, urlpath, followRedirect=False, return_response=False,
407             method="GET", clientnum=0, **kwargs):
408         # if return_response=True, this fires with (data, statuscode,
409         # respheaders) instead of just data.
410         assert not isinstance(urlpath, unicode)
411         url = self.client_baseurls[clientnum] + urlpath
412         factory = HTTPClientGETFactory(url, method=method,
413                                        followRedirect=followRedirect, **kwargs)
414         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
415         d = factory.deferred
416         def _got_data(data):
417             return (data, factory.status, factory.response_headers)
418         if return_response:
419             d.addCallback(_got_data)
420         return factory.deferred
421
422     def PUT(self, urlpath, **kwargs):
423         return self.GET(urlpath, method="PUT", **kwargs)