]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/no_network.py
test/no_network.py: add a basic stats provider
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / no_network.py
1
2 # This contains a test harness that creates a full Tahoe grid in a single
3 # process (actually in a single MultiService) which does not use the network.
4 # It does not use an Introducer, and there are no foolscap Tubs. Each storage
5 # server puts real shares on disk, but is accessed through loopback
6 # RemoteReferences instead of over serialized SSL. It is not as complete as
7 # the common.SystemTestMixin framework (which does use the network), but
8 # should be considerably faster: on my laptop, it takes 50-80ms to start up,
9 # whereas SystemTestMixin takes close to 2s.
10
11 # This should be useful for tests which want to examine and/or manipulate the
12 # uploaded shares, checker/verifier/repairer tests, etc. The clients have no
13 # Tubs, so it is not useful for tests that involve a Helper, a KeyGenerator,
14 # or the control.furl .
15
16 import os.path
17 import sha
18 from twisted.application import service
19 from foolscap import Referenceable
20 from foolscap.eventual import fireEventually
21 from base64 import b32encode
22 from allmydata import uri as tahoe_uri
23 from allmydata.client import Client
24 from allmydata.storage.server import StorageServer, storage_index_to_dir
25 from allmydata.util import fileutil, idlib, hashutil, rrefutil
26 from allmydata.introducer.client import RemoteServiceConnector
27
28 class IntentionalError(Exception):
29     pass
30
31 class Marker:
32     pass
33
34 class LocalWrapper:
35     def __init__(self, original):
36         self.original = original
37         self.broken = False
38         self.post_call_notifier = None
39         self.disconnectors = {}
40
41     def callRemoteOnly(self, methname, *args, **kwargs):
42         d = self.callRemote(methname, *args, **kwargs)
43         return None
44
45     def callRemote(self, methname, *args, **kwargs):
46         # this is ideally a Membrane, but that's too hard. We do a shallow
47         # wrapping of inbound arguments, and per-methodname wrapping of
48         # selected return values.
49         def wrap(a):
50             if isinstance(a, Referenceable):
51                 return LocalWrapper(a)
52             else:
53                 return a
54         args = tuple([wrap(a) for a in args])
55         kwargs = dict([(k,wrap(kwargs[k])) for k in kwargs])
56         def _call():
57             if self.broken:
58                 raise IntentionalError("I was asked to break")
59             meth = getattr(self.original, "remote_" + methname)
60             return meth(*args, **kwargs)
61         d = fireEventually()
62         d.addCallback(lambda res: _call())
63         def _return_membrane(res):
64             # rather than complete the difficult task of building a
65             # fully-general Membrane (which would locate all Referenceable
66             # objects that cross the simulated wire and replace them with
67             # wrappers), we special-case certain methods that we happen to
68             # know will return Referenceables.
69             if methname == "allocate_buckets":
70                 (alreadygot, allocated) = res
71                 for shnum in allocated:
72                     allocated[shnum] = LocalWrapper(allocated[shnum])
73             if methname == "get_buckets":
74                 for shnum in res:
75                     res[shnum] = LocalWrapper(res[shnum])
76             return res
77         d.addCallback(_return_membrane)
78         if self.post_call_notifier:
79             d.addCallback(self.post_call_notifier, methname)
80         return d
81
82     def notifyOnDisconnect(self, f, *args, **kwargs):
83         m = Marker()
84         self.disconnectors[m] = (f, args, kwargs)
85         return m
86     def dontNotifyOnDisconnect(self, marker):
87         del self.disconnectors[marker]
88
89 def wrap(original, service_name):
90     # The code in immutable.checker insists upon asserting the truth of
91     # isinstance(rref, rrefutil.WrappedRemoteReference). Much of the
92     # upload/download code uses rref.version (which normally comes from
93     # rrefutil.VersionedRemoteReference). To avoid using a network, we want a
94     # LocalWrapper here. Try to satisfy all these constraints at the same
95     # time.
96     local = LocalWrapper(original)
97     wrapped = rrefutil.WrappedRemoteReference(local)
98     try:
99         version = original.remote_get_version()
100     except AttributeError:
101         version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
102     wrapped.version = version
103     return wrapped
104
105 class NoNetworkClient(Client):
106
107     def create_tub(self):
108         pass
109     def init_introducer_client(self):
110         pass
111     def setup_logging(self):
112         pass
113     def startService(self):
114         service.MultiService.startService(self)
115     def stopService(self):
116         service.MultiService.stopService(self)
117     def when_tub_ready(self):
118         raise NotImplementedError("NoNetworkClient has no Tub")
119     def init_control(self):
120         pass
121     def init_helper(self):
122         pass
123     def init_key_gen(self):
124         pass
125     def init_storage(self):
126         pass
127     def init_stub_client(self):
128         pass
129
130     def get_servers(self, service_name):
131         return self._servers
132
133     def get_permuted_peers(self, service_name, key):
134         return sorted(self._servers, key=lambda x: sha.new(key+x[0]).digest())
135     def get_nickname_for_peerid(self, peerid):
136         return None
137
138 class SimpleStats:
139     def __init__(self):
140         self.counters = {}
141         self.stats_producers = []
142
143     def count(self, name, delta=1):
144         val = self.counters.setdefault(name, 0)
145         self.counters[name] = val + delta
146
147     def register_producer(self, stats_producer):
148         self.stats_producers.append(stats_producer)
149
150     def get_stats(self):
151         stats = {}
152         for sp in self.stats_producers:
153             stats.update(sp.get_stats())
154         ret = { 'counters': self.counters, 'stats': stats }
155         return ret
156
157 class NoNetworkGrid(service.MultiService):
158     def __init__(self, basedir, num_clients=1, num_servers=10,
159                  client_config_hooks={}):
160         service.MultiService.__init__(self)
161         self.basedir = basedir
162         fileutil.make_dirs(basedir)
163
164         self.servers_by_number = {}
165         self.servers_by_id = {}
166         self.clients = []
167
168         for i in range(num_servers):
169             ss = self.make_server(i)
170             self.add_server(i, ss)
171
172         for i in range(num_clients):
173             clientid = hashutil.tagged_hash("clientid", str(i))[:20]
174             clientdir = os.path.join(basedir, "clients",
175                                      idlib.shortnodeid_b2a(clientid))
176             fileutil.make_dirs(clientdir)
177             f = open(os.path.join(clientdir, "tahoe.cfg"), "w")
178             f.write("[node]\n")
179             f.write("nickname = client-%d\n" % i)
180             f.write("web.port = tcp:0:interface=127.0.0.1\n")
181             f.write("[storage]\n")
182             f.write("enabled = false\n")
183             f.close()
184             c = None
185             if i in client_config_hooks:
186                 # this hook can either modify tahoe.cfg, or return an
187                 # entirely new Client instance
188                 c = client_config_hooks[i](clientdir)
189             if not c:
190                 c = NoNetworkClient(clientdir)
191             c.nodeid = clientid
192             c.short_nodeid = b32encode(clientid).lower()[:8]
193             c._servers = self.all_servers # can be updated later
194             c.setServiceParent(self)
195             self.clients.append(c)
196
197     def make_server(self, i):
198         serverid = hashutil.tagged_hash("serverid", str(i))[:20]
199         serverdir = os.path.join(self.basedir, "servers",
200                                  idlib.shortnodeid_b2a(serverid))
201         fileutil.make_dirs(serverdir)
202         ss = StorageServer(serverdir, serverid, stats_provider=SimpleStats())
203         return ss
204
205     def add_server(self, i, ss):
206         # to deal with the fact that all StorageServers are named 'storage',
207         # we interpose a middleman
208         middleman = service.MultiService()
209         middleman.setServiceParent(self)
210         ss.setServiceParent(middleman)
211         serverid = ss.my_nodeid
212         self.servers_by_number[i] = ss
213         self.servers_by_id[serverid] = wrap(ss, "storage")
214         self.all_servers = frozenset(self.servers_by_id.items())
215         for c in self.clients:
216             c._servers = self.all_servers
217
218 class GridTestMixin:
219     def setUp(self):
220         self.s = service.MultiService()
221         self.s.startService()
222
223     def tearDown(self):
224         return self.s.stopService()
225
226     def set_up_grid(self, num_clients=1, num_servers=10,
227                     client_config_hooks={}):
228         # self.basedir must be set
229         self.g = NoNetworkGrid(self.basedir,
230                                num_clients=num_clients,
231                                num_servers=num_servers,
232                                client_config_hooks=client_config_hooks)
233         self.g.setServiceParent(self.s)
234         self.client_webports = [c.getServiceNamed("webish").listener._port.getHost().port
235                                 for c in self.g.clients]
236         self.client_baseurls = ["http://localhost:%d/" % p
237                                 for p in self.client_webports]
238
239     def get_clientdir(self, i=0):
240         return self.g.clients[i].basedir
241
242     def get_serverdir(self, i):
243         return self.g.servers_by_number[i].storedir
244
245     def iterate_servers(self):
246         for i in sorted(self.g.servers_by_number.keys()):
247             ss = self.g.servers_by_number[i]
248             yield (i, ss, ss.storedir)
249
250     def find_shares(self, uri):
251         si = tahoe_uri.from_string(uri).get_storage_index()
252         prefixdir = storage_index_to_dir(si)
253         shares = []
254         for i,ss in self.g.servers_by_number.items():
255             serverid = ss.my_nodeid
256             basedir = os.path.join(ss.storedir, "shares", prefixdir)
257             if not os.path.exists(basedir):
258                 continue
259             for f in os.listdir(basedir):
260                 try:
261                     shnum = int(f)
262                     shares.append((shnum, serverid, os.path.join(basedir, f)))
263                 except ValueError:
264                     pass
265         return sorted(shares)
266