]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
5654c2cedfe7da723f2d4a10ac771d9cd199a71e
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker
30 from allmydata.test.common import TEST_RSA_KEY_SIZE
31
32
33 class IntentionalError(Exception):
34     pass
35
36 class Marker:
37     pass
38
39 class LocalWrapper:
40     def __init__(self, original):
41         self.original = original
42         self.broken = False
43         self.hung_until = None
44         self.post_call_notifier = None
45         self.disconnectors = {}
46
47     def callRemoteOnly(self, methname, *args, **kwargs):
48         d = self.callRemote(methname, *args, **kwargs)
49         del d # explicitly ignored
50         return None
51
52     def callRemote(self, methname, *args, **kwargs):
53         # this is ideally a Membrane, but that's too hard. We do a shallow
54         # wrapping of inbound arguments, and per-methodname wrapping of
55         # selected return values.
56         def wrap(a):
57             if isinstance(a, Referenceable):
58                 return LocalWrapper(a)
59             else:
60                 return a
61         args = tuple([wrap(a) for a in args])
62         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
63
64         def _really_call():
65             meth = getattr(self.original, "remote_" + methname)
66             return meth(*args, **kwargs)
67
68         def _call():
69             if self.broken:
70                 raise IntentionalError("I was asked to break")
71             if self.hung_until:
72                 d2 = defer.Deferred()
73                 self.hung_until.addCallback(lambda ign: _really_call())
74                 self.hung_until.addCallback(lambda res: d2.callback(res))
75                 def _err(res):
76                     d2.errback(res)
77                     return res
78                 self.hung_until.addErrback(_err)
79                 return d2
80             return _really_call()
81
82         d = fireEventually()
83         d.addCallback(lambda res: _call())
84         def _wrap_exception(f):
85             return Failure(RemoteException(f))
86         d.addErrback(_wrap_exception)
87         def _return_membrane(res):
88             # rather than complete the difficult task of building a
89             # fully-general Membrane (which would locate all Referenceable
90             # objects that cross the simulated wire and replace them with
91             # wrappers), we special-case certain methods that we happen to
92             # know will return Referenceables.
93             if methname == "allocate_buckets":
94                 (alreadygot, allocated) = res
95                 for shnum in allocated:
96                     allocated[shnum] = LocalWrapper(allocated[shnum])
97             if methname == "get_buckets":
98                 for shnum in res:
99                     res[shnum] = LocalWrapper(res[shnum])
100             return res
101         d.addCallback(_return_membrane)
102         if self.post_call_notifier:
103             d.addCallback(self.post_call_notifier, self, methname)
104         return d
105
106     def notifyOnDisconnect(self, f, *args, **kwargs):
107         m = Marker()
108         self.disconnectors[m] = (f, args, kwargs)
109         return m
110     def dontNotifyOnDisconnect(self, marker):
111         del self.disconnectors[marker]
112
113 def wrap_storage_server(original):
114     # Much of the upload/download code uses rref.version (which normally
115     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
116     # network, we want a LocalWrapper here. Try to satisfy all these
117     # constraints at the same time.
118     wrapper = LocalWrapper(original)
119     wrapper.version = original.remote_get_version()
120     return wrapper
121
122 class NoNetworkServer:
123     def __init__(self, serverid, rref):
124         self.serverid = serverid
125         self.rref = rref
126     def __repr__(self):
127         return "<NoNetworkServer for %s>" % self.get_name()
128     def get_serverid(self):
129         return self.serverid
130     def get_permutation_seed(self):
131         return self.serverid
132     def get_lease_seed(self):
133         return self.serverid
134     def get_name(self):
135         return idlib.shortnodeid_b2a(self.serverid)
136     def get_longname(self):
137         return idlib.nodeid_b2a(self.serverid)
138     def get_nickname(self):
139         return "nickname"
140     def get_rref(self):
141         return self.rref
142     def get_version(self):
143         return self.rref.version
144
145 class NoNetworkStorageBroker:
146     implements(IStorageBroker)
147     def get_servers_for_psi(self, peer_selection_index):
148         def _permuted(server):
149             seed = server.get_permutation_seed()
150             return sha1(peer_selection_index + seed).digest()
151         return sorted(self.get_connected_servers(), key=_permuted)
152     def get_connected_servers(self):
153         return self.client._servers
154     def get_nickname_for_serverid(self, serverid):
155         return None
156
157 class NoNetworkClient(Client):
158     def create_tub(self):
159         pass
160     def init_introducer_client(self):
161         pass
162     def setup_logging(self):
163         pass
164     def startService(self):
165         service.MultiService.startService(self)
166     def stopService(self):
167         service.MultiService.stopService(self)
168     def when_tub_ready(self):
169         raise NotImplementedError("NoNetworkClient has no Tub")
170     def init_control(self):
171         pass
172     def init_helper(self):
173         pass
174     def init_key_gen(self):
175         pass
176     def init_storage(self):
177         pass
178     def init_client_storage_broker(self):
179         self.storage_broker = NoNetworkStorageBroker()
180         self.storage_broker.client = self
181     def init_stub_client(self):
182         pass
183     #._servers will be set by the NoNetworkGrid which creates us
184
185 class SimpleStats:
186     def __init__(self):
187         self.counters = {}
188         self.stats_producers = []
189
190     def count(self, name, delta=1):
191         val = self.counters.setdefault(name, 0)
192         self.counters[name] = val + delta
193
194     def register_producer(self, stats_producer):
195         self.stats_producers.append(stats_producer)
196
197     def get_stats(self):
198         stats = {}
199         for sp in self.stats_producers:
200             stats.update(sp.get_stats())
201         ret = { 'counters': self.counters, 'stats': stats }
202         return ret
203
204 class NoNetworkGrid(service.MultiService):
205     def __init__(self, basedir, num_clients=1, num_servers=10,
206                  client_config_hooks={}):
207         service.MultiService.__init__(self)
208         self.basedir = basedir
209         fileutil.make_dirs(basedir)
210
211         self.servers_by_number = {} # maps to StorageServer instance
212         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
213         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
214                                 # StorageServer
215         self.clients = []
216
217         for i in range(num_servers):
218             ss = self.make_server(i)
219             self.add_server(i, ss)
220         self.rebuild_serverlist()
221
222         for i in range(num_clients):
223             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
224             clientdir = os.path.join(basedir, "clients",
225                                      idlib.shortnodeid_b2a(clientid))
226             fileutil.make_dirs(clientdir)
227             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
228             f.write("[node]\n")
229             f.write("nickname = client-%d\n" % i)
230             f.write("web.port = tcp:0:interface=127.0.0.1\n")
231             f.write("[storage]\n")
232             f.write("enabled = false\n")
233             f.close()
234             c = None
235             if i in client_config_hooks:
236                 # this hook can either modify tahoe.cfg, or return an
237                 # entirely new Client instance
238                 c = client_config_hooks[i](clientdir)
239             if not c:
240                 c = NoNetworkClient(clientdir)
241                 c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
242             c.nodeid = clientid
243             c.short_nodeid = b32encode(clientid).lower()[:8]
244             c._servers = self.all_servers # can be updated later
245             c.setServiceParent(self)
246             self.clients.append(c)
247
248     def make_server(self, i, readonly=False):
249         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
250         serverdir = os.path.join(self.basedir, "servers",
251                                  idlib.shortnodeid_b2a(serverid))
252         fileutil.make_dirs(serverdir)
253         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
254                            readonly_storage=readonly)
255         ss._no_network_server_number = i
256         return ss
257
258     def add_server(self, i, ss):
259         # to deal with the fact that all StorageServers are named 'storage',
260         # we interpose a middleman
261         middleman = service.MultiService()
262         middleman.setServiceParent(self)
263         ss.setServiceParent(middleman)
264         serverid = ss.my_nodeid
265         self.servers_by_number[i] = ss
266         wrapper = wrap_storage_server(ss)
267         self.wrappers_by_id[serverid] = wrapper
268         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
269         self.rebuild_serverlist()
270
271     def get_all_serverids(self):
272         return self.proxies_by_id.keys()
273
274     def rebuild_serverlist(self):
275         self.all_servers = frozenset(self.proxies_by_id.values())
276         for c in self.clients:
277             c._servers = self.all_servers
278
279     def remove_server(self, serverid):
280         # it's enough to remove the server from c._servers (we don't actually
281         # have to detach and stopService it)
282         for i,ss in self.servers_by_number.items():
283             if ss.my_nodeid == serverid:
284                 del self.servers_by_number[i]
285                 break
286         del self.wrappers_by_id[serverid]
287         del self.proxies_by_id[serverid]
288         self.rebuild_serverlist()
289
290     def break_server(self, serverid):
291         # mark the given server as broken, so it will throw exceptions when
292         # asked to hold a share or serve a share
293         self.wrappers_by_id[serverid].broken = True
294
295     def hang_server(self, serverid):
296         # hang the given server
297         ss = self.wrappers_by_id[serverid]
298         assert ss.hung_until is None
299         ss.hung_until = defer.Deferred()
300
301     def unhang_server(self, serverid):
302         # unhang the given server
303         ss = self.wrappers_by_id[serverid]
304         assert ss.hung_until is not None
305         ss.hung_until.callback(None)
306         ss.hung_until = None
307
308
309 class GridTestMixin:
310     def setUp(self):
311         self.s = service.MultiService()
312         self.s.startService()
313
314     def tearDown(self):
315         return self.s.stopService()
316
317     def set_up_grid(self, num_clients=1, num_servers=10,
318                     client_config_hooks={}):
319         # self.basedir must be set
320         self.g = NoNetworkGrid(self.basedir,
321                                num_clients=num_clients,
322                                num_servers=num_servers,
323                                client_config_hooks=client_config_hooks)
324         self.g.setServiceParent(self.s)
325         self.client_webports = [c.getServiceNamed("webish").getPortnum()
326                                 for c in self.g.clients]
327         self.client_baseurls = [c.getServiceNamed("webish").getURL()
328                                 for c in self.g.clients]
329
330     def get_clientdir(self, i=0):
331         return self.g.clients[i].basedir
332
333     def get_serverdir(self, i):
334         return self.g.servers_by_number[i].storedir
335
336     def iterate_servers(self):
337         for i in sorted(self.g.servers_by_number.keys()):
338             ss = self.g.servers_by_number[i]
339             yield (i, ss, ss.storedir)
340
341     def find_uri_shares(self, uri):
342         si = tahoe_uri.from_string(uri).get_storage_index()
343         prefixdir = storage_index_to_dir(si)
344         shares = []
345         for i,ss in self.g.servers_by_number.items():
346             serverid = ss.my_nodeid
347             basedir = os.path.join(ss.storedir, "shares", prefixdir)
348             if not os.path.exists(basedir):
349                 continue
350             for f in os.listdir(basedir):
351                 try:
352                     shnum = int(f)
353                     shares.append((shnum, serverid, os.path.join(basedir, f)))
354                 except ValueError:
355                     pass
356         return sorted(shares)
357
358     def copy_shares(self, uri):
359         shares = {}
360         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
361             shares[sharefile] = open(sharefile, "rb").read()
362         return shares
363
364     def restore_all_shares(self, shares):
365         for sharefile, data in shares.items():
366             open(sharefile, "wb").write(data)
367
368     def delete_share(self, (shnum, serverid, sharefile)):
369         os.unlink(sharefile)
370
371     def delete_shares_numbered(self, uri, shnums):
372         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
373             if i_shnum in shnums:
374                 os.unlink(i_sharefile)
375
376     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
377         sharedata = open(sharefile, "rb").read()
378         corruptdata = corruptor_function(sharedata)
379         open(sharefile, "wb").write(corruptdata)
380
381     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
382         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
383             if i_shnum in shnums:
384                 sharedata = open(i_sharefile, "rb").read()
385                 corruptdata = corruptor(sharedata, debug=debug)
386                 open(i_sharefile, "wb").write(corruptdata)
387
388     def corrupt_all_shares(self, uri, corruptor, debug=False):
389         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
390             sharedata = open(i_sharefile, "rb").read()
391             corruptdata = corruptor(sharedata, debug=debug)
392             open(i_sharefile, "wb").write(corruptdata)
393
394     def GET(self, urlpath, followRedirect=False, return_response=False,
395             method="GET", clientnum=0, **kwargs):
396         # if return_response=True, this fires with (data, statuscode,
397         # respheaders) instead of just data.
398         assert not isinstance(urlpath, unicode)
399         url = self.client_baseurls[clientnum] + urlpath
400         factory = HTTPClientGETFactory(url, method=method,
401                                        followRedirect=followRedirect, **kwargs)
402         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
403         d = factory.deferred
404         def _got_data(data):
405             return (data, factory.status, factory.response_headers)
406         if return_response:
407             d.addCallback(_got_data)
408         return factory.deferred
409
410     def PUT(self, urlpath, **kwargs):
411         return self.GET(urlpath, method="PUT", **kwargs)