]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
Replace the hard-coded 522-bit RSA key size used for tests with a TEST_RSA_KEY_SIZE...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker
30
31 class IntentionalError(Exception):
32     pass
33
34 class Marker:
35     pass
36
37 class LocalWrapper:
38     def __init__(self, original):
39         self.original = original
40         self.broken = False
41         self.hung_until = None
42         self.post_call_notifier = None
43         self.disconnectors = {}
44
45     def callRemoteOnly(self, methname, *args, **kwargs):
46         d = self.callRemote(methname, *args, **kwargs)
47         del d # explicitly ignored
48         return None
49
50     def callRemote(self, methname, *args, **kwargs):
51         # this is ideally a Membrane, but that's too hard. We do a shallow
52         # wrapping of inbound arguments, and per-methodname wrapping of
53         # selected return values.
54         def wrap(a):
55             if isinstance(a, Referenceable):
56                 return LocalWrapper(a)
57             else:
58                 return a
59         args = tuple([wrap(a) for a in args])
60         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
61
62         def _really_call():
63             meth = getattr(self.original, "remote_" + methname)
64             return meth(*args, **kwargs)
65
66         def _call():
67             if self.broken:
68                 raise IntentionalError("I was asked to break")
69             if self.hung_until:
70                 d2 = defer.Deferred()
71                 self.hung_until.addCallback(lambda ign: _really_call())
72                 self.hung_until.addCallback(lambda res: d2.callback(res))
73                 def _err(res):
74                     d2.errback(res)
75                     return res
76                 self.hung_until.addErrback(_err)
77                 return d2
78             return _really_call()
79
80         d = fireEventually()
81         d.addCallback(lambda res: _call())
82         def _wrap_exception(f):
83             return Failure(RemoteException(f))
84         d.addErrback(_wrap_exception)
85         def _return_membrane(res):
86             # rather than complete the difficult task of building a
87             # fully-general Membrane (which would locate all Referenceable
88             # objects that cross the simulated wire and replace them with
89             # wrappers), we special-case certain methods that we happen to
90             # know will return Referenceables.
91             if methname == "allocate_buckets":
92                 (alreadygot, allocated) = res
93                 for shnum in allocated:
94                     allocated[shnum] = LocalWrapper(allocated[shnum])
95             if methname == "get_buckets":
96                 for shnum in res:
97                     res[shnum] = LocalWrapper(res[shnum])
98             return res
99         d.addCallback(_return_membrane)
100         if self.post_call_notifier:
101             d.addCallback(self.post_call_notifier, self, methname)
102         return d
103
104     def notifyOnDisconnect(self, f, *args, **kwargs):
105         m = Marker()
106         self.disconnectors[m] = (f, args, kwargs)
107         return m
108     def dontNotifyOnDisconnect(self, marker):
109         del self.disconnectors[marker]
110
111 def wrap_storage_server(original):
112     # Much of the upload/download code uses rref.version (which normally
113     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
114     # network, we want a LocalWrapper here. Try to satisfy all these
115     # constraints at the same time.
116     wrapper = LocalWrapper(original)
117     wrapper.version = original.remote_get_version()
118     return wrapper
119
120 class NoNetworkServer:
121     def __init__(self, serverid, rref):
122         self.serverid = serverid
123         self.rref = rref
124     def __repr__(self):
125         return "<NoNetworkServer for %s>" % self.get_name()
126     def get_serverid(self):
127         return self.serverid
128     def get_permutation_seed(self):
129         return self.serverid
130     def get_lease_seed(self):
131         return self.serverid
132     def get_name(self):
133         return idlib.shortnodeid_b2a(self.serverid)
134     def get_longname(self):
135         return idlib.nodeid_b2a(self.serverid)
136     def get_nickname(self):
137         return "nickname"
138     def get_rref(self):
139         return self.rref
140     def get_version(self):
141         return self.rref.version
142
143 class NoNetworkStorageBroker:
144     implements(IStorageBroker)
145     def get_servers_for_psi(self, peer_selection_index):
146         def _permuted(server):
147             seed = server.get_permutation_seed()
148             return sha1(peer_selection_index + seed).digest()
149         return sorted(self.get_connected_servers(), key=_permuted)
150     def get_connected_servers(self):
151         return self.client._servers
152     def get_nickname_for_serverid(self, serverid):
153         return None
154
155 class NoNetworkClient(Client):
156     def create_tub(self):
157         pass
158     def init_introducer_client(self):
159         pass
160     def setup_logging(self):
161         pass
162     def startService(self):
163         service.MultiService.startService(self)
164     def stopService(self):
165         service.MultiService.stopService(self)
166     def when_tub_ready(self):
167         raise NotImplementedError("NoNetworkClient has no Tub")
168     def init_control(self):
169         pass
170     def init_helper(self):
171         pass
172     def init_key_gen(self):
173         pass
174     def init_storage(self):
175         pass
176     def init_client_storage_broker(self):
177         self.storage_broker = NoNetworkStorageBroker()
178         self.storage_broker.client = self
179     def init_stub_client(self):
180         pass
181     #._servers will be set by the NoNetworkGrid which creates us
182
183 class SimpleStats:
184     def __init__(self):
185         self.counters = {}
186         self.stats_producers = []
187
188     def count(self, name, delta=1):
189         val = self.counters.setdefault(name, 0)
190         self.counters[name] = val + delta
191
192     def register_producer(self, stats_producer):
193         self.stats_producers.append(stats_producer)
194
195     def get_stats(self):
196         stats = {}
197         for sp in self.stats_producers:
198             stats.update(sp.get_stats())
199         ret = { 'counters': self.counters, 'stats': stats }
200         return ret
201
202 class NoNetworkGrid(service.MultiService):
203     def __init__(self, basedir, num_clients=1, num_servers=10,
204                  client_config_hooks={}):
205         service.MultiService.__init__(self)
206         self.basedir = basedir
207         fileutil.make_dirs(basedir)
208
209         self.servers_by_number = {} # maps to StorageServer instance
210         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
211         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
212                                 # StorageServer
213         self.clients = []
214
215         for i in range(num_servers):
216             ss = self.make_server(i)
217             self.add_server(i, ss)
218         self.rebuild_serverlist()
219
220         for i in range(num_clients):
221             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
222             clientdir = os.path.join(basedir, "clients",
223                                      idlib.shortnodeid_b2a(clientid))
224             fileutil.make_dirs(clientdir)
225             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
226             f.write("[node]\n")
227             f.write("nickname = client-%d\n" % i)
228             f.write("web.port = tcp:0:interface=127.0.0.1\n")
229             f.write("[storage]\n")
230             f.write("enabled = false\n")
231             f.close()
232             c = None
233             if i in client_config_hooks:
234                 # this hook can either modify tahoe.cfg, or return an
235                 # entirely new Client instance
236                 c = client_config_hooks[i](clientdir)
237             if not c:
238                 c = NoNetworkClient(clientdir)
239                 c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
240             c.nodeid = clientid
241             c.short_nodeid = b32encode(clientid).lower()[:8]
242             c._servers = self.all_servers # can be updated later
243             c.setServiceParent(self)
244             self.clients.append(c)
245
246     def make_server(self, i, readonly=False):
247         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
248         serverdir = os.path.join(self.basedir, "servers",
249                                  idlib.shortnodeid_b2a(serverid))
250         fileutil.make_dirs(serverdir)
251         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
252                            readonly_storage=readonly)
253         ss._no_network_server_number = i
254         return ss
255
256     def add_server(self, i, ss):
257         # to deal with the fact that all StorageServers are named 'storage',
258         # we interpose a middleman
259         middleman = service.MultiService()
260         middleman.setServiceParent(self)
261         ss.setServiceParent(middleman)
262         serverid = ss.my_nodeid
263         self.servers_by_number[i] = ss
264         wrapper = wrap_storage_server(ss)
265         self.wrappers_by_id[serverid] = wrapper
266         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
267         self.rebuild_serverlist()
268
269     def get_all_serverids(self):
270         return self.proxies_by_id.keys()
271
272     def rebuild_serverlist(self):
273         self.all_servers = frozenset(self.proxies_by_id.values())
274         for c in self.clients:
275             c._servers = self.all_servers
276
277     def remove_server(self, serverid):
278         # it's enough to remove the server from c._servers (we don't actually
279         # have to detach and stopService it)
280         for i,ss in self.servers_by_number.items():
281             if ss.my_nodeid == serverid:
282                 del self.servers_by_number[i]
283                 break
284         del self.wrappers_by_id[serverid]
285         del self.proxies_by_id[serverid]
286         self.rebuild_serverlist()
287
288     def break_server(self, serverid):
289         # mark the given server as broken, so it will throw exceptions when
290         # asked to hold a share or serve a share
291         self.wrappers_by_id[serverid].broken = True
292
293     def hang_server(self, serverid):
294         # hang the given server
295         ss = self.wrappers_by_id[serverid]
296         assert ss.hung_until is None
297         ss.hung_until = defer.Deferred()
298
299     def unhang_server(self, serverid):
300         # unhang the given server
301         ss = self.wrappers_by_id[serverid]
302         assert ss.hung_until is not None
303         ss.hung_until.callback(None)
304         ss.hung_until = None
305
306
307 class GridTestMixin:
308     def setUp(self):
309         self.s = service.MultiService()
310         self.s.startService()
311
312     def tearDown(self):
313         return self.s.stopService()
314
315     def set_up_grid(self, num_clients=1, num_servers=10,
316                     client_config_hooks={}):
317         # self.basedir must be set
318         self.g = NoNetworkGrid(self.basedir,
319                                num_clients=num_clients,
320                                num_servers=num_servers,
321                                client_config_hooks=client_config_hooks)
322         self.g.setServiceParent(self.s)
323         self.client_webports = [c.getServiceNamed("webish").getPortnum()
324                                 for c in self.g.clients]
325         self.client_baseurls = [c.getServiceNamed("webish").getURL()
326                                 for c in self.g.clients]
327
328     def get_clientdir(self, i=0):
329         return self.g.clients[i].basedir
330
331     def get_serverdir(self, i):
332         return self.g.servers_by_number[i].storedir
333
334     def iterate_servers(self):
335         for i in sorted(self.g.servers_by_number.keys()):
336             ss = self.g.servers_by_number[i]
337             yield (i, ss, ss.storedir)
338
339     def find_uri_shares(self, uri):
340         si = tahoe_uri.from_string(uri).get_storage_index()
341         prefixdir = storage_index_to_dir(si)
342         shares = []
343         for i,ss in self.g.servers_by_number.items():
344             serverid = ss.my_nodeid
345             basedir = os.path.join(ss.storedir, "shares", prefixdir)
346             if not os.path.exists(basedir):
347                 continue
348             for f in os.listdir(basedir):
349                 try:
350                     shnum = int(f)
351                     shares.append((shnum, serverid, os.path.join(basedir, f)))
352                 except ValueError:
353                     pass
354         return sorted(shares)
355
356     def copy_shares(self, uri):
357         shares = {}
358         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
359             shares[sharefile] = open(sharefile, "rb").read()
360         return shares
361
362     def restore_all_shares(self, shares):
363         for sharefile, data in shares.items():
364             open(sharefile, "wb").write(data)
365
366     def delete_share(self, (shnum, serverid, sharefile)):
367         os.unlink(sharefile)
368
369     def delete_shares_numbered(self, uri, shnums):
370         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
371             if i_shnum in shnums:
372                 os.unlink(i_sharefile)
373
374     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
375         sharedata = open(sharefile, "rb").read()
376         corruptdata = corruptor_function(sharedata)
377         open(sharefile, "wb").write(corruptdata)
378
379     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
380         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
381             if i_shnum in shnums:
382                 sharedata = open(i_sharefile, "rb").read()
383                 corruptdata = corruptor(sharedata, debug=debug)
384                 open(i_sharefile, "wb").write(corruptdata)
385
386     def corrupt_all_shares(self, uri, corruptor, debug=False):
387         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
388             sharedata = open(i_sharefile, "rb").read()
389             corruptdata = corruptor(sharedata, debug=debug)
390             open(i_sharefile, "wb").write(corruptdata)
391
392     def GET(self, urlpath, followRedirect=False, return_response=False,
393             method="GET", clientnum=0, **kwargs):
394         # if return_response=True, this fires with (data, statuscode,
395         # respheaders) instead of just data.
396         assert not isinstance(urlpath, unicode)
397         url = self.client_baseurls[clientnum] + urlpath
398         factory = HTTPClientGETFactory(url, method=method,
399                                        followRedirect=followRedirect, **kwargs)
400         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
401         d = factory.deferred
402         def _got_data(data):
403             return (data, factory.status, factory.response_headers)
404         if return_response:
405             d.addCallback(_got_data)
406         return factory.deferred