]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
more refactoring: move get_all_serverids() and get_nickname_for_serverid() from Clien...
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 import sha
18 from twisted.application import service
19 from twisted.internet import reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.introducer.client import RemoteServiceConnector
28 from allmydata.test.common_web import HTTPClientGETFactory
29
30 class IntentionalError(Exception):
31     pass
32
33 class Marker:
34     pass
35
36 class LocalWrapper:
37     def __init__(self, original):
38         self.original = original
39         self.broken = False
40         self.post_call_notifier = None
41         self.disconnectors = {}
42
43     def callRemoteOnly(self, methname, *args, **kwargs):
44         d = self.callRemote(methname, *args, **kwargs)
45         return None
46
47     def callRemote(self, methname, *args, **kwargs):
48         # this is ideally a Membrane, but that's too hard. We do a shallow
49         # wrapping of inbound arguments, and per-methodname wrapping of
50         # selected return values.
51         def wrap(a):
52             if isinstance(a, Referenceable):
53                 return LocalWrapper(a)
54             else:
55                 return a
56         args = tuple([wrap(a) for a in args])
57         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
58         def _call():
59             if self.broken:
60                 raise IntentionalError("I was asked to break")
61             meth = getattr(self.original, "remote_" + methname)
62             return meth(*args, **kwargs)
63         d = fireEventually()
64         d.addCallback(lambda res: _call())
65         def _wrap_exception(f):
66             return Failure(RemoteException(f))
67         d.addErrback(_wrap_exception)
68         def _return_membrane(res):
69             # rather than complete the difficult task of building a
70             # fully-general Membrane (which would locate all Referenceable
71             # objects that cross the simulated wire and replace them with
72             # wrappers), we special-case certain methods that we happen to
73             # know will return Referenceables.
74             if methname == "allocate_buckets":
75                 (alreadygot, allocated) = res
76                 for shnum in allocated:
77                     allocated[shnum] = LocalWrapper(allocated[shnum])
78             if methname == "get_buckets":
79                 for shnum in res:
80                     res[shnum] = LocalWrapper(res[shnum])
81             return res
82         d.addCallback(_return_membrane)
83         if self.post_call_notifier:
84             d.addCallback(self.post_call_notifier, methname)
85         return d
86
87     def notifyOnDisconnect(self, f, *args, **kwargs):
88         m = Marker()
89         self.disconnectors[m] = (f, args, kwargs)
90         return m
91     def dontNotifyOnDisconnect(self, marker):
92         del self.disconnectors[marker]
93
94 def wrap(original, service_name):
95     # Much of the upload/download code uses rref.version (which normally
96     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
97     # network, we want a LocalWrapper here. Try to satisfy all these
98     # constraints at the same time.
99     wrapper = LocalWrapper(original)
100     try:
101         version = original.remote_get_version()
102     except AttributeError:
103         version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
104     wrapper.version = version
105     return wrapper
106
107 class NoNetworkStorageBroker:
108     def get_servers(self, key):
109         return sorted(self.client._servers,
110                       key=lambda x: sha.new(key+x[0]).digest())
111     def get_nickname_for_serverid(self, serverid):
112         return None
113
114 class NoNetworkClient(Client):
115
116     def create_tub(self):
117         pass
118     def init_introducer_client(self):
119         pass
120     def setup_logging(self):
121         pass
122     def startService(self):
123         service.MultiService.startService(self)
124     def stopService(self):
125         service.MultiService.stopService(self)
126     def when_tub_ready(self):
127         raise NotImplementedError("NoNetworkClient has no Tub")
128     def init_control(self):
129         pass
130     def init_helper(self):
131         pass
132     def init_key_gen(self):
133         pass
134     def init_storage(self):
135         pass
136     def init_client_storage_broker(self):
137         self.storage_broker = NoNetworkStorageBroker()
138         self.storage_broker.client = self
139     def init_stub_client(self):
140         pass
141
142     def get_servers(self, service_name):
143         return self._servers
144
145 class SimpleStats:
146     def __init__(self):
147         self.counters = {}
148         self.stats_producers = []
149
150     def count(self, name, delta=1):
151         val = self.counters.setdefault(name, 0)
152         self.counters[name] = val + delta
153
154     def register_producer(self, stats_producer):
155         self.stats_producers.append(stats_producer)
156
157     def get_stats(self):
158         stats = {}
159         for sp in self.stats_producers:
160             stats.update(sp.get_stats())
161         ret = { 'counters': self.counters, 'stats': stats }
162         return ret
163
164 class NoNetworkGrid(service.MultiService):
165     def __init__(self, basedir, num_clients=1, num_servers=10,
166                  client_config_hooks={}):
167         service.MultiService.__init__(self)
168         self.basedir = basedir
169         fileutil.make_dirs(basedir)
170
171         self.servers_by_number = {}
172         self.servers_by_id = {}
173         self.clients = []
174
175         for i in range(num_servers):
176             ss = self.make_server(i)
177             self.add_server(i, ss)
178
179         for i in range(num_clients):
180             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
181             clientdir = os.path.join(basedir, "clients",
182                                      idlib.shortnodeid_b2a(clientid))
183             fileutil.make_dirs(clientdir)
184             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
185             f.write("[node]\n")
186             f.write("nickname = client-%d\n" % i)
187             f.write("web.port = tcp:0:interface=127.0.0.1\n")
188             f.write("[storage]\n")
189             f.write("enabled = false\n")
190             f.close()
191             c = None
192             if i in client_config_hooks:
193                 # this hook can either modify tahoe.cfg, or return an
194                 # entirely new Client instance
195                 c = client_config_hooks[i](clientdir)
196             if not c:
197                 c = NoNetworkClient(clientdir)
198             c.nodeid = clientid
199             c.short_nodeid = b32encode(clientid).lower()[:8]
200             c._servers = self.all_servers # can be updated later
201             c.setServiceParent(self)
202             self.clients.append(c)
203
204     def make_server(self, i):
205         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
206         serverdir = os.path.join(self.basedir, "servers",
207                                  idlib.shortnodeid_b2a(serverid))
208         fileutil.make_dirs(serverdir)
209         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats())
210         return ss
211
212     def add_server(self, i, ss):
213         # to deal with the fact that all StorageServers are named 'storage',
214         # we interpose a middleman
215         middleman = service.MultiService()
216         middleman.setServiceParent(self)
217         ss.setServiceParent(middleman)
218         serverid = ss.my_nodeid
219         self.servers_by_number[i] = ss
220         self.servers_by_id[serverid] = wrap(ss, "storage")
221         self.all_servers = frozenset(self.servers_by_id.items())
222         for c in self.clients:
223             c._servers = self.all_servers
224
225 class GridTestMixin:
226     def setUp(self):
227         self.s = service.MultiService()
228         self.s.startService()
229
230     def tearDown(self):
231         return self.s.stopService()
232
233     def set_up_grid(self, num_clients=1, num_servers=10,
234                     client_config_hooks={}):
235         # self.basedir must be set
236         self.g = NoNetworkGrid(self.basedir,
237                                num_clients=num_clients,
238                                num_servers=num_servers,
239                                client_config_hooks=client_config_hooks)
240         self.g.setServiceParent(self.s)
241         self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
242                                 for c in self.g.clients]
243         self.client_baseurls = ["http://localhost:%d/" % p
244                                 for p in self.client_webports]
245
246     def get_clientdir(self, i=0):
247         return self.g.clients[i].basedir
248
249     def get_serverdir(self, i):
250         return self.g.servers_by_number[i].storedir
251
252     def iterate_servers(self):
253         for i in sorted(self.g.servers_by_number.keys()):
254             ss = self.g.servers_by_number[i]
255             yield (i, ss, ss.storedir)
256
257     def find_shares(self, uri):
258         si = tahoe_uri.from_string(uri).get_storage_index()
259         prefixdir = storage_index_to_dir(si)
260         shares = []
261         for i,ss in self.g.servers_by_number.items():
262             serverid = ss.my_nodeid
263             basedir = os.path.join(ss.storedir, "shares", prefixdir)
264             if not os.path.exists(basedir):
265                 continue
266             for f in os.listdir(basedir):
267                 try:
268                     shnum = int(f)
269                     shares.append((shnum, serverid, os.path.join(basedir, f)))
270                 except ValueError:
271                     pass
272         return sorted(shares)
273
274     def delete_share(self, (shnum, serverid, sharefile)):
275         os.unlink(sharefile)
276
277     def delete_shares_numbered(self, uri, shnums):
278         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
279             if i_shnum in shnums:
280                 os.unlink(i_sharefile)
281
282     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
283         sharedata = open(sharefile, "rb").read()
284         corruptdata = corruptor_function(sharedata)
285         open(sharefile, "wb").write(corruptdata)
286
287     def corrupt_shares_numbered(self, uri, shnums, corruptor):
288         for (i_shnum, i_serverid, i_sharefile) in self.find_shares(uri):
289             if i_shnum in shnums:
290                 sharedata = open(i_sharefile, "rb").read()
291                 corruptdata = corruptor(sharedata)
292                 open(i_sharefile, "wb").write(corruptdata)
293
294     def GET(self, urlpath, followRedirect=False, return_response=False,
295             method="GET", clientnum=0, **kwargs):
296         # if return_response=True, this fires with (data, statuscode,
297         # respheaders) instead of just data.
298         assert not isinstance(urlpath, unicode)
299         url = self.client_baseurls[clientnum] + urlpath
300         factory = HTTPClientGETFactory(url, method=method,
301                                        followRedirect=followRedirect, **kwargs)
302         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
303         d = factory.deferred
304         def _got_data(data):
305             return (data, factory.status, factory.response_headers)
306         if return_response:
307             d.addCallback(_got_data)
308         return factory.deferred