]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
switch all foolscap imports to use foolscap.api or foolscap.logging
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 import sha
18 from twisted.application import service
19 from twisted.internet import reactor
20 from foolscap.api import Referenceable, fireEventually
21 from base64 import b32encode
22 from allmydata import uri as tahoe_uri
23 from allmydata.client import Client
24 from allmydata.storage.server import StorageServer, storage_index_to_dir
25 from allmydata.util import fileutil, idlib, hashutil, rrefutil
26 from allmydata.introducer.client import RemoteServiceConnector
27 from allmydata.test.common_web import HTTPClientGETFactory
28
29 class IntentionalError(Exception):
30     pass
31
32 class Marker:
33     pass
34
35 class LocalWrapper:
36     def __init__(self, original):
37         self.original = original
38         self.broken = False
39         self.post_call_notifier = None
40         self.disconnectors = {}
41
42     def callRemoteOnly(self, methname, *args, **kwargs):
43         d = self.callRemote(methname, *args, **kwargs)
44         return None
45
46     def callRemote(self, methname, *args, **kwargs):
47         # this is ideally a Membrane, but that's too hard. We do a shallow
48         # wrapping of inbound arguments, and per-methodname wrapping of
49         # selected return values.
50         def wrap(a):
51             if isinstance(a, Referenceable):
52                 return LocalWrapper(a)
53             else:
54                 return a
55         args = tuple([wrap(a) for a in args])
56         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
57         def _call():
58             if self.broken:
59                 raise IntentionalError("I was asked to break")
60             meth = getattr(self.original, "remote_" + methname)
61             return meth(*args, **kwargs)
62         d = fireEventually()
63         d.addCallback(lambda res: _call())
64         def _return_membrane(res):
65             # rather than complete the difficult task of building a
66             # fully-general Membrane (which would locate all Referenceable
67             # objects that cross the simulated wire and replace them with
68             # wrappers), we special-case certain methods that we happen to
69             # know will return Referenceables.
70             if methname == "allocate_buckets":
71                 (alreadygot, allocated) = res
72                 for shnum in allocated:
73                     allocated[shnum] = LocalWrapper(allocated[shnum])
74             if methname == "get_buckets":
75                 for shnum in res:
76                     res[shnum] = LocalWrapper(res[shnum])
77             return res
78         d.addCallback(_return_membrane)
79         if self.post_call_notifier:
80             d.addCallback(self.post_call_notifier, methname)
81         return d
82
83     def notifyOnDisconnect(self, f, *args, **kwargs):
84         m = Marker()
85         self.disconnectors[m] = (f, args, kwargs)
86         return m
87     def dontNotifyOnDisconnect(self, marker):
88         del self.disconnectors[marker]
89
90 def wrap(original, service_name):
91     # The code in immutable.checker insists upon asserting the truth of
92     # isinstance(rref, rrefutil.WrappedRemoteReference). Much of the
93     # upload/download code uses rref.version (which normally comes from
94     # rrefutil.VersionedRemoteReference). To avoid using a network, we want a
95     # LocalWrapper here. Try to satisfy all these constraints at the same
96     # time.
97     local = LocalWrapper(original)
98     wrapped = rrefutil.WrappedRemoteReference(local)
99     try:
100         version = original.remote_get_version()
101     except AttributeError:
102         version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
103     wrapped.version = version
104     return wrapped
105
106 class NoNetworkClient(Client):
107
108     def create_tub(self):
109         pass
110     def init_introducer_client(self):
111         pass
112     def setup_logging(self):
113         pass
114     def startService(self):
115         service.MultiService.startService(self)
116     def stopService(self):
117         service.MultiService.stopService(self)
118     def when_tub_ready(self):
119         raise NotImplementedError("NoNetworkClient has no Tub")
120     def init_control(self):
121         pass
122     def init_helper(self):
123         pass
124     def init_key_gen(self):
125         pass
126     def init_storage(self):
127         pass
128     def init_stub_client(self):
129         pass
130
131     def get_servers(self, service_name):
132         return self._servers
133
134     def get_permuted_peers(self, service_name, key):
135         return sorted(self._servers, key=lambda x: sha.new(key+x[0]).digest())
136     def get_nickname_for_peerid(self, peerid):
137         return None
138
139 class SimpleStats:
140     def __init__(self):
141         self.counters = {}
142         self.stats_producers = []
143
144     def count(self, name, delta=1):
145         val = self.counters.setdefault(name, 0)
146         self.counters[name] = val + delta
147
148     def register_producer(self, stats_producer):
149         self.stats_producers.append(stats_producer)
150
151     def get_stats(self):
152         stats = {}
153         for sp in self.stats_producers:
154             stats.update(sp.get_stats())
155         ret = { 'counters': self.counters, 'stats': stats }
156         return ret
157
158 class NoNetworkGrid(service.MultiService):
159     def __init__(self, basedir, num_clients=1, num_servers=10,
160                  client_config_hooks={}):
161         service.MultiService.__init__(self)
162         self.basedir = basedir
163         fileutil.make_dirs(basedir)
164
165         self.servers_by_number = {}
166         self.servers_by_id = {}
167         self.clients = []
168
169         for i in range(num_servers):
170             ss = self.make_server(i)
171             self.add_server(i, ss)
172
173         for i in range(num_clients):
174             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
175             clientdir = os.path.join(basedir, "clients",
176                                      idlib.shortnodeid_b2a(clientid))
177             fileutil.make_dirs(clientdir)
178             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
179             f.write("[node]\n")
180             f.write("nickname = client-%d\n" % i)
181             f.write("web.port = tcp:0:interface=127.0.0.1\n")
182             f.write("[storage]\n")
183             f.write("enabled = false\n")
184             f.close()
185             c = None
186             if i in client_config_hooks:
187                 # this hook can either modify tahoe.cfg, or return an
188                 # entirely new Client instance
189                 c = client_config_hooks[i](clientdir)
190             if not c:
191                 c = NoNetworkClient(clientdir)
192             c.nodeid = clientid
193             c.short_nodeid = b32encode(clientid).lower()[:8]
194             c._servers = self.all_servers # can be updated later
195             c.setServiceParent(self)
196             self.clients.append(c)
197
198     def make_server(self, i):
199         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
200         serverdir = os.path.join(self.basedir, "servers",
201                                  idlib.shortnodeid_b2a(serverid))
202         fileutil.make_dirs(serverdir)
203         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats())
204         return ss
205
206     def add_server(self, i, ss):
207         # to deal with the fact that all StorageServers are named 'storage',
208         # we interpose a middleman
209         middleman = service.MultiService()
210         middleman.setServiceParent(self)
211         ss.setServiceParent(middleman)
212         serverid = ss.my_nodeid
213         self.servers_by_number[i] = ss
214         self.servers_by_id[serverid] = wrap(ss, "storage")
215         self.all_servers = frozenset(self.servers_by_id.items())
216         for c in self.clients:
217             c._servers = self.all_servers
218
219 class GridTestMixin:
220     def setUp(self):
221         self.s = service.MultiService()
222         self.s.startService()
223
224     def tearDown(self):
225         return self.s.stopService()
226
227     def set_up_grid(self, num_clients=1, num_servers=10,
228                     client_config_hooks={}):
229         # self.basedir must be set
230         self.g = NoNetworkGrid(self.basedir,
231                                num_clients=num_clients,
232                                num_servers=num_servers,
233                                client_config_hooks=client_config_hooks)
234         self.g.setServiceParent(self.s)
235         self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
236                                 for c in self.g.clients]
237         self.client_baseurls = ["http://localhost:%d/" % p
238                                 for p in self.client_webports]
239
240     def get_clientdir(self, i=0):
241         return self.g.clients[i].basedir
242
243     def get_serverdir(self, i):
244         return self.g.servers_by_number[i].storedir
245
246     def iterate_servers(self):
247         for i in sorted(self.g.servers_by_number.keys()):
248             ss = self.g.servers_by_number[i]
249             yield (i, ss, ss.storedir)
250
251     def find_shares(self, uri):
252         si = tahoe_uri.from_string(uri).get_storage_index()
253         prefixdir = storage_index_to_dir(si)
254         shares = []
255         for i,ss in self.g.servers_by_number.items():
256             serverid = ss.my_nodeid
257             basedir = os.path.join(ss.storedir, "shares", prefixdir)
258             if not os.path.exists(basedir):
259                 continue
260             for f in os.listdir(basedir):
261                 try:
262                     shnum = int(f)
263                     shares.append((shnum, serverid, os.path.join(basedir, f)))
264                 except ValueError:
265                     pass
266         return sorted(shares)
267
268     def delete_share(self, (shnum, serverid, sharefile)):
269         os.unlink(sharefile)
270
271     def delete_shares_numbered(self, uri, shnums):
272         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
273             if i_shnum in shnums:
274                 os.unlink(i_sharefile)
275
276     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
277         sharedata = open(sharefile, "rb").read()
278         corruptdata = corruptor_function(sharedata)
279         open(sharefile, "wb").write(corruptdata)
280
281     def corrupt_shares_numbered(self, uri, shnums, corruptor):
282         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
283             if i_shnum in shnums:
284                 sharedata = open(i_sharefile, "rb").read()
285                 corruptdata = corruptor(sharedata)
286                 open(i_sharefile, "wb").write(corruptdata)
287
288     def GET(self, urlpath, followRedirect=False, return_response=False,
289             method="GET", clientnum=0, **kwargs):
290         # if return_response=True, this fires with (data, statuscode,
291         # respheaders) instead of just data.
292         assert not isinstance(urlpath, unicode)
293         url = self.client_baseurls[clientnum] + urlpath
294         factory = HTTPClientGETFactory(url, method=method,
295                                        followRedirect=followRedirect, **kwargs)
296         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
297         d = factory.deferred
298         def _got_data(data):
299             return (data, factory.status, factory.response_headers)
300         if return_response:
301             d.addCallback(_got_data)
302         return factory.deferred