]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
Improve SFTP error handling and remove use of IFinishableConsumer. fixes #1525
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 from zope.interface import implements
18 from twisted.application import service
19 from twisted.internet import defer, reactor
20 from twisted.python.failure import Failure
21 from foolscap.api import Referenceable, fireEventually, RemoteException
22 from base64 import b32encode
23 from allmydata import uri as tahoe_uri
24 from allmydata.client import Client
25 from allmydata.storage.server import StorageServer, storage_index_to_dir
26 from allmydata.util import fileutil, idlib, hashutil
27 from allmydata.util.hashutil import sha1
28 from allmydata.test.common_web import HTTPClientGETFactory
29 from allmydata.interfaces import IStorageBroker, IServer
30 from allmydata.test.common import TEST_RSA_KEY_SIZE
31
32
33 class IntentionalError(Exception):
34     pass
35
36 class Marker:
37     pass
38
39 class LocalWrapper:
40     def __init__(self, original):
41         self.original = original
42         self.broken = False
43         self.hung_until = None
44         self.post_call_notifier = None
45         self.disconnectors = {}
46         self.counter_by_methname = {}
47
48     def _clear_counters(self):
49         self.counter_by_methname = {}
50
51     def callRemoteOnly(self, methname, *args, **kwargs):
52         d = self.callRemote(methname, *args, **kwargs)
53         del d # explicitly ignored
54         return None
55
56     def callRemote(self, methname, *args, **kwargs):
57         # this is ideally a Membrane, but that's too hard. We do a shallow
58         # wrapping of inbound arguments, and per-methodname wrapping of
59         # selected return values.
60         def wrap(a):
61             if isinstance(a, Referenceable):
62                 return LocalWrapper(a)
63             else:
64                 return a
65         args = tuple([wrap(a) for a in args])
66         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
67
68         def _really_call():
69             def incr(d, k): d[k] = d.setdefault(k, 0) + 1
70             incr(self.counter_by_methname, methname)
71             meth = getattr(self.original, "remote_" + methname)
72             return meth(*args, **kwargs)
73
74         def _call():
75             if self.broken:
76                 if self.broken is not True: # a counter, not boolean
77                     self.broken -= 1
78                 raise IntentionalError("I was asked to break")
79             if self.hung_until:
80                 d2 = defer.Deferred()
81                 self.hung_until.addCallback(lambda ign: _really_call())
82                 self.hung_until.addCallback(lambda res: d2.callback(res))
83                 def _err(res):
84                     d2.errback(res)
85                     return res
86                 self.hung_until.addErrback(_err)
87                 return d2
88             return _really_call()
89
90         d = fireEventually()
91         d.addCallback(lambda res: _call())
92         def _wrap_exception(f):
93             return Failure(RemoteException(f))
94         d.addErrback(_wrap_exception)
95         def _return_membrane(res):
96             # rather than complete the difficult task of building a
97             # fully-general Membrane (which would locate all Referenceable
98             # objects that cross the simulated wire and replace them with
99             # wrappers), we special-case certain methods that we happen to
100             # know will return Referenceables.
101             if methname == "allocate_buckets":
102                 (alreadygot, allocated) = res
103                 for shnum in allocated:
104                     allocated[shnum] = LocalWrapper(allocated[shnum])
105             if methname == "get_buckets":
106                 for shnum in res:
107                     res[shnum] = LocalWrapper(res[shnum])
108             return res
109         d.addCallback(_return_membrane)
110         if self.post_call_notifier:
111             d.addCallback(self.post_call_notifier, self, methname)
112         return d
113
114     def notifyOnDisconnect(self, f, *args, **kwargs):
115         m = Marker()
116         self.disconnectors[m] = (f, args, kwargs)
117         return m
118     def dontNotifyOnDisconnect(self, marker):
119         del self.disconnectors[marker]
120
121 def wrap_storage_server(original):
122     # Much of the upload/download code uses rref.version (which normally
123     # comes from rrefutil.add_version_to_remote_reference). To avoid using a
124     # network, we want a LocalWrapper here. Try to satisfy all these
125     # constraints at the same time.
126     wrapper = LocalWrapper(original)
127     wrapper.version = original.remote_get_version()
128     return wrapper
129
130 class NoNetworkServer:
131     implements(IServer)
132     def __init__(self, serverid, rref):
133         self.serverid = serverid
134         self.rref = rref
135     def __repr__(self):
136         return "<NoNetworkServer for %s>" % self.get_name()
137     # Special method used by copy.copy() and copy.deepcopy(). When those are
138     # used in allmydata.immutable.filenode to copy CheckResults during
139     # repair, we want it to treat the IServer instances as singletons.
140     def __copy__(self):
141         return self
142     def __deepcopy__(self, memodict):
143         return self
144     def get_serverid(self):
145         return self.serverid
146     def get_permutation_seed(self):
147         return self.serverid
148     def get_lease_seed(self):
149         return self.serverid
150     def get_foolscap_write_enabler_seed(self):
151         return self.serverid
152
153     def get_name(self):
154         return idlib.shortnodeid_b2a(self.serverid)
155     def get_longname(self):
156         return idlib.nodeid_b2a(self.serverid)
157     def get_nickname(self):
158         return "nickname"
159     def get_rref(self):
160         return self.rref
161     def get_version(self):
162         return self.rref.version
163
164 class NoNetworkStorageBroker:
165     implements(IStorageBroker)
166     def get_servers_for_psi(self, peer_selection_index):
167         def _permuted(server):
168             seed = server.get_permutation_seed()
169             return sha1(peer_selection_index + seed).digest()
170         return sorted(self.get_connected_servers(), key=_permuted)
171     def get_connected_servers(self):
172         return self.client._servers
173     def get_nickname_for_serverid(self, serverid):
174         return None
175
176 class NoNetworkClient(Client):
177     def create_tub(self):
178         pass
179     def init_introducer_client(self):
180         pass
181     def setup_logging(self):
182         pass
183     def startService(self):
184         service.MultiService.startService(self)
185     def stopService(self):
186         service.MultiService.stopService(self)
187     def when_tub_ready(self):
188         raise NotImplementedError("NoNetworkClient has no Tub")
189     def init_control(self):
190         pass
191     def init_helper(self):
192         pass
193     def init_key_gen(self):
194         pass
195     def init_storage(self):
196         pass
197     def init_client_storage_broker(self):
198         self.storage_broker = NoNetworkStorageBroker()
199         self.storage_broker.client = self
200     def init_stub_client(self):
201         pass
202     #._servers will be set by the NoNetworkGrid which creates us
203
204 class SimpleStats:
205     def __init__(self):
206         self.counters = {}
207         self.stats_producers = []
208
209     def count(self, name, delta=1):
210         val = self.counters.setdefault(name, 0)
211         self.counters[name] = val + delta
212
213     def register_producer(self, stats_producer):
214         self.stats_producers.append(stats_producer)
215
216     def get_stats(self):
217         stats = {}
218         for sp in self.stats_producers:
219             stats.update(sp.get_stats())
220         ret = { 'counters': self.counters, 'stats': stats }
221         return ret
222
223 class NoNetworkGrid(service.MultiService):
224     def __init__(self, basedir, num_clients=1, num_servers=10,
225                  client_config_hooks={}):
226         service.MultiService.__init__(self)
227         self.basedir = basedir
228         fileutil.make_dirs(basedir)
229
230         self.servers_by_number = {} # maps to StorageServer instance
231         self.wrappers_by_id = {} # maps to wrapped StorageServer instance
232         self.proxies_by_id = {} # maps to IServer on which .rref is a wrapped
233                                 # StorageServer
234         self.clients = []
235
236         for i in range(num_servers):
237             ss = self.make_server(i)
238             self.add_server(i, ss)
239         self.rebuild_serverlist()
240
241         for i in range(num_clients):
242             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
243             clientdir = os.path.join(basedir, "clients",
244                                      idlib.shortnodeid_b2a(clientid))
245             fileutil.make_dirs(clientdir)
246             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
247             f.write("[node]\n")
248             f.write("nickname = client-%d\n" % i)
249             f.write("web.port = tcp:0:interface=127.0.0.1\n")
250             f.write("[storage]\n")
251             f.write("enabled = false\n")
252             f.close()
253             c = None
254             if i in client_config_hooks:
255                 # this hook can either modify tahoe.cfg, or return an
256                 # entirely new Client instance
257                 c = client_config_hooks[i](clientdir)
258             if not c:
259                 c = NoNetworkClient(clientdir)
260                 c.set_default_mutable_keysize(TEST_RSA_KEY_SIZE)
261             c.nodeid = clientid
262             c.short_nodeid = b32encode(clientid).lower()[:8]
263             c._servers = self.all_servers # can be updated later
264             c.setServiceParent(self)
265             self.clients.append(c)
266
267     def make_server(self, i, readonly=False):
268         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
269         serverdir = os.path.join(self.basedir, "servers",
270                                  idlib.shortnodeid_b2a(serverid), "storage")
271         fileutil.make_dirs(serverdir)
272         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats(),
273                            readonly_storage=readonly)
274         ss._no_network_server_number = i
275         return ss
276
277     def add_server(self, i, ss):
278         # to deal with the fact that all StorageServers are named 'storage',
279         # we interpose a middleman
280         middleman = service.MultiService()
281         middleman.setServiceParent(self)
282         ss.setServiceParent(middleman)
283         serverid = ss.my_nodeid
284         self.servers_by_number[i] = ss
285         wrapper = wrap_storage_server(ss)
286         self.wrappers_by_id[serverid] = wrapper
287         self.proxies_by_id[serverid] = NoNetworkServer(serverid, wrapper)
288         self.rebuild_serverlist()
289
290     def get_all_serverids(self):
291         return self.proxies_by_id.keys()
292
293     def rebuild_serverlist(self):
294         self.all_servers = frozenset(self.proxies_by_id.values())
295         for c in self.clients:
296             c._servers = self.all_servers
297
298     def remove_server(self, serverid):
299         # it's enough to remove the server from c._servers (we don't actually
300         # have to detach and stopService it)
301         for i,ss in self.servers_by_number.items():
302             if ss.my_nodeid == serverid:
303                 del self.servers_by_number[i]
304                 break
305         del self.wrappers_by_id[serverid]
306         del self.proxies_by_id[serverid]
307         self.rebuild_serverlist()
308         return ss
309
310     def break_server(self, serverid, count=True):
311         # mark the given server as broken, so it will throw exceptions when
312         # asked to hold a share or serve a share. If count= is a number,
313         # throw that many exceptions before starting to work again.
314         self.wrappers_by_id[serverid].broken = count
315
316     def hang_server(self, serverid):
317         # hang the given server
318         ss = self.wrappers_by_id[serverid]
319         assert ss.hung_until is None
320         ss.hung_until = defer.Deferred()
321
322     def unhang_server(self, serverid):
323         # unhang the given server
324         ss = self.wrappers_by_id[serverid]
325         assert ss.hung_until is not None
326         ss.hung_until.callback(None)
327         ss.hung_until = None
328
329     def nuke_from_orbit(self):
330         """ Empty all share directories in this grid. It's the only way to be sure ;-) """
331         for server in self.servers_by_number.values():
332             for prefixdir in os.listdir(server.sharedir):
333                 if prefixdir != 'incoming':
334                     fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
335
336
337 class GridTestMixin:
338     def setUp(self):
339         self.s = service.MultiService()
340         self.s.startService()
341
342     def tearDown(self):
343         return self.s.stopService()
344
345     def set_up_grid(self, num_clients=1, num_servers=10,
346                     client_config_hooks={}):
347         # self.basedir must be set
348         self.g = NoNetworkGrid(self.basedir,
349                                num_clients=num_clients,
350                                num_servers=num_servers,
351                                client_config_hooks=client_config_hooks)
352         self.g.setServiceParent(self.s)
353         self.client_webports = [c.getServiceNamed("webish").getPortnum()
354                                 for c in self.g.clients]
355         self.client_baseurls = [c.getServiceNamed("webish").getURL()
356                                 for c in self.g.clients]
357
358     def get_clientdir(self, i=0):
359         return self.g.clients[i].basedir
360
361     def get_serverdir(self, i):
362         return self.g.servers_by_number[i].storedir
363
364     def iterate_servers(self):
365         for i in sorted(self.g.servers_by_number.keys()):
366             ss = self.g.servers_by_number[i]
367             yield (i, ss, ss.storedir)
368
369     def find_uri_shares(self, uri):
370         si = tahoe_uri.from_string(uri).get_storage_index()
371         prefixdir = storage_index_to_dir(si)
372         shares = []
373         for i,ss in self.g.servers_by_number.items():
374             serverid = ss.my_nodeid
375             basedir = os.path.join(ss.sharedir, prefixdir)
376             if not os.path.exists(basedir):
377                 continue
378             for f in os.listdir(basedir):
379                 try:
380                     shnum = int(f)
381                     shares.append((shnum, serverid, os.path.join(basedir, f)))
382                 except ValueError:
383                     pass
384         return sorted(shares)
385
386     def copy_shares(self, uri):
387         shares = {}
388         for (shnum, serverid, sharefile) in self.find_uri_shares(uri):
389             shares[sharefile] = open(sharefile, "rb").read()
390         return shares
391
392     def restore_all_shares(self, shares):
393         for sharefile, data in shares.items():
394             open(sharefile, "wb").write(data)
395
396     def delete_share(self, (shnum, serverid, sharefile)):
397         os.unlink(sharefile)
398
399     def delete_shares_numbered(self, uri, shnums):
400         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
401             if i_shnum in shnums:
402                 os.unlink(i_sharefile)
403
404     def corrupt_share(self, (shnum, serverid, sharefile), corruptor_function):
405         sharedata = open(sharefile, "rb").read()
406         corruptdata = corruptor_function(sharedata)
407         open(sharefile, "wb").write(corruptdata)
408
409     def corrupt_shares_numbered(self, uri, shnums, corruptor, debug=False):
410         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
411             if i_shnum in shnums:
412                 sharedata = open(i_sharefile, "rb").read()
413                 corruptdata = corruptor(sharedata, debug=debug)
414                 open(i_sharefile, "wb").write(corruptdata)
415
416     def corrupt_all_shares(self, uri, corruptor, debug=False):
417         for (i_shnum, i_serverid, i_sharefile) in self.find_uri_shares(uri):
418             sharedata = open(i_sharefile, "rb").read()
419             corruptdata = corruptor(sharedata, debug=debug)
420             open(i_sharefile, "wb").write(corruptdata)
421
422     def GET(self, urlpath, followRedirect=False, return_response=False,
423             method="GET", clientnum=0, **kwargs):
424         # if return_response=True, this fires with (data, statuscode,
425         # respheaders) instead of just data.
426         assert not isinstance(urlpath, unicode)
427         url = self.client_baseurls[clientnum] + urlpath
428         factory = HTTPClientGETFactory(url, method=method,
429                                        followRedirect=followRedirect, **kwargs)
430         reactor.connectTCP("localhost", self.client_webports[clientnum],factory)
431         d = factory.deferred
432         def _got_data(data):
433             return (data, factory.status, factory.response_headers)
434         if return_response:
435             d.addCallback(_got_data)
436         return factory.deferred
437
438     def PUT(self, urlpath, **kwargs):
439         return self.GET(urlpath, method="PUT", **kwargs)