]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
Refactor StorageFarmBroker handling of servers
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker
30
31 class IntentionalError(Exception):
32     pass
33
34 class Marker:
35     pass
36
37 class LocalWrapper:
38     def __init__(self, original):
39         self.original = original
40         self.broken = False
41         self.hung_until = None
42         self.post_call_notifier = None
43         self.disconnectors = {}
44
45     def callRemoteOnly(self, methname, *args, **kwargs):
46         d = self.callRemote(methname, *args, **kwargs)
47         del d # explicitly ignored
48         return None
49
50     def callRemote(self, methname, *args, **kwargs):
51         # this is ideally a Membrane, but that's too hard. We do a shallow
52         # wrapping of inbound arguments, and per-methodname wrapping of
53         # selected return values.
54         def wrap(a):
55             if isinstance(a, Referenceable):
56                 return LocalWrapper(a)
57             else:
58                 return a
59         args = tuple([wrap(a) for a in args])
60         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
61
62         def _really_call():
63             meth = getattr(self.original, "remote_" + methname)
64             return meth(*args, **kwargs)
65
66         def _call():
67             if self.broken:
68                 raise IntentionalError("I was asked to break")
69             if self.hung_until:
70                 d2 = defer.Deferred()
71                 self.hung_until.addCallback(lambda ign: _really_call())
72                 self.hung_until.addCallback(lambda res: d2.callback(res))
73                 def _err(res):
74                     d2.errback(res)
75                     return res
76                 self.hung_until.addErrback(_err)
77                 return d2
78             return _really_call()
79
80         d = fireEventually()
81         d.addCallback(lambda res: _call())
82         def _wrap_exception(f):
83             return Failure(RemoteException(f))
84         d.addErrback(_wrap_exception)
85         def _return_membrane(res):
86             # rather than complete the difficult task of building a
87             # fully-general Membrane (which would locate all Referenceable
88             # objects that cross the simulated wire and replace them with
89             # wrappers), we special-case certain methods that we happen to
90             # know will return Referenceables.
91             if methname == "allocate_buckets":
92                 (alreadygot, allocated) = res
93                 for shnum in allocated:
94                     allocated[shnum] = LocalWrapper(allocated[shnum])
95             if methname == "get_buckets":
96                 for shnum in res:
97                     res[shnum] = LocalWrapper(res[shnum])
98             return res
99         d.addCallback(_return_membrane)
100         if self.post_call_notifier:
101             d.addCallback(self.post_call_notifier, self, methname)
102         return d
103
104     def notifyOnDisconnect(self, f, *args, **kwargs):
105         m = Marker()
106         self.disconnectors[m] = (f, args, kwargs)
107         return m
108     def dontNotifyOnDisconnect(self, marker):
109         del self.disconnectors[marker]
110
111 def wrap_storage_server(original):
112     # Much of the upload/download code uses rref.version (which normally
113     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
114     # network, we want a LocalWrapper here. Try to satisfy all these
115     # constraints at the same time.
116     wrapper = LocalWrapper(original)
117     wrapper.version = original.remote_get_version()
118     return wrapper
119
120 class NoNetworkServer:
121     def __init__(self, serverid, rref):
122         self.serverid = serverid
123         self.rref = rref
124     def get_serverid(self):
125         return self.serverid
126     def get_permutation_seed(self):
127         return self.serverid
128     def get_rref(self):
129         return self.rref
130
131 class NoNetworkStorageBroker:
132     implements(IStorageBroker)
133     def get_servers_for_psi(self, peer_selection_index):
134         def _permuted(server):
135             seed = server.get_permutation_seed()
136             return sha1(peer_selection_index + seed).digest()
137         return sorted(self.get_connected_servers(), key=_permuted)
138     def get_connected_servers(self):
139         return self.client._servers
140     def get_nickname_for_serverid(self, serverid):
141         return None
142
143 class NoNetworkClient(Client):
144     def create_tub(self):
145         pass
146     def init_introducer_client(self):
147         pass
148     def setup_logging(self):
149         pass
150     def startService(self):
151         service.MultiService.startService(self)
152     def stopService(self):
153         service.MultiService.stopService(self)
154     def when_tub_ready(self):
155         raise NotImplementedError("NoNetworkClient has no Tub")
156     def init_control(self):
157         pass
158     def init_helper(self):
159         pass
160     def init_key_gen(self):
161         pass
162     def init_storage(self):
163         pass
164     def init_client_storage_broker(self):
165         self.storage_broker = NoNetworkStorageBroker()
166         self.storage_broker.client = self
167     def init_stub_client(self):
168         pass
169     #._servers will be set by the NoNetworkGrid which creates us
170
171 class SimpleStats:
172     def __init__(self):
173         self.counters = {}
174         self.stats_producers = []
175
176     def count(self, name, delta=1):
177         val = self.counters.setdefault(name, 0)
178         self.counters[name] = val + delta
179
180     def register_producer(self, stats_producer):
181         self.stats_producers.append(stats_producer)
182
183     def get_stats(self):
184         stats = {}
185         for sp in self.stats_producers:
186             stats.update(sp.get_stats())
187         ret = { 'counters': self.counters, 'stats': stats }
188         return ret
189
190 class NoNetworkGrid(service.MultiService):
191     def __init__(self, basedir, num_clients=1, num_servers=10,
192                  client_config_hooks={}):
193         service.MultiService.__init__(self)
194         self.basedir = basedir
195         fileutil.make_dirs(basedir)
196
197         self.servers_by_number = {} # maps to StorageServer instance
198         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
199         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
200                                 # StorageServer
201         self.clients = []
202
203         for i in range(num_servers):
204             ss = self.make_server(i)
205             self.add_server(i, ss)
206         self.rebuild_serverlist()
207
208         for i in range(num_clients):
209             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
210             clientdir = os.path.join(basedir, "clients",
211                                      idlib.shortnodeid_b2a(clientid))
212             fileutil.make_dirs(clientdir)
213             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
214             f.write("[node]\n")
215             f.write("nickname = client-%d\n" % i)
216             f.write("web.port = tcp:0:interface=127.0.0.1\n")
217             f.write("[storage]\n")
218             f.write("enabled = false\n")
219             f.close()
220             c = None
221             if i in client_config_hooks:
222                 # this hook can either modify tahoe.cfg, or return an
223                 # entirely new Client instance
224                 c = client_config_hooks[i](clientdir)
225             if not c:
226                 c = NoNetworkClient(clientdir)
227                 c.set_default_mutable_keysize(522)
228             c.nodeid = clientid
229             c.short_nodeid = b32encode(clientid).lower()[:8]
230             c._servers = self.all_servers # can be updated later
231             c.setServiceParent(self)
232             self.clients.append(c)
233
234     def make_server(self, i, readonly=False):
235         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
236         serverdir = os.path.join(self.basedir, "servers",
237                                  idlib.shortnodeid_b2a(serverid))
238         fileutil.make_dirs(serverdir)
239         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
240                            readonly_storage=readonly)
241         ss._no_network_server_number = i
242         return ss
243
244     def add_server(self, i, ss):
245         # to deal with the fact that all StorageServers are named 'storage',
246         # we interpose a middleman
247         middleman = service.MultiService()
248         middleman.setServiceParent(self)
249         ss.setServiceParent(middleman)
250         serverid = ss.my_nodeid
251         self.servers_by_number[i] = ss
252         wrapper = wrap_storage_server(ss)
253         self.wrappers_by_id[serverid] = wrapper
254         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
255         self.rebuild_serverlist()
256
257     def get_all_serverids(self):
258         return self.proxies_by_id.keys()
259
260     def rebuild_serverlist(self):
261         self.all_servers = frozenset(self.proxies_by_id.values())
262         for c in self.clients:
263             c._servers = self.all_servers
264
265     def remove_server(self, serverid):
266         # it's enough to remove the server from c._servers (we don't actually
267         # have to detach and stopService it)
268         for i,ss in self.servers_by_number.items():
269             if ss.my_nodeid == serverid:
270                 del self.servers_by_number[i]
271                 break
272         del self.wrappers_by_id[serverid]
273         del self.proxies_by_id[serverid]
274         self.rebuild_serverlist()
275
276     def break_server(self, serverid):
277         # mark the given server as broken, so it will throw exceptions when
278         # asked to hold a share or serve a share
279         self.wrappers_by_id[serverid].broken = True
280
281     def hang_server(self, serverid):
282         # hang the given server
283         ss = self.wrappers_by_id[serverid]
284         assert ss.hung_until is None
285         ss.hung_until = defer.Deferred()
286
287     def unhang_server(self, serverid):
288         # unhang the given server
289         ss = self.wrappers_by_id[serverid]
290         assert ss.hung_until is not None
291         ss.hung_until.callback(None)
292         ss.hung_until = None
293
294
295 class GridTestMixin:
296     def setUp(self):
297         self.s = service.MultiService()
298         self.s.startService()
299
300     def tearDown(self):
301         return self.s.stopService()
302
303     def set_up_grid(self, num_clients=1, num_servers=10,
304                     client_config_hooks={}):
305         # self.basedir must be set
306         self.g = NoNetworkGrid(self.basedir,
307                                num_clients=num_clients,
308                                num_servers=num_servers,
309                                client_config_hooks=client_config_hooks)
310         self.g.setServiceParent(self.s)
311         self.client_webports = [c.getServiceNamed("webish").getPortnum()
312                                 for c in self.g.clients]
313         self.client_baseurls = [c.getServiceNamed("webish").getURL()
314                                 for c in self.g.clients]
315
316     def get_clientdir(self, i=0):
317         return self.g.clients[i].basedir
318
319     def get_serverdir(self, i):
320         return self.g.servers_by_number[i].storedir
321
322     def iterate_servers(self):
323         for i in sorted(self.g.servers_by_number.keys()):
324             ss = self.g.servers_by_number[i]
325             yield (i, ss, ss.storedir)
326
327     def find_uri_shares(self, uri):
328         si = tahoe_uri.from_string(uri).get_storage_index()
329         prefixdir = storage_index_to_dir(si)
330         shares = []
331         for i,ss in self.g.servers_by_number.items():
332             serverid = ss.my_nodeid
333             basedir = os.path.join(ss.storedir, "shares", prefixdir)
334             if not os.path.exists(basedir):
335                 continue
336             for f in os.listdir(basedir):
337                 try:
338                     shnum = int(f)
339                     shares.append((shnum, serverid, os.path.join(basedir, f)))
340                 except ValueError:
341                     pass
342         return sorted(shares)
343
344     def copy_shares(self, uri):
345         shares = {}
346         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
347             shares[sharefile] = open(sharefile, "rb").read()
348         return shares
349
350     def restore_all_shares(self, shares):
351         for sharefile, data in shares.items():
352             open(sharefile, "wb").write(data)
353
354     def delete_share(self, (shnum, serverid, sharefile)):
355         os.unlink(sharefile)
356
357     def delete_shares_numbered(self, uri, shnums):
358         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
359             if i_shnum in shnums:
360                 os.unlink(i_sharefile)
361
362     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
363         sharedata = open(sharefile, "rb").read()
364         corruptdata = corruptor_function(sharedata)
365         open(sharefile, "wb").write(corruptdata)
366
367     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
368         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
369             if i_shnum in shnums:
370                 sharedata = open(i_sharefile, "rb").read()
371                 corruptdata = corruptor(sharedata, debug=debug)
372                 open(i_sharefile, "wb").write(corruptdata)
373
374     def corrupt_all_shares(self, uri, corruptor, debug=False):
375         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
376             sharedata = open(i_sharefile, "rb").read()
377             corruptdata = corruptor(sharedata, debug=debug)
378             open(i_sharefile, "wb").write(corruptdata)
379
380     def GET(self, urlpath, followRedirect=False, return_response=False,
381             method="GET", clientnum=0, **kwargs):
382         # if return_response=True, this fires with (data, statuscode,
383         # respheaders) instead of just data.
384         assert not isinstance(urlpath, unicode)
385         url = self.client_baseurls[clientnum] + urlpath
386         factory = HTTPClientGETFactory(url, method=method,
387                                        followRedirect=followRedirect, **kwargs)
388         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
389         d = factory.deferred
390         def _got_data(data):
391             return (data, factory.status, factory.response_headers)
392         if return_response:
393             d.addCallback(_got_data)
394         return factory.deferred