]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
tests/no_network: move GET into the GridTestMixin class
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 import sha
18 from twisted.application import service
19 from twisted.internet import reactor
20 from foolscap import Referenceable
21 from foolscap.eventual import fireEventually
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil, rrefutil
27 from allmydata.introducer.client import RemoteServiceConnector
28 from allmydata.test.common_web import HTTPClientGETFactory
29
30 class IntentionalError(Exception):
31     pass
32
33 class Marker:
34     pass
35
36 class LocalWrapper:
37     def __init__(self, original):
38         self.original = original
39         self.broken = False
40         self.post_call_notifier = None
41         self.disconnectors = {}
42
43     def callRemoteOnly(self, methname, *args, **kwargs):
44         d = self.callRemote(methname, *args, **kwargs)
45         return None
46
47     def callRemote(self, methname, *args, **kwargs):
48         # this is ideally a Membrane, but that's too hard. We do a shallow
49         # wrapping of inbound arguments, and per-methodname wrapping of
50         # selected return values.
51         def wrap(a):
52             if isinstance(a, Referenceable):
53                 return LocalWrapper(a)
54             else:
55                 return a
56         args = tuple([wrap(a) for a in args])
57         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
58         def _call():
59             if self.broken:
60                 raise IntentionalError("I was asked to break")
61             meth = getattr(self.original, "remote_" + methname)
62             return meth(*args, **kwargs)
63         d = fireEventually()
64         d.addCallback(lambda res: _call())
65         def _return_membrane(res):
66             # rather than complete the difficult task of building a
67             # fully-general Membrane (which would locate all Referenceable
68             # objects that cross the simulated wire and replace them with
69             # wrappers), we special-case certain methods that we happen to
70             # know will return Referenceables.
71             if methname == "allocate_buckets":
72                 (alreadygot, allocated) = res
73                 for shnum in allocated:
74                     allocated[shnum] = LocalWrapper(allocated[shnum])
75             if methname == "get_buckets":
76                 for shnum in res:
77                     res[shnum] = LocalWrapper(res[shnum])
78             return res
79         d.addCallback(_return_membrane)
80         if self.post_call_notifier:
81             d.addCallback(self.post_call_notifier, methname)
82         return d
83
84     def notifyOnDisconnect(self, f, *args, **kwargs):
85         m = Marker()
86         self.disconnectors[m] = (f, args, kwargs)
87         return m
88     def dontNotifyOnDisconnect(self, marker):
89         del self.disconnectors[marker]
90
91 def wrap(original, service_name):
92     # The code in immutable.checker insists upon asserting the truth of
93     # isinstance(rref, rrefutil.WrappedRemoteReference). Much of the
94     # upload/download code uses rref.version (which normally comes from
95     # rrefutil.VersionedRemoteReference). To avoid using a network, we want a
96     # LocalWrapper here. Try to satisfy all these constraints at the same
97     # time.
98     local = LocalWrapper(original)
99     wrapped = rrefutil.WrappedRemoteReference(local)
100     try:
101         version = original.remote_get_version()
102     except AttributeError:
103         version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
104     wrapped.version = version
105     return wrapped
106
107 class NoNetworkClient(Client):
108
109     def create_tub(self):
110         pass
111     def init_introducer_client(self):
112         pass
113     def setup_logging(self):
114         pass
115     def startService(self):
116         service.MultiService.startService(self)
117     def stopService(self):
118         service.MultiService.stopService(self)
119     def when_tub_ready(self):
120         raise NotImplementedError("NoNetworkClient has no Tub")
121     def init_control(self):
122         pass
123     def init_helper(self):
124         pass
125     def init_key_gen(self):
126         pass
127     def init_storage(self):
128         pass
129     def init_stub_client(self):
130         pass
131
132     def get_servers(self, service_name):
133         return self._servers
134
135     def get_permuted_peers(self, service_name, key):
136         return sorted(self._servers, key=lambda x: sha.new(key+x[0]).digest())
137     def get_nickname_for_peerid(self, peerid):
138         return None
139
140 class SimpleStats:
141     def __init__(self):
142         self.counters = {}
143         self.stats_producers = []
144
145     def count(self, name, delta=1):
146         val = self.counters.setdefault(name, 0)
147         self.counters[name] = val + delta
148
149     def register_producer(self, stats_producer):
150         self.stats_producers.append(stats_producer)
151
152     def get_stats(self):
153         stats = {}
154         for sp in self.stats_producers:
155             stats.update(sp.get_stats())
156         ret = { 'counters': self.counters, 'stats': stats }
157         return ret
158
159 class NoNetworkGrid(service.MultiService):
160     def __init__(self, basedir, num_clients=1, num_servers=10,
161                  client_config_hooks={}):
162         service.MultiService.__init__(self)
163         self.basedir = basedir
164         fileutil.make_dirs(basedir)
165
166         self.servers_by_number = {}
167         self.servers_by_id = {}
168         self.clients = []
169
170         for i in range(num_servers):
171             ss = self.make_server(i)
172             self.add_server(i, ss)
173
174         for i in range(num_clients):
175             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
176             clientdir = os.path.join(basedir, "clients",
177                                      idlib.shortnodeid_b2a(clientid))
178             fileutil.make_dirs(clientdir)
179             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
180             f.write("[node]\n")
181             f.write("nickname = client-%d\n" % i)
182             f.write("web.port = tcp:0:interface=127.0.0.1\n")
183             f.write("[storage]\n")
184             f.write("enabled = false\n")
185             f.close()
186             c = None
187             if i in client_config_hooks:
188                 # this hook can either modify tahoe.cfg, or return an
189                 # entirely new Client instance
190                 c = client_config_hooks[i](clientdir)
191             if not c:
192                 c = NoNetworkClient(clientdir)
193             c.nodeid = clientid
194             c.short_nodeid = b32encode(clientid).lower()[:8]
195             c._servers = self.all_servers # can be updated later
196             c.setServiceParent(self)
197             self.clients.append(c)
198
199     def make_server(self, i):
200         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
201         serverdir = os.path.join(self.basedir, "servers",
202                                  idlib.shortnodeid_b2a(serverid))
203         fileutil.make_dirs(serverdir)
204         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats())
205         return ss
206
207     def add_server(self, i, ss):
208         # to deal with the fact that all StorageServers are named 'storage',
209         # we interpose a middleman
210         middleman = service.MultiService()
211         middleman.setServiceParent(self)
212         ss.setServiceParent(middleman)
213         serverid = ss.my_nodeid
214         self.servers_by_number[i] = ss
215         self.servers_by_id[serverid] = wrap(ss, "storage")
216         self.all_servers = frozenset(self.servers_by_id.items())
217         for c in self.clients:
218             c._servers = self.all_servers
219
220 class GridTestMixin:
221     def setUp(self):
222         self.s = service.MultiService()
223         self.s.startService()
224
225     def tearDown(self):
226         return self.s.stopService()
227
228     def set_up_grid(self, num_clients=1, num_servers=10,
229                     client_config_hooks={}):
230         # self.basedir must be set
231         self.g = NoNetworkGrid(self.basedir,
232                                num_clients=num_clients,
233                                num_servers=num_servers,
234                                client_config_hooks=client_config_hooks)
235         self.g.setServiceParent(self.s)
236         self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
237                                 for c in self.g.clients]
238         self.client_baseurls = ["http://localhost:%d/" % p
239                                 for p in self.client_webports]
240
241     def get_clientdir(self, i=0):
242         return self.g.clients[i].basedir
243
244     def get_serverdir(self, i):
245         return self.g.servers_by_number[i].storedir
246
247     def iterate_servers(self):
248         for i in sorted(self.g.servers_by_number.keys()):
249             ss = self.g.servers_by_number[i]
250             yield (i, ss, ss.storedir)
251
252     def find_shares(self, uri):
253         si = tahoe_uri.from_string(uri).get_storage_index()
254         prefixdir = storage_index_to_dir(si)
255         shares = []
256         for i,ss in self.g.servers_by_number.items():
257             serverid = ss.my_nodeid
258             basedir = os.path.join(ss.storedir, "shares", prefixdir)
259             if not os.path.exists(basedir):
260                 continue
261             for f in os.listdir(basedir):
262                 try:
263                     shnum = int(f)
264                     shares.append((shnum, serverid, os.path.join(basedir, f)))
265                 except ValueError:
266                     pass
267         return sorted(shares)
268
269     def delete_share(self, (shnum, serverid, sharefile)):
270         os.unlink(sharefile)
271
272     def delete_shares_numbered(self, uri, shnums):
273         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
274             if i_shnum in shnums:
275                 os.unlink(i_sharefile)
276
277     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
278         sharedata = open(sharefile, "rb").read()
279         corruptdata = corruptor_function(sharedata)
280         open(sharefile, "wb").write(corruptdata)
281
282     def corrupt_shares_numbered(self, uri, shnums, corruptor):
283         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
284             if i_shnum in shnums:
285                 sharedata = open(i_sharefile, "rb").read()
286                 corruptdata = corruptor(sharedata)
287                 open(i_sharefile, "wb").write(corruptdata)
288
289     def GET(self, urlpath, followRedirect=False, return_response=False,
290             method="GET", clientnum=0, **kwargs):
291         # if return_response=True, this fires with (data, statuscode,
292         # respheaders) instead of just data.
293         assert not isinstance(urlpath, unicode)
294         url = self.client_baseurls[clientnum] + urlpath
295         factory = HTTPClientGETFactory(url, method=method,
296                                        followRedirect=followRedirect, **kwargs)
297         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
298         d = factory.deferred
299         def _got_data(data):
300             return (data, factory.status, factory.response_headers)
301         if return_response:
302             d.addCallback(_got_data)
303         return factory.deferred