]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
8dd9a2f957a2a91469590db66849ea5984ba45ec
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker, IServer
30 from allmydata.test.common import TEST_RSA_KEY_SIZE
31
32
33 class IntentionalError(Exception):
34     pass
35
36 class Marker:
37     pass
38
39 class LocalWrapper:
40     def __init__(self, original):
41         self.original = original
42         self.broken = False
43         self.hung_until = None
44         self.post_call_notifier = None
45         self.disconnectors = {}
46         self.counter_by_methname = {}
47
48     def _clear_counters(self):
49         self.counter_by_methname = {}
50
51     def callRemoteOnly(self, methname, *args, **kwargs):
52         d = self.callRemote(methname, *args, **kwargs)
53         del d # explicitly ignored
54         return None
55
56     def callRemote(self, methname, *args, **kwargs):
57         # this is ideally a Membrane, but that's too hard. We do a shallow
58         # wrapping of inbound arguments, and per-methodname wrapping of
59         # selected return values.
60         def wrap(a):
61             if isinstance(a, Referenceable):
62                 return LocalWrapper(a)
63             else:
64                 return a
65         args = tuple([wrap(a) for a in args])
66         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
67
68         def _really_call():
69             def incr(d, k): d[k] = d.setdefault(k, 0) + 1
70             incr(self.counter_by_methname, methname)
71             meth = getattr(self.original, "remote_" + methname)
72             return meth(*args, **kwargs)
73
74         def _call():
75             if self.broken:
76                 if self.broken is not True: # a counter, not boolean
77                     self.broken -= 1
78                 raise IntentionalError("I was asked to break")
79             if self.hung_until:
80                 d2 = defer.Deferred()
81                 self.hung_until.addCallback(lambda ign: _really_call())
82                 self.hung_until.addCallback(lambda res: d2.callback(res))
83                 def _err(res):
84                     d2.errback(res)
85                     return res
86                 self.hung_until.addErrback(_err)
87                 return d2
88             return _really_call()
89
90         d = fireEventually()
91         d.addCallback(lambda res: _call())
92         def _wrap_exception(f):
93             return Failure(RemoteException(f))
94         d.addErrback(_wrap_exception)
95         def _return_membrane(res):
96             # rather than complete the difficult task of building a
97             # fully-general Membrane (which would locate all Referenceable
98             # objects that cross the simulated wire and replace them with
99             # wrappers), we special-case certain methods that we happen to
100             # know will return Referenceables.
101             if methname == "allocate_buckets":
102                 (alreadygot, allocated) = res
103                 for shnum in allocated:
104                     allocated[shnum] = LocalWrapper(allocated[shnum])
105             if methname == "get_buckets":
106                 for shnum in res:
107                     res[shnum] = LocalWrapper(res[shnum])
108             return res
109         d.addCallback(_return_membrane)
110         if self.post_call_notifier:
111             d.addCallback(self.post_call_notifier, self, methname)
112         return d
113
114     def notifyOnDisconnect(self, f, *args, **kwargs):
115         m = Marker()
116         self.disconnectors[m] = (f, args, kwargs)
117         return m
118     def dontNotifyOnDisconnect(self, marker):
119         del self.disconnectors[marker]
120
121 def wrap_storage_server(original):
122     # Much of the upload/download code uses rref.version (which normally
123     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
124     # network, we want a LocalWrapper here. Try to satisfy all these
125     # constraints at the same time.
126     wrapper = LocalWrapper(original)
127     wrapper.version = original.remote_get_version()
128     return wrapper
129
130 class NoNetworkServer:
131     implements(IServer)
132     def __init__(self, serverid, rref):
133         self.serverid = serverid
134         self.rref = rref
135     def __repr__(self):
136         return "<NoNetworkServer for %s>" % self.get_name()
137     # Special method used by copy.copy() and copy.deepcopy(). When those are
138     # used in allmydata.immutable.filenode to copy CheckResults during
139     # repair, we want it to treat the IServer instances as singletons.
140     def __copy__(self):
141         return self
142     def __deepcopy__(self, memodict):
143         return self
144     def get_serverid(self):
145         return self.serverid
146     def get_permutation_seed(self):
147         return self.serverid
148     def get_lease_seed(self):
149         return self.serverid
150     def get_foolscap_write_enabler_seed(self):
151         return self.serverid
152
153     def get_name(self):
154         return idlib.shortnodeid_b2a(self.serverid)
155     def get_longname(self):
156         return idlib.nodeid_b2a(self.serverid)
157     def get_nickname(self):
158         return "nickname"
159     def get_rref(self):
160         return self.rref
161     def get_version(self):
162         return self.rref.version
163
164 class NoNetworkStorageBroker:
165     implements(IStorageBroker)
166     def get_servers_for_psi(self, peer_selection_index):
167         def _permuted(server):
168             seed = server.get_permutation_seed()
169             return sha1(peer_selection_index + seed).digest()
170         return sorted(self.get_connected_servers(), key=_permuted)
171     def get_connected_servers(self):
172         return self.client._servers
173     def get_nickname_for_serverid(self, serverid):
174         return None
175
176 class NoNetworkClient(Client):
177     def create_tub(self):
178         pass
179     def init_introducer_client(self):
180         pass
181     def setup_logging(self):
182         pass
183     def startService(self):
184         service.MultiService.startService(self)
185     def stopService(self):
186         service.MultiService.stopService(self)
187     def when_tub_ready(self):
188         raise NotImplementedError("NoNetworkClient has no Tub")
189     def init_control(self):
190         pass
191     def init_helper(self):
192         pass
193     def init_key_gen(self):
194         pass
195     def init_storage(self):
196         pass
197     def init_client_storage_broker(self):
198         self.storage_broker = NoNetworkStorageBroker()
199         self.storage_broker.client = self
200     def init_stub_client(self):
201         pass
202     #._servers will be set by the NoNetworkGrid which creates us
203
204 class SimpleStats:
205     def __init__(self):
206         self.counters = {}
207         self.stats_producers = []
208
209     def count(self, name, delta=1):
210         val = self.counters.setdefault(name, 0)
211         self.counters[name] = val + delta
212
213     def register_producer(self, stats_producer):
214         self.stats_producers.append(stats_producer)
215
216     def get_stats(self):
217         stats = {}
218         for sp in self.stats_producers:
219             stats.update(sp.get_stats())
220         ret = { 'counters': self.counters, 'stats': stats }
221         return ret
222
223 class NoNetworkGrid(service.MultiService):
224     def __init__(self, basedir, num_clients=1, num_servers=10,
225                  client_config_hooks={}):
226         service.MultiService.__init__(self)
227         self.basedir = basedir
228         fileutil.make_dirs(basedir)
229
230         self.servers_by_number = {} # maps to StorageServer instance
231         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
232         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
233                                 # StorageServer
234         self.clients = []
235
236         for i in range(num_servers):
237             ss = self.make_server(i)
238             self.add_server(i, ss)
239         self.rebuild_serverlist()
240
241         for i in range(num_clients):
242             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
243             clientdir = os.path.join(basedir, "clients",
244                                      idlib.shortnodeid_b2a(clientid))
245             fileutil.make_dirs(clientdir)
246             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
247             f.write("[node]\n")
248             f.write("nickname = client-%d\n" % i)
249             f.write("web.port = tcp:0:interface=127.0.0.1\n")
250             f.write("[storage]\n")
251             f.write("enabled = false\n")
252             f.close()
253             c = None
254             if i in client_config_hooks:
255                 # this hook can either modify tahoe.cfg, or return an
256                 # entirely new Client instance
257                 c = client_config_hooks[i](clientdir)
258             if not c:
259                 c = NoNetworkClient(clientdir)
260                 c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
261             c.nodeid = clientid
262             c.short_nodeid = b32encode(clientid).lower()[:8]
263             c._servers = self.all_servers # can be updated later
264             c.setServiceParent(self)
265             self.clients.append(c)
266
267     def make_server(self, i, readonly=False):
268         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
269         serverdir = os.path.join(self.basedir, "servers",
270                                  idlib.shortnodeid_b2a(serverid), "storage")
271         fileutil.make_dirs(serverdir)
272         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
273                            readonly_storage=readonly)
274         ss._no_network_server_number = i
275         return ss
276
277     def add_server(self, i, ss):
278         # to deal with the fact that all StorageServers are named 'storage',
279         # we interpose a middleman
280         middleman = service.MultiService()
281         middleman.setServiceParent(self)
282         ss.setServiceParent(middleman)
283         serverid = ss.my_nodeid
284         self.servers_by_number[i] = ss
285         wrapper = wrap_storage_server(ss)
286         self.wrappers_by_id[serverid] = wrapper
287         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
288         self.rebuild_serverlist()
289
290     def get_all_serverids(self):
291         return self.proxies_by_id.keys()
292
293     def rebuild_serverlist(self):
294         self.all_servers = frozenset(self.proxies_by_id.values())
295         for c in self.clients:
296             c._servers = self.all_servers
297
298     def remove_server(self, serverid):
299         # it's enough to remove the server from c._servers (we don't actually
300         # have to detach and stopService it)
301         for i,ss in self.servers_by_number.items():
302             if ss.my_nodeid == serverid:
303                 del self.servers_by_number[i]
304                 break
305         del self.wrappers_by_id[serverid]
306         del self.proxies_by_id[serverid]
307         self.rebuild_serverlist()
308         return ss
309
310     def break_server(self, serverid, count=True):
311         # mark the given server as broken, so it will throw exceptions when
312         # asked to hold a share or serve a share. If count= is a number,
313         # throw that many exceptions before starting to work again.
314         self.wrappers_by_id[serverid].broken = count
315
316     def hang_server(self, serverid):
317         # hang the given server
318         ss = self.wrappers_by_id[serverid]
319         assert ss.hung_until is None
320         ss.hung_until = defer.Deferred()
321
322     def unhang_server(self, serverid):
323         # unhang the given server
324         ss = self.wrappers_by_id[serverid]
325         assert ss.hung_until is not None
326         ss.hung_until.callback(None)
327         ss.hung_until = None
328
329     def nuke_from_orbit(self):
330         """ Empty all share directories in this grid. It's the only way to be sure ;-) """
331         for server in self.servers_by_number.values():
332             for prefixdir in os.listdir(server.sharedir):
333                 if prefixdir != 'incoming':
334                     fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
335
336
337 class GridTestMixin:
338     def setUp(self):
339         self.s = service.MultiService()
340         self.s.startService()
341
342     def tearDown(self):
343         return self.s.stopService()
344
345     def set_up_grid(self, num_clients=1, num_servers=10,
346                     client_config_hooks={}):
347         # self.basedir must be set
348         self.g = NoNetworkGrid(self.basedir,
349                                num_clients=num_clients,
350                                num_servers=num_servers,
351                                client_config_hooks=client_config_hooks)
352         self.g.setServiceParent(self.s)
353         self.client_webports = [c.getServiceNamed("webish").getPortnum()
354                                 for c in self.g.clients]
355         self.client_baseurls = [c.getServiceNamed("webish").getURL()
356                                 for c in self.g.clients]
357
358     def get_clientdir(self, i=0):
359         return self.g.clients[i].basedir
360
361     def get_serverdir(self, i):
362         return self.g.servers_by_number[i].storedir
363
364     def iterate_servers(self):
365         for i in sorted(self.g.servers_by_number.keys()):
366             ss = self.g.servers_by_number[i]
367             yield (i, ss, ss.storedir)
368
369     def find_uri_shares(self, uri):
370         si = tahoe_uri.from_string(uri).get_storage_index()
371         prefixdir = storage_index_to_dir(si)
372         shares = []
373         for i,ss in self.g.servers_by_number.items():
374             serverid = ss.my_nodeid
375             basedir = os.path.join(ss.sharedir, prefixdir)
376             if not os.path.exists(basedir):
377                 continue
378             for f in os.listdir(basedir):
379                 try:
380                     shnum = int(f)
381                     shares.append((shnum, serverid, os.path.join(basedir, f)))
382                 except ValueError:
383                     pass
384         return sorted(shares)
385
386     def copy_shares(self, uri):
387         shares = {}
388         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
389             shares[sharefile] = open(sharefile, "rb").read()
390         return shares
391
392     def restore_all_shares(self, shares):
393         for sharefile, data in shares.items():
394             open(sharefile, "wb").write(data)
395
396     def delete_share(self, (shnum, serverid, sharefile)):
397         os.unlink(sharefile)
398
399     def delete_shares_numbered(self, uri, shnums):
400         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
401             if i_shnum in shnums:
402                 os.unlink(i_sharefile)
403
404     def delete_all_shares(self, serverdir):
405         sharedir = os.path.join(serverdir, "shares")
406         for prefixdir in os.listdir(sharedir):
407             if prefixdir != 'incoming':
408                 fileutil.rm_dir(os.path.join(sharedir, prefixdir))
409
410     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
411         sharedata = open(sharefile, "rb").read()
412         corruptdata = corruptor_function(sharedata)
413         open(sharefile, "wb").write(corruptdata)
414
415     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
416         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
417             if i_shnum in shnums:
418                 sharedata = open(i_sharefile, "rb").read()
419                 corruptdata = corruptor(sharedata, debug=debug)
420                 open(i_sharefile, "wb").write(corruptdata)
421
422     def corrupt_all_shares(self, uri, corruptor, debug=False):
423         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
424             sharedata = open(i_sharefile, "rb").read()
425             corruptdata = corruptor(sharedata, debug=debug)
426             open(i_sharefile, "wb").write(corruptdata)
427
428     def GET(self, urlpath, followRedirect=False, return_response=False,
429             method="GET", clientnum=0, **kwargs):
430         # if return_response=True, this fires with (data, statuscode,
431         # respheaders) instead of just data.
432         assert not isinstance(urlpath, unicode)
433         url = self.client_baseurls[clientnum] + urlpath
434         factory = HTTPClientGETFactory(url, method=method,
435                                        followRedirect=followRedirect, **kwargs)
436         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
437         d = factory.deferred
438         def _got_data(data):
439             return (data, factory.status, factory.response_headers)
440         if return_response:
441             d.addCallback(_got_data)
442         return factory.deferred
443
444     def PUT(self, urlpath, **kwargs):
445         return self.GET(urlpath, method="PUT", **kwargs)