]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
Tolerate Twisted-10.2's endpoints, patch by David-Sarah. Closes #1286.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker
30
31 class IntentionalError(Exception):
32     pass
33
34 class Marker:
35     pass
36
37 class LocalWrapper:
38     def __init__(self, original):
39         self.original = original
40         self.broken = False
41         self.hung_until = None
42         self.post_call_notifier = None
43         self.disconnectors = {}
44
45     def callRemoteOnly(self, methname, *args, **kwargs):
46         d = self.callRemote(methname, *args, **kwargs)
47         del d # explicitly ignored
48         return None
49
50     def callRemote(self, methname, *args, **kwargs):
51         # this is ideally a Membrane, but that's too hard. We do a shallow
52         # wrapping of inbound arguments, and per-methodname wrapping of
53         # selected return values.
54         def wrap(a):
55             if isinstance(a, Referenceable):
56                 return LocalWrapper(a)
57             else:
58                 return a
59         args = tuple([wrap(a) for a in args])
60         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
61
62         def _really_call():
63             meth = getattr(self.original, "remote_" + methname)
64             return meth(*args, **kwargs)
65
66         def _call():
67             if self.broken:
68                 raise IntentionalError("I was asked to break")
69             if self.hung_until:
70                 d2 = defer.Deferred()
71                 self.hung_until.addCallback(lambda ign: _really_call())
72                 self.hung_until.addCallback(lambda res: d2.callback(res))
73                 def _err(res):
74                     d2.errback(res)
75                     return res
76                 self.hung_until.addErrback(_err)
77                 return d2
78             return _really_call()
79
80         d = fireEventually()
81         d.addCallback(lambda res: _call())
82         def _wrap_exception(f):
83             return Failure(RemoteException(f))
84         d.addErrback(_wrap_exception)
85         def _return_membrane(res):
86             # rather than complete the difficult task of building a
87             # fully-general Membrane (which would locate all Referenceable
88             # objects that cross the simulated wire and replace them with
89             # wrappers), we special-case certain methods that we happen to
90             # know will return Referenceables.
91             if methname == "allocate_buckets":
92                 (alreadygot, allocated) = res
93                 for shnum in allocated:
94                     allocated[shnum] = LocalWrapper(allocated[shnum])
95             if methname == "get_buckets":
96                 for shnum in res:
97                     res[shnum] = LocalWrapper(res[shnum])
98             return res
99         d.addCallback(_return_membrane)
100         if self.post_call_notifier:
101             d.addCallback(self.post_call_notifier, self, methname)
102         return d
103
104     def notifyOnDisconnect(self, f, *args, **kwargs):
105         m = Marker()
106         self.disconnectors[m] = (f, args, kwargs)
107         return m
108     def dontNotifyOnDisconnect(self, marker):
109         del self.disconnectors[marker]
110
111 def wrap_storage_server(original):
112     # Much of the upload/download code uses rref.version (which normally
113     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
114     # network, we want a LocalWrapper here. Try to satisfy all these
115     # constraints at the same time.
116     wrapper = LocalWrapper(original)
117     wrapper.version = original.remote_get_version()
118     return wrapper
119
120 class NoNetworkStorageBroker:
121     implements(IStorageBroker)
122     def get_servers_for_index(self, key):
123         return sorted(self.client._servers,
124                       key=lambda x: sha1(key+x[0]).digest())
125     def get_all_servers(self):
126         return frozenset(self.client._servers)
127     def get_nickname_for_serverid(self, serverid):
128         return None
129
130 class NoNetworkClient(Client):
131     def create_tub(self):
132         pass
133     def init_introducer_client(self):
134         pass
135     def setup_logging(self):
136         pass
137     def startService(self):
138         service.MultiService.startService(self)
139     def stopService(self):
140         service.MultiService.stopService(self)
141     def when_tub_ready(self):
142         raise NotImplementedError("NoNetworkClient has no Tub")
143     def init_control(self):
144         pass
145     def init_helper(self):
146         pass
147     def init_key_gen(self):
148         pass
149     def init_storage(self):
150         pass
151     def init_client_storage_broker(self):
152         self.storage_broker = NoNetworkStorageBroker()
153         self.storage_broker.client = self
154     def init_stub_client(self):
155         pass
156     #._servers will be set by the NoNetworkGrid which creates us
157
158 class SimpleStats:
159     def __init__(self):
160         self.counters = {}
161         self.stats_producers = []
162
163     def count(self, name, delta=1):
164         val = self.counters.setdefault(name, 0)
165         self.counters[name] = val + delta
166
167     def register_producer(self, stats_producer):
168         self.stats_producers.append(stats_producer)
169
170     def get_stats(self):
171         stats = {}
172         for sp in self.stats_producers:
173             stats.update(sp.get_stats())
174         ret = { 'counters': self.counters, 'stats': stats }
175         return ret
176
177 class NoNetworkGrid(service.MultiService):
178     def __init__(self, basedir, num_clients=1, num_servers=10,
179                  client_config_hooks={}):
180         service.MultiService.__init__(self)
181         self.basedir = basedir
182         fileutil.make_dirs(basedir)
183
184         self.servers_by_number = {}
185         self.servers_by_id = {}
186         self.clients = []
187
188         for i in range(num_servers):
189             ss = self.make_server(i)
190             self.add_server(i, ss)
191         self.rebuild_serverlist()
192
193         for i in range(num_clients):
194             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
195             clientdir = os.path.join(basedir, "clients",
196                                      idlib.shortnodeid_b2a(clientid))
197             fileutil.make_dirs(clientdir)
198             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
199             f.write("[node]\n")
200             f.write("nickname = client-%d\n" % i)
201             f.write("web.port = tcp:0:interface=127.0.0.1\n")
202             f.write("[storage]\n")
203             f.write("enabled = false\n")
204             f.close()
205             c = None
206             if i in client_config_hooks:
207                 # this hook can either modify tahoe.cfg, or return an
208                 # entirely new Client instance
209                 c = client_config_hooks[i](clientdir)
210             if not c:
211                 c = NoNetworkClient(clientdir)
212                 c.set_default_mutable_keysize(522)
213             c.nodeid = clientid
214             c.short_nodeid = b32encode(clientid).lower()[:8]
215             c._servers = self.all_servers # can be updated later
216             c.setServiceParent(self)
217             self.clients.append(c)
218
219     def make_server(self, i, readonly=False):
220         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
221         serverdir = os.path.join(self.basedir, "servers",
222                                  idlib.shortnodeid_b2a(serverid))
223         fileutil.make_dirs(serverdir)
224         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
225                            readonly_storage=readonly)
226         ss._no_network_server_number = i
227         return ss
228
229     def add_server(self, i, ss):
230         # to deal with the fact that all StorageServers are named 'storage',
231         # we interpose a middleman
232         middleman = service.MultiService()
233         middleman.setServiceParent(self)
234         ss.setServiceParent(middleman)
235         serverid = ss.my_nodeid
236         self.servers_by_number[i] = ss
237         self.servers_by_id[serverid] = wrap_storage_server(ss)
238         self.rebuild_serverlist()
239
240     def rebuild_serverlist(self):
241         self.all_servers = frozenset(self.servers_by_id.items())
242         for c in self.clients:
243             c._servers = self.all_servers
244
245     def remove_server(self, serverid):
246         # it's enough to remove the server from c._servers (we don't actually
247         # have to detach and stopService it)
248         for i,ss in self.servers_by_number.items():
249             if ss.my_nodeid == serverid:
250                 del self.servers_by_number[i]
251                 break
252         del self.servers_by_id[serverid]
253         self.rebuild_serverlist()
254
255     def break_server(self, serverid):
256         # mark the given server as broken, so it will throw exceptions when
257         # asked to hold a share or serve a share
258         self.servers_by_id[serverid].broken = True
259
260     def hang_server(self, serverid):
261         # hang the given server
262         ss = self.servers_by_id[serverid]
263         assert ss.hung_until is None
264         ss.hung_until = defer.Deferred()
265
266     def unhang_server(self, serverid):
267         # unhang the given server
268         ss = self.servers_by_id[serverid]
269         assert ss.hung_until is not None
270         ss.hung_until.callback(None)
271         ss.hung_until = None
272
273
274 class GridTestMixin:
275     def setUp(self):
276         self.s = service.MultiService()
277         self.s.startService()
278
279     def tearDown(self):
280         return self.s.stopService()
281
282     def set_up_grid(self, num_clients=1, num_servers=10,
283                     client_config_hooks={}):
284         # self.basedir must be set
285         self.g = NoNetworkGrid(self.basedir,
286                                num_clients=num_clients,
287                                num_servers=num_servers,
288                                client_config_hooks=client_config_hooks)
289         self.g.setServiceParent(self.s)
290         self.client_webports = [c.getServiceNamed("webish").getPortnum()
291                                 for c in self.g.clients]
292         self.client_baseurls = [c.getServiceNamed("webish").getURL()
293                                 for c in self.g.clients]
294
295     def get_clientdir(self, i=0):
296         return self.g.clients[i].basedir
297
298     def get_serverdir(self, i):
299         return self.g.servers_by_number[i].storedir
300
301     def iterate_servers(self):
302         for i in sorted(self.g.servers_by_number.keys()):
303             ss = self.g.servers_by_number[i]
304             yield (i, ss, ss.storedir)
305
306     def find_uri_shares(self, uri):
307         si = tahoe_uri.from_string(uri).get_storage_index()
308         prefixdir = storage_index_to_dir(si)
309         shares = []
310         for i,ss in self.g.servers_by_number.items():
311             serverid = ss.my_nodeid
312             basedir = os.path.join(ss.storedir, "shares", prefixdir)
313             if not os.path.exists(basedir):
314                 continue
315             for f in os.listdir(basedir):
316                 try:
317                     shnum = int(f)
318                     shares.append((shnum, serverid, os.path.join(basedir, f)))
319                 except ValueError:
320                     pass
321         return sorted(shares)
322
323     def copy_shares(self, uri):
324         shares = {}
325         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
326             shares[sharefile] = open(sharefile, "rb").read()
327         return shares
328
329     def restore_all_shares(self, shares):
330         for sharefile, data in shares.items():
331             open(sharefile, "wb").write(data)
332
333     def delete_share(self, (shnum, serverid, sharefile)):
334         os.unlink(sharefile)
335
336     def delete_shares_numbered(self, uri, shnums):
337         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
338             if i_shnum in shnums:
339                 os.unlink(i_sharefile)
340
341     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
342         sharedata = open(sharefile, "rb").read()
343         corruptdata = corruptor_function(sharedata)
344         open(sharefile, "wb").write(corruptdata)
345
346     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
347         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
348             if i_shnum in shnums:
349                 sharedata = open(i_sharefile, "rb").read()
350                 corruptdata = corruptor(sharedata, debug=debug)
351                 open(i_sharefile, "wb").write(corruptdata)
352
353     def corrupt_all_shares(self, uri, corruptor, debug=False):
354         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
355             sharedata = open(i_sharefile, "rb").read()
356             corruptdata = corruptor(sharedata, debug=debug)
357             open(i_sharefile, "wb").write(corruptdata)
358
359     def GET(self, urlpath, followRedirect=False, return_response=False,
360             method="GET", clientnum=0, **kwargs):
361         # if return_response=True, this fires with (data, statuscode,
362         # respheaders) instead of just data.
363         assert not isinstance(urlpath, unicode)
364         url = self.client_baseurls[clientnum] + urlpath
365         factory = HTTPClientGETFactory(url, method=method,
366                                        followRedirect=followRedirect, **kwargs)
367         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
368         d = factory.deferred
369         def _got_data(data):
370             return (data, factory.status, factory.response_headers)
371         if return_response:
372             d.addCallback(_got_data)
373         return factory.deferred