]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
Improvements to no_network test harness.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23
24 from allmydata.util.assertutil import _assert
25
26 from allmydata import uri as tahoe_uri
27 from allmydata.client import Client
28 from allmydata.storage.server import StorageServer, storage_index_to_dir
29 from allmydata.util import fileutil, idlib, hashutil
30 from allmydata.util.hashutil import sha1
31 from allmydata.test.common_web import HTTPClientGETFactory
32 from allmydata.interfaces import IStorageBroker, IServer
33 from allmydata.test.common import TEST_RSA_KEY_SIZE
34
35
36 class IntentionalError(Exception):
37     pass
38
39 class Marker:
40     pass
41
42 class LocalWrapper:
43     def __init__(self, original):
44         self.original = original
45         self.broken = False
46         self.hung_until = None
47         self.post_call_notifier = None
48         self.disconnectors = {}
49         self.counter_by_methname = {}
50
51     def _clear_counters(self):
52         self.counter_by_methname = {}
53
54     def callRemoteOnly(self, methname, *args, **kwargs):
55         d = self.callRemote(methname, *args, **kwargs)
56         del d # explicitly ignored
57         return None
58
59     def callRemote(self, methname, *args, **kwargs):
60         # this is ideally a Membrane, but that's too hard. We do a shallow
61         # wrapping of inbound arguments, and per-methodname wrapping of
62         # selected return values.
63         def wrap(a):
64             if isinstance(a, Referenceable):
65                 return LocalWrapper(a)
66             else:
67                 return a
68         args = tuple([wrap(a) for a in args])
69         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
70
71         def _really_call():
72             def incr(d, k): d[k] = d.setdefault(k, 0) + 1
73             incr(self.counter_by_methname, methname)
74             meth = getattr(self.original, "remote_" + methname)
75             return meth(*args, **kwargs)
76
77         def _call():
78             if self.broken:
79                 if self.broken is not True: # a counter, not boolean
80                     self.broken -= 1
81                 raise IntentionalError("I was asked to break")
82             if self.hung_until:
83                 d2 = defer.Deferred()
84                 self.hung_until.addCallback(lambda ign: _really_call())
85                 self.hung_until.addCallback(lambda res: d2.callback(res))
86                 def _err(res):
87                     d2.errback(res)
88                     return res
89                 self.hung_until.addErrback(_err)
90                 return d2
91             return _really_call()
92
93         d = fireEventually()
94         d.addCallback(lambda res: _call())
95         def _wrap_exception(f):
96             return Failure(RemoteException(f))
97         d.addErrback(_wrap_exception)
98         def _return_membrane(res):
99             # rather than complete the difficult task of building a
100             # fully-general Membrane (which would locate all Referenceable
101             # objects that cross the simulated wire and replace them with
102             # wrappers), we special-case certain methods that we happen to
103             # know will return Referenceables.
104             if methname == "allocate_buckets":
105                 (alreadygot, allocated) = res
106                 for shnum in allocated:
107                     allocated[shnum] = LocalWrapper(allocated[shnum])
108             if methname == "get_buckets":
109                 for shnum in res:
110                     res[shnum] = LocalWrapper(res[shnum])
111             return res
112         d.addCallback(_return_membrane)
113         if self.post_call_notifier:
114             d.addCallback(self.post_call_notifier, self, methname)
115         return d
116
117     def notifyOnDisconnect(self, f, *args, **kwargs):
118         m = Marker()
119         self.disconnectors[m] = (f, args, kwargs)
120         return m
121     def dontNotifyOnDisconnect(self, marker):
122         del self.disconnectors[marker]
123
124 def wrap_storage_server(original):
125     # Much of the upload/download code uses rref.version (which normally
126     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
127     # network, we want a LocalWrapper here. Try to satisfy all these
128     # constraints at the same time.
129     wrapper = LocalWrapper(original)
130     wrapper.version = original.remote_get_version()
131     return wrapper
132
133 class NoNetworkServer:
134     implements(IServer)
135     def __init__(self, serverid, rref):
136         self.serverid = serverid
137         self.rref = rref
138     def __repr__(self):
139         return "<NoNetworkServer for %s>" % self.get_name()
140     # Special method used by copy.copy() and copy.deepcopy(). When those are
141     # used in allmydata.immutable.filenode to copy CheckResults during
142     # repair, we want it to treat the IServer instances as singletons.
143     def __copy__(self):
144         return self
145     def __deepcopy__(self, memodict):
146         return self
147     def get_serverid(self):
148         return self.serverid
149     def get_permutation_seed(self):
150         return self.serverid
151     def get_lease_seed(self):
152         return self.serverid
153     def get_foolscap_write_enabler_seed(self):
154         return self.serverid
155
156     def get_name(self):
157         return idlib.shortnodeid_b2a(self.serverid)
158     def get_longname(self):
159         return idlib.nodeid_b2a(self.serverid)
160     def get_nickname(self):
161         return "nickname"
162     def get_rref(self):
163         return self.rref
164     def get_version(self):
165         return self.rref.version
166
167 class NoNetworkStorageBroker:
168     implements(IStorageBroker)
169     def get_servers_for_psi(self, peer_selection_index):
170         def _permuted(server):
171             seed = server.get_permutation_seed()
172             return sha1(peer_selection_index + seed).digest()
173         return sorted(self.get_connected_servers(), key=_permuted)
174     def get_connected_servers(self):
175         return self.client._servers
176     def get_nickname_for_serverid(self, serverid):
177         return None
178
179 class NoNetworkClient(Client):
180
181     def disownServiceParent(self):
182         self.disownServiceParent()
183     def create_tub(self):
184         pass
185     def init_introducer_client(self):
186         pass
187     def setup_logging(self):
188         pass
189     def startService(self):
190         service.MultiService.startService(self)
191     def stopService(self):
192         service.MultiService.stopService(self)
193     def when_tub_ready(self):
194         raise NotImplementedError("NoNetworkClient has no Tub")
195     def init_control(self):
196         pass
197     def init_helper(self):
198         pass
199     def init_key_gen(self):
200         pass
201     def init_storage(self):
202         pass
203     def init_client_storage_broker(self):
204         self.storage_broker = NoNetworkStorageBroker()
205         self.storage_broker.client = self
206     def init_stub_client(self):
207         pass
208     #._servers will be set by the NoNetworkGrid which creates us
209
210 class SimpleStats:
211     def __init__(self):
212         self.counters = {}
213         self.stats_producers = []
214
215     def count(self, name, delta=1):
216         val = self.counters.setdefault(name, 0)
217         self.counters[name] = val + delta
218
219     def register_producer(self, stats_producer):
220         self.stats_producers.append(stats_producer)
221
222     def get_stats(self):
223         stats = {}
224         for sp in self.stats_producers:
225             stats.update(sp.get_stats())
226         ret = { 'counters': self.counters, 'stats': stats }
227         return ret
228
229 class NoNetworkGrid(service.MultiService):
230     def __init__(self, basedir, num_clients=1, num_servers=10,
231                  client_config_hooks={}):
232         service.MultiService.__init__(self)
233         self.basedir = basedir
234         fileutil.make_dirs(basedir)
235
236         self.servers_by_number = {} # maps to StorageServer instance
237         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
238         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
239                                 # StorageServer
240         self.clients = []
241         self.client_config_hooks = client_config_hooks
242
243         for i in range(num_servers):
244             ss = self.make_server(i)
245             self.add_server(i, ss)
246         self.rebuild_serverlist()
247
248         for i in range(num_clients):
249             c = self.make_client(i)
250             self.clients.append(c)
251
252     def make_client(self, i, write_config=True):
253         clientid = hashutil.tagged_hash("clientid", str(i))[:20]
254         clientdir = os.path.join(self.basedir, "clients",
255                                  idlib.shortnodeid_b2a(clientid))
256         fileutil.make_dirs(clientdir)
257
258         tahoe_cfg_path = os.path.join(clientdir, "tahoe.cfg")
259         if write_config:
260             f = open(tahoe_cfg_path, "w")
261             f.write("[node]\n")
262             f.write("nickname = client-%d\n" % i)
263             f.write("web.port = tcp:0:interface=127.0.0.1\n")
264             f.write("[storage]\n")
265             f.write("enabled = false\n")
266             f.close()
267         else:
268             _assert(os.path.exists(tahoe_cfg_path), tahoe_cfg_path=tahoe_cfg_path)
269
270         c = None
271         if i in self.client_config_hooks:
272             # this hook can either modify tahoe.cfg, or return an
273             # entirely new Client instance
274             c = self.client_config_hooks[i](clientdir)
275
276         if not c:
277             c = NoNetworkClient(clientdir)
278             c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
279
280         c.nodeid = clientid
281         c.short_nodeid = b32encode(clientid).lower()[:8]
282         c._servers = self.all_servers # can be updated later
283         c.setServiceParent(self)
284         return c
285
286     def make_server(self, i, readonly=False):
287         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
288         serverdir = os.path.join(self.basedir, "servers",
289                                  idlib.shortnodeid_b2a(serverid), "storage")
290         fileutil.make_dirs(serverdir)
291         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
292                            readonly_storage=readonly)
293         ss._no_network_server_number = i
294         return ss
295
296     def add_server(self, i, ss):
297         # to deal with the fact that all StorageServers are named 'storage',
298         # we interpose a middleman
299         middleman = service.MultiService()
300         middleman.setServiceParent(self)
301         ss.setServiceParent(middleman)
302         serverid = ss.my_nodeid
303         self.servers_by_number[i] = ss
304         wrapper = wrap_storage_server(ss)
305         self.wrappers_by_id[serverid] = wrapper
306         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
307         self.rebuild_serverlist()
308
309     def get_all_serverids(self):
310         return self.proxies_by_id.keys()
311
312     def rebuild_serverlist(self):
313         self.all_servers = frozenset(self.proxies_by_id.values())
314         for c in self.clients:
315             c._servers = self.all_servers
316
317     def remove_server(self, serverid):
318         # it's enough to remove the server from c._servers (we don't actually
319         # have to detach and stopService it)
320         for i,ss in self.servers_by_number.items():
321             if ss.my_nodeid == serverid:
322                 del self.servers_by_number[i]
323                 break
324         del self.wrappers_by_id[serverid]
325         del self.proxies_by_id[serverid]
326         self.rebuild_serverlist()
327         return ss
328
329     def break_server(self, serverid, count=True):
330         # mark the given server as broken, so it will throw exceptions when
331         # asked to hold a share or serve a share. If count= is a number,
332         # throw that many exceptions before starting to work again.
333         self.wrappers_by_id[serverid].broken = count
334
335     def hang_server(self, serverid):
336         # hang the given server
337         ss = self.wrappers_by_id[serverid]
338         assert ss.hung_until is None
339         ss.hung_until = defer.Deferred()
340
341     def unhang_server(self, serverid):
342         # unhang the given server
343         ss = self.wrappers_by_id[serverid]
344         assert ss.hung_until is not None
345         ss.hung_until.callback(None)
346         ss.hung_until = None
347
348     def nuke_from_orbit(self):
349         """ Empty all share directories in this grid. It's the only way to be sure ;-) """
350         for server in self.servers_by_number.values():
351             for prefixdir in os.listdir(server.sharedir):
352                 if prefixdir != 'incoming':
353                     fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
354
355
356 class GridTestMixin:
357     def setUp(self):
358         self.s = service.MultiService()
359         self.s.startService()
360
361     def tearDown(self):
362         return self.s.stopService()
363
364     def set_up_grid(self, num_clients=1, num_servers=10,
365                     client_config_hooks={}):
366         # self.basedir must be set
367         self.g = NoNetworkGrid(self.basedir,
368                                num_clients=num_clients,
369                                num_servers=num_servers,
370                                client_config_hooks=client_config_hooks)
371         self.g.setServiceParent(self.s)
372         self._record_webports_and_baseurls()
373
374     def _record_webports_and_baseurls(self):
375         self.client_webports = [c.getServiceNamed("webish").getPortnum()
376                                 for c in self.g.clients]
377         self.client_baseurls = [c.getServiceNamed("webish").getURL()
378                                 for c in self.g.clients]
379
380     def get_clientdir(self, i=0):
381         return self.g.clients[i].basedir
382
383     def set_clientdir(self, basedir, i=0):
384         self.g.clients[i].basedir = basedir
385
386     def get_client(self, i=0):
387         return self.g.clients[i]
388
389     def restart_client(self, i=0):
390         client = self.g.clients[i]
391         d = defer.succeed(None)
392         d.addCallback(lambda ign: self.g.removeService(client))
393         def _make_client(ign):
394             c = self.g.make_client(i, write_config=False)
395             self.g.clients[i] = c
396             self._record_webports_and_baseurls()
397         d.addCallback(_make_client)
398         return d
399
400     def get_serverdir(self, i):
401         return self.g.servers_by_number[i].storedir
402
403     def iterate_servers(self):
404         for i in sorted(self.g.servers_by_number.keys()):
405             ss = self.g.servers_by_number[i]
406             yield (i, ss, ss.storedir)
407
408     def find_uri_shares(self, uri):
409         si = tahoe_uri.from_string(uri).get_storage_index()
410         prefixdir = storage_index_to_dir(si)
411         shares = []
412         for i,ss in self.g.servers_by_number.items():
413             serverid = ss.my_nodeid
414             basedir = os.path.join(ss.sharedir, prefixdir)
415             if not os.path.exists(basedir):
416                 continue
417             for f in os.listdir(basedir):
418                 try:
419                     shnum = int(f)
420                     shares.append((shnum, serverid, os.path.join(basedir, f)))
421                 except ValueError:
422                     pass
423         return sorted(shares)
424
425     def copy_shares(self, uri):
426         shares = {}
427         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
428             shares[sharefile] = open(sharefile, "rb").read()
429         return shares
430
431     def restore_all_shares(self, shares):
432         for sharefile, data in shares.items():
433             open(sharefile, "wb").write(data)
434
435     def delete_share(self, (shnum, serverid, sharefile)):
436         os.unlink(sharefile)
437
438     def delete_shares_numbered(self, uri, shnums):
439         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
440             if i_shnum in shnums:
441                 os.unlink(i_sharefile)
442
443     def delete_all_shares(self, serverdir):
444         sharedir = os.path.join(serverdir, "shares")
445         for prefixdir in os.listdir(sharedir):
446             if prefixdir != 'incoming':
447                 fileutil.rm_dir(os.path.join(sharedir, prefixdir))
448
449     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
450         sharedata = open(sharefile, "rb").read()
451         corruptdata = corruptor_function(sharedata)
452         open(sharefile, "wb").write(corruptdata)
453
454     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
455         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
456             if i_shnum in shnums:
457                 sharedata = open(i_sharefile, "rb").read()
458                 corruptdata = corruptor(sharedata, debug=debug)
459                 open(i_sharefile, "wb").write(corruptdata)
460
461     def corrupt_all_shares(self, uri, corruptor, debug=False):
462         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
463             sharedata = open(i_sharefile, "rb").read()
464             corruptdata = corruptor(sharedata, debug=debug)
465             open(i_sharefile, "wb").write(corruptdata)
466
467     def GET(self, urlpath, followRedirect=False, return_response=False,
468             method="GET", clientnum=0, **kwargs):
469         # if return_response=True, this fires with (data, statuscode,
470         # respheaders) instead of just data.
471         assert not isinstance(urlpath, unicode)
472         url = self.client_baseurls[clientnum] + urlpath
473         factory = HTTPClientGETFactory(url, method=method,
474                                        followRedirect=followRedirect, **kwargs)
475         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
476         d = factory.deferred
477         def _got_data(data):
478             return (data, factory.status, factory.response_headers)
479         if return_response:
480             d.addCallback(_got_data)
481         return factory.deferred
482
483     def PUT(self, urlpath, **kwargs):
484         return self.GET(urlpath, method="PUT", **kwargs)