]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
immutable: extend the tests to check that the shares that got uploaded really do...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker
30
31 class IntentionalError(Exception):
32     pass
33
34 class Marker:
35     pass
36
37 class LocalWrapper:
38     def __init__(self, original):
39         self.original = original
40         self.broken = False
41         self.hung_until = None
42         self.post_call_notifier = None
43         self.disconnectors = {}
44
45     def callRemoteOnly(self, methname, *args, **kwargs):
46         d = self.callRemote(methname, *args, **kwargs)
47         del d # explicitly ignored
48         return None
49
50     def callRemote(self, methname, *args, **kwargs):
51         # this is ideally a Membrane, but that's too hard. We do a shallow
52         # wrapping of inbound arguments, and per-methodname wrapping of
53         # selected return values.
54         def wrap(a):
55             if isinstance(a, Referenceable):
56                 return LocalWrapper(a)
57             else:
58                 return a
59         args = tuple([wrap(a) for a in args])
60         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
61
62         def _really_call():
63             meth = getattr(self.original, "remote_" + methname)
64             return meth(*args, **kwargs)
65
66         def _call():
67             if self.broken:
68                 raise IntentionalError("I was asked to break")
69             if self.hung_until:
70                 d2 = defer.Deferred()
71                 self.hung_until.addCallback(lambda ign: _really_call())
72                 self.hung_until.addCallback(lambda res: d2.callback(res))
73                 def _err(res):
74                     d2.errback(res)
75                     return res
76                 self.hung_until.addErrback(_err)
77                 return d2
78             return _really_call()
79
80         d = fireEventually()
81         d.addCallback(lambda res: _call())
82         def _wrap_exception(f):
83             return Failure(RemoteException(f))
84         d.addErrback(_wrap_exception)
85         def _return_membrane(res):
86             # rather than complete the difficult task of building a
87             # fully-general Membrane (which would locate all Referenceable
88             # objects that cross the simulated wire and replace them with
89             # wrappers), we special-case certain methods that we happen to
90             # know will return Referenceables.
91             if methname == "allocate_buckets":
92                 (alreadygot, allocated) = res
93                 for shnum in allocated:
94                     allocated[shnum] = LocalWrapper(allocated[shnum])
95             if methname == "get_buckets":
96                 for shnum in res:
97                     res[shnum] = LocalWrapper(res[shnum])
98             return res
99         d.addCallback(_return_membrane)
100         if self.post_call_notifier:
101             d.addCallback(self.post_call_notifier, self, methname)
102         return d
103
104     def notifyOnDisconnect(self, f, *args, **kwargs):
105         m = Marker()
106         self.disconnectors[m] = (f, args, kwargs)
107         return m
108     def dontNotifyOnDisconnect(self, marker):
109         del self.disconnectors[marker]
110
111 def wrap_storage_server(original):
112     # Much of the upload/download code uses rref.version (which normally
113     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
114     # network, we want a LocalWrapper here. Try to satisfy all these
115     # constraints at the same time.
116     wrapper = LocalWrapper(original)
117     wrapper.version = original.remote_get_version()
118     return wrapper
119
120 class NoNetworkStorageBroker:
121     implements(IStorageBroker)
122     def get_servers_for_index(self, key):
123         return sorted(self.client._servers,
124                       key=lambda x: sha1(key+x[0]).digest())
125     def get_all_servers(self):
126         return frozenset(self.client._servers)
127     def get_nickname_for_serverid(self, serverid):
128         return None
129
130 class NoNetworkClient(Client):
131     def create_tub(self):
132         pass
133     def init_introducer_client(self):
134         pass
135     def setup_logging(self):
136         pass
137     def startService(self):
138         service.MultiService.startService(self)
139     def stopService(self):
140         service.MultiService.stopService(self)
141     def when_tub_ready(self):
142         raise NotImplementedError("NoNetworkClient has no Tub")
143     def init_control(self):
144         pass
145     def init_helper(self):
146         pass
147     def init_key_gen(self):
148         pass
149     def init_storage(self):
150         pass
151     def init_client_storage_broker(self):
152         self.storage_broker = NoNetworkStorageBroker()
153         self.storage_broker.client = self
154     def init_stub_client(self):
155         pass
156     #._servers will be set by the NoNetworkGrid which creates us
157
158 class SimpleStats:
159     def __init__(self):
160         self.counters = {}
161         self.stats_producers = []
162
163     def count(self, name, delta=1):
164         val = self.counters.setdefault(name, 0)
165         self.counters[name] = val + delta
166
167     def register_producer(self, stats_producer):
168         self.stats_producers.append(stats_producer)
169
170     def get_stats(self):
171         stats = {}
172         for sp in self.stats_producers:
173             stats.update(sp.get_stats())
174         ret = { 'counters': self.counters, 'stats': stats }
175         return ret
176
177 class NoNetworkGrid(service.MultiService):
178     def __init__(self, basedir, num_clients=1, num_servers=10,
179                  client_config_hooks={}):
180         service.MultiService.__init__(self)
181         self.basedir = basedir
182         fileutil.make_dirs(basedir)
183
184         self.servers_by_number = {}
185         self.servers_by_id = {}
186         self.clients = []
187
188         for i in range(num_servers):
189             ss = self.make_server(i)
190             self.add_server(i, ss)
191         self.rebuild_serverlist()
192
193         for i in range(num_clients):
194             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
195             clientdir = os.path.join(basedir, "clients",
196                                      idlib.shortnodeid_b2a(clientid))
197             fileutil.make_dirs(clientdir)
198             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
199             f.write("[node]\n")
200             f.write("nickname = client-%d\n" % i)
201             f.write("web.port = tcp:0:interface=127.0.0.1\n")
202             f.write("[storage]\n")
203             f.write("enabled = false\n")
204             f.close()
205             c = None
206             if i in client_config_hooks:
207                 # this hook can either modify tahoe.cfg, or return an
208                 # entirely new Client instance
209                 c = client_config_hooks[i](clientdir)
210             if not c:
211                 c = NoNetworkClient(clientdir)
212                 c.set_default_mutable_keysize(522)
213             c.nodeid = clientid
214             c.short_nodeid = b32encode(clientid).lower()[:8]
215             c._servers = self.all_servers # can be updated later
216             c.setServiceParent(self)
217             self.clients.append(c)
218
219     def make_server(self, i, readonly=False):
220         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
221         serverdir = os.path.join(self.basedir, "servers",
222                                  idlib.shortnodeid_b2a(serverid))
223         fileutil.make_dirs(serverdir)
224         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
225                            readonly_storage=readonly)
226         return ss
227
228     def add_server(self, i, ss):
229         # to deal with the fact that all StorageServers are named 'storage',
230         # we interpose a middleman
231         middleman = service.MultiService()
232         middleman.setServiceParent(self)
233         ss.setServiceParent(middleman)
234         serverid = ss.my_nodeid
235         self.servers_by_number[i] = ss
236         self.servers_by_id[serverid] = wrap_storage_server(ss)
237         self.rebuild_serverlist()
238
239     def rebuild_serverlist(self):
240         self.all_servers = frozenset(self.servers_by_id.items())
241         for c in self.clients:
242             c._servers = self.all_servers
243
244     def remove_server(self, serverid):
245         # it's enough to remove the server from c._servers (we don't actually
246         # have to detach and stopService it)
247         for i,ss in self.servers_by_number.items():
248             if ss.my_nodeid == serverid:
249                 del self.servers_by_number[i]
250                 break
251         del self.servers_by_id[serverid]
252         self.rebuild_serverlist()
253
254     def break_server(self, serverid):
255         # mark the given server as broken, so it will throw exceptions when
256         # asked to hold a share or serve a share
257         self.servers_by_id[serverid].broken = True
258
259     def hang_server(self, serverid):
260         # hang the given server
261         ss = self.servers_by_id[serverid]
262         assert ss.hung_until is None
263         ss.hung_until = defer.Deferred()
264
265     def unhang_server(self, serverid):
266         # unhang the given server
267         ss = self.servers_by_id[serverid]
268         assert ss.hung_until is not None
269         ss.hung_until.callback(None)
270         ss.hung_until = None
271
272
273 class GridTestMixin:
274     def setUp(self):
275         self.s = service.MultiService()
276         self.s.startService()
277
278     def tearDown(self):
279         return self.s.stopService()
280
281     def set_up_grid(self, num_clients=1, num_servers=10,
282                     client_config_hooks={}):
283         # self.basedir must be set
284         self.g = NoNetworkGrid(self.basedir,
285                                num_clients=num_clients,
286                                num_servers=num_servers,
287                                client_config_hooks=client_config_hooks)
288         self.g.setServiceParent(self.s)
289         self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
290                                 for c in self.g.clients]
291         self.client_baseurls = ["http://localhost:%d/" % p
292                                 for p in self.client_webports]
293
294     def get_clientdir(self, i=0):
295         return self.g.clients[i].basedir
296
297     def get_serverdir(self, i):
298         return self.g.servers_by_number[i].storedir
299
300     def iterate_servers(self):
301         for i in sorted(self.g.servers_by_number.keys()):
302             ss = self.g.servers_by_number[i]
303             yield (i, ss, ss.storedir)
304
305     def find_uri_shares(self, uri):
306         si = tahoe_uri.from_string(uri).get_storage_index()
307         prefixdir = storage_index_to_dir(si)
308         shares = []
309         for i,ss in self.g.servers_by_number.items():
310             serverid = ss.my_nodeid
311             basedir = os.path.join(ss.storedir, "shares", prefixdir)
312             if not os.path.exists(basedir):
313                 continue
314             for f in os.listdir(basedir):
315                 try:
316                     shnum = int(f)
317                     shares.append((shnum, serverid, os.path.join(basedir, f)))
318                 except ValueError:
319                     pass
320         return sorted(shares)
321
322     def delete_share(self, (shnum, serverid, sharefile)):
323         os.unlink(sharefile)
324
325     def delete_shares_numbered(self, uri, shnums):
326         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
327             if i_shnum in shnums:
328                 os.unlink(i_sharefile)
329
330     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
331         sharedata = open(sharefile, "rb").read()
332         corruptdata = corruptor_function(sharedata)
333         open(sharefile, "wb").write(corruptdata)
334
335     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
336         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
337             if i_shnum in shnums:
338                 sharedata = open(i_sharefile, "rb").read()
339                 corruptdata = corruptor(sharedata, debug=debug)
340                 open(i_sharefile, "wb").write(corruptdata)
341
342     def GET(self, urlpath, followRedirect=False, return_response=False,
343             method="GET", clientnum=0, **kwargs):
344         # if return_response=True, this fires with (data, statuscode,
345         # respheaders) instead of just data.
346         assert not isinstance(urlpath, unicode)
347         url = self.client_baseurls[clientnum] + urlpath
348         factory = HTTPClientGETFactory(url, method=method,
349                                        followRedirect=followRedirect, **kwargs)
350         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
351         d = factory.deferred
352         def _got_data(data):
353             return (data, factory.status, factory.response_headers)
354         if return_response:
355             d.addCallback(_got_data)
356         return factory.deferred