]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/node.py
Improve error reporting for '#' characters in config entries. refs #2128
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / node.py
1 import datetime, os.path, re, types, ConfigParser, tempfile
2 from base64 import b32decode, b32encode
3
4 from twisted.python import log as twlog
5 from twisted.application import service
6 from twisted.internet import defer, reactor
7 from foolscap.api import Tub, eventually, app_versions
8 import foolscap.logging.log
9 from allmydata import get_package_versions, get_package_versions_string
10 from allmydata.util import log
11 from allmydata.util import fileutil, iputil, observer
12 from allmydata.util.assertutil import precondition, _assert
13 from allmydata.util.fileutil import abspath_expanduser_unicode
14 from allmydata.util.encodingutil import get_filesystem_encoding, quote_output
15
16 # Add our application versions to the data that Foolscap's LogPublisher
17 # reports.
18 for thing, things_version in get_package_versions().iteritems():
19     app_versions.add_version(thing, str(things_version))
20
21 # group 1 will be addr (dotted quad string), group 3 if any will be portnum (string)
22 ADDR_RE=re.compile("^([1-9][0-9]*\.[1-9][0-9]*\.[1-9][0-9]*\.[1-9][0-9]*)(:([1-9][0-9]*))?$")
23
24
25 def formatTimeTahoeStyle(self, when):
26     # we want UTC timestamps that look like:
27     #  2007-10-12 00:26:28.566Z [Client] rnp752lz: 'client running'
28     d = datetime.datetime.utcfromtimestamp(when)
29     if d.microsecond:
30         return d.isoformat(" ")[:-3]+"Z"
31     else:
32         return d.isoformat(" ") + ".000Z"
33
34 PRIV_README="""
35 This directory contains files which contain private data for the Tahoe node,
36 such as private keys.  On Unix-like systems, the permissions on this directory
37 are set to disallow users other than its owner from reading the contents of
38 the files.   See the 'configuration.rst' documentation file for details."""
39
40 class _None: # used as a marker in get_config()
41     pass
42
43 class MissingConfigEntry(Exception):
44     """ A required config entry was not found. """
45
46 class OldConfigError(Exception):
47     """ An obsolete config file was found. See
48     docs/historical/configuration.rst. """
49     def __str__(self):
50         return ("Found pre-Tahoe-LAFS-v1.3 configuration file(s):\n"
51                 "%s\n"
52                 "See docs/historical/configuration.rst."
53                 % "\n".join([quote_output(fname) for fname in self.args[0]]))
54
55 class OldConfigOptionError(Exception):
56     pass
57
58 class UnescapedHashError(Exception):
59     def __str__(self):
60         return ("The configuration entry %s contained an unescaped '#' character."
61                 % quote_output("[%s]%s = %s" % self.args))
62
63
64 class Node(service.MultiService):
65     # this implements common functionality of both Client nodes and Introducer
66     # nodes.
67     NODETYPE = "unknown NODETYPE"
68     PORTNUMFILE = None
69     CERTFILE = "node.pem"
70     GENERATED_FILES = []
71
72     def __init__(self, basedir=u"."):
73         service.MultiService.__init__(self)
74         self.basedir = abspath_expanduser_unicode(unicode(basedir))
75         self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
76         self._tub_ready_observerlist = observer.OneShotObserverList()
77         fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700)
78         open(os.path.join(self.basedir, "private", "README"), "w").write(PRIV_README)
79
80         # creates self.config
81         self.read_config()
82         nickname_utf8 = self.get_config("node", "nickname", "<unspecified>")
83         self.nickname = nickname_utf8.decode("utf-8")
84         assert type(self.nickname) is unicode
85
86         self.init_tempdir()
87         self.create_tub()
88         self.logSource="Node"
89
90         self.setup_ssh()
91         self.setup_logging()
92         self.log("Node constructed. " + get_package_versions_string())
93         iputil.increase_rlimits()
94
95     def init_tempdir(self):
96         local_tempdir_utf8 = "tmp" # default is NODEDIR/tmp/
97         tempdir = self.get_config("node", "tempdir", local_tempdir_utf8).decode('utf-8')
98         tempdir = os.path.join(self.basedir, tempdir)
99         if not os.path.exists(tempdir):
100             fileutil.make_dirs(tempdir)
101         tempfile.tempdir = abspath_expanduser_unicode(tempdir)
102         # this should cause twisted.web.http (which uses
103         # tempfile.TemporaryFile) to put large request bodies in the given
104         # directory. Without this, the default temp dir is usually /tmp/,
105         # which is frequently too small.
106         test_name = tempfile.mktemp()
107         _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir)
108
109     @staticmethod
110     def _contains_unescaped_hash(item):
111         characters = iter(item)
112         for c in characters:
113             if c == '\\':
114                 characters.next()
115             elif c == '#':
116                 return True
117
118         return False
119
120     def get_config(self, section, option, default=_None, boolean=False):
121         try:
122             if boolean:
123                 return self.config.getboolean(section, option)
124
125             item = self.config.get(section, option)
126             if option.endswith(".furl") and self._contains_unescaped_hash(item):
127                 raise UnescapedHashError(section, option, item)
128
129             return item
130         except (ConfigParser.NoOptionError, ConfigParser.NoSectionError):
131             if default is _None:
132                 fn = os.path.join(self.basedir, u"tahoe.cfg")
133                 raise MissingConfigEntry("%s is missing the [%s]%s entry"
134                                          % (quote_output(fn), section, option))
135             return default
136
137     def set_config(self, section, option, value):
138         if not self.config.has_section(section):
139             self.config.add_section(section)
140         self.config.set(section, option, value)
141         assert self.config.get(section, option) == value
142
143     def read_config(self):
144         self.error_about_old_config_files()
145         self.config = ConfigParser.SafeConfigParser()
146
147         tahoe_cfg = os.path.join(self.basedir, "tahoe.cfg")
148         try:
149             f = open(tahoe_cfg, "rb")
150             try:
151                 # Skip any initial Byte Order Mark. Since this is an ordinary file, we
152                 # don't need to handle incomplete reads, and can assume seekability.
153                 if f.read(3) != '\xEF\xBB\xBF':
154                     f.seek(0)
155                 self.config.readfp(f)
156             finally:
157                 f.close()
158         except EnvironmentError:
159             if os.path.exists(tahoe_cfg):
160                 raise
161
162         cfg_tubport = self.get_config("node", "tub.port", "")
163         if not cfg_tubport:
164             # For 'tub.port', tahoe.cfg overrides the individual file on
165             # disk. So only read self._portnumfile if tahoe.cfg doesn't
166             # provide a value.
167             try:
168                 file_tubport = fileutil.read(self._portnumfile).strip()
169                 self.set_config("node", "tub.port", file_tubport)
170             except EnvironmentError:
171                 if os.path.exists(self._portnumfile):
172                     raise
173
174     def error_about_old_config_files(self):
175         """ If any old configuration files are detected, raise OldConfigError. """
176
177         oldfnames = set()
178         for name in [
179             'nickname', 'webport', 'keepalive_timeout', 'log_gatherer.furl',
180             'disconnect_timeout', 'advertised_ip_addresses', 'introducer.furl',
181             'helper.furl', 'key_generator.furl', 'stats_gatherer.furl',
182             'no_storage', 'readonly_storage', 'sizelimit',
183             'debug_discard_storage', 'run_helper']:
184             if name not in self.GENERATED_FILES:
185                 fullfname = os.path.join(self.basedir, name)
186                 if os.path.exists(fullfname):
187                     oldfnames.add(fullfname)
188         if oldfnames:
189             e = OldConfigError(oldfnames)
190             twlog.msg(e)
191             raise e
192
193     def create_tub(self):
194         certfile = os.path.join(self.basedir, "private", self.CERTFILE)
195         self.tub = Tub(certFile=certfile)
196         self.tub.setOption("logLocalFailures", True)
197         self.tub.setOption("logRemoteFailures", True)
198         self.tub.setOption("expose-remote-exception-types", False)
199
200         # see #521 for a discussion of how to pick these timeout values.
201         keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "")
202         if keepalive_timeout_s:
203             self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s))
204         disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "")
205         if disconnect_timeout_s:
206             # N.B.: this is in seconds, so use "1800" to get 30min
207             self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s))
208
209         self.nodeid = b32decode(self.tub.tubID.upper()) # binary format
210         self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n")
211         self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing
212
213         tubport = self.get_config("node", "tub.port", "tcp:0")
214         self.tub.listenOn(tubport)
215         # we must wait until our service has started before we can find out
216         # our IP address and thus do tub.setLocation, and we can't register
217         # any services with the Tub until after that point
218         self.tub.setServiceParent(self)
219
220     def setup_ssh(self):
221         ssh_port = self.get_config("node", "ssh.port", "")
222         if ssh_port:
223             ssh_keyfile = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
224             from allmydata import manhole
225             m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile.encode(get_filesystem_encoding()))
226             m.setServiceParent(self)
227             self.log("AuthorizedKeysManhole listening on %s" % ssh_port)
228
229     def get_app_versions(self):
230         # TODO: merge this with allmydata.get_package_versions
231         return dict(app_versions.versions)
232
233     def get_config_from_file(self, name, required=False):
234         """Get the (string) contents of a config file, or None if the file
235         did not exist. If required=True, raise an exception rather than
236         returning None. Any leading or trailing whitespace will be stripped
237         from the data."""
238         fn = os.path.join(self.basedir, name)
239         try:
240             return fileutil.read(fn).strip()
241         except EnvironmentError:
242             if not required:
243                 return None
244             raise
245
246     def write_private_config(self, name, value):
247         """Write the (string) contents of a private config file (which is a
248         config file that resides within the subdirectory named 'private'), and
249         return it.
250         """
251         privname = os.path.join(self.basedir, "private", name)
252         open(privname, "w").write(value)
253
254     def get_private_config(self, name, default=_None):
255         """Read the (string) contents of a private config file (which is a
256         config file that resides within the subdirectory named 'private'),
257         and return it. Return a default, or raise an error if one was not
258         given.
259         """
260         privname = os.path.join(self.basedir, "private", name)
261         try:
262             return fileutil.read(privname)
263         except EnvironmentError:
264             if os.path.exists(privname):
265                 raise
266             if default is _None:
267                 raise MissingConfigEntry("The required configuration file %s is missing."
268                                          % (quote_output(privname),))
269             return default
270
271     def get_or_create_private_config(self, name, default=_None):
272         """Try to get the (string) contents of a private config file (which
273         is a config file that resides within the subdirectory named
274         'private'), and return it. Any leading or trailing whitespace will be
275         stripped from the data.
276
277         If the file does not exist, and default is not given, report an error.
278         If the file does not exist and a default is specified, try to create
279         it using that default, and then return the value that was written.
280         If 'default' is a string, use it as a default value. If not, treat it
281         as a zero-argument callable that is expected to return a string.
282         """
283         privname = os.path.join(self.basedir, "private", name)
284         try:
285             value = fileutil.read(privname)
286         except EnvironmentError:
287             if os.path.exists(privname):
288                 raise
289             if default is _None:
290                 raise MissingConfigEntry("The required configuration file %s is missing."
291                                          % (quote_output(privname),))
292             if isinstance(default, basestring):
293                 value = default
294             else:
295                 value = default()
296             fileutil.write(privname, value)
297         return value.strip()
298
299     def write_config(self, name, value, mode="w"):
300         """Write a string to a config file."""
301         fn = os.path.join(self.basedir, name)
302         try:
303             fileutil.write(fn, value, mode)
304         except EnvironmentError, e:
305             self.log("Unable to write config file '%s'" % fn)
306             self.log(e)
307
308     def startService(self):
309         # Note: this class can be started and stopped at most once.
310         self.log("Node.startService")
311         # Record the process id in the twisted log, after startService()
312         # (__init__ is called before fork(), but startService is called
313         # after). Note that Foolscap logs handle pid-logging by itself, no
314         # need to send a pid to the foolscap log here.
315         twlog.msg("My pid: %s" % os.getpid())
316         try:
317             os.chmod("twistd.pid", 0644)
318         except EnvironmentError:
319             pass
320         # Delay until the reactor is running.
321         eventually(self._startService)
322
323     def _startService(self):
324         precondition(reactor.running)
325         self.log("Node._startService")
326
327         service.MultiService.startService(self)
328         d = defer.succeed(None)
329         d.addCallback(lambda res: iputil.get_local_addresses_async())
330         d.addCallback(self._setup_tub)
331         def _ready(res):
332             self.log("%s running" % self.NODETYPE)
333             self._tub_ready_observerlist.fire(self)
334             return self
335         d.addCallback(_ready)
336         d.addErrback(self._service_startup_failed)
337
338     def _service_startup_failed(self, failure):
339         self.log('_startService() failed')
340         log.err(failure)
341         print "Node._startService failed, aborting"
342         print failure
343         #reactor.stop() # for unknown reasons, reactor.stop() isn't working.  [ ] TODO
344         self.log('calling os.abort()')
345         twlog.msg('calling os.abort()') # make sure it gets into twistd.log
346         print "calling os.abort()"
347         os.abort()
348
349     def stopService(self):
350         self.log("Node.stopService")
351         d = self._tub_ready_observerlist.when_fired()
352         def _really_stopService(ignored):
353             self.log("Node._really_stopService")
354             return service.MultiService.stopService(self)
355         d.addCallback(_really_stopService)
356         return d
357
358     def shutdown(self):
359         """Shut down the node. Returns a Deferred that fires (with None) when
360         it finally stops kicking."""
361         self.log("Node.shutdown")
362         return self.stopService()
363
364     def setup_logging(self):
365         # we replace the formatTime() method of the log observer that
366         # twistd set up for us, with a method that uses our preferred
367         # timestamp format.
368         for o in twlog.theLogPublisher.observers:
369             # o might be a FileLogObserver's .emit method
370             if type(o) is type(self.setup_logging): # bound method
371                 ob = o.im_self
372                 if isinstance(ob, twlog.FileLogObserver):
373                     newmeth = types.UnboundMethodType(formatTimeTahoeStyle, ob, ob.__class__)
374                     ob.formatTime = newmeth
375         # TODO: twisted >2.5.0 offers maxRotatedFiles=50
376
377         lgfurl_file = os.path.join(self.basedir, "private", "logport.furl").encode(get_filesystem_encoding())
378         self.tub.setOption("logport-furlfile", lgfurl_file)
379         lgfurl = self.get_config("node", "log_gatherer.furl", "")
380         if lgfurl:
381             # this is in addition to the contents of log-gatherer-furlfile
382             self.tub.setOption("log-gatherer-furl", lgfurl)
383         self.tub.setOption("log-gatherer-furlfile",
384                            os.path.join(self.basedir, "log_gatherer.furl"))
385         self.tub.setOption("bridge-twisted-logs", True)
386         incident_dir = os.path.join(self.basedir, "logs", "incidents")
387         # this doesn't quite work yet: unit tests fail
388         foolscap.logging.log.setLogDir(incident_dir.encode(get_filesystem_encoding()))
389
390     def log(self, *args, **kwargs):
391         return log.msg(*args, **kwargs)
392
393     def _setup_tub(self, local_addresses):
394         # we can't get a dynamically-assigned portnum until our Tub is
395         # running, which means after startService.
396         l = self.tub.getListeners()[0]
397         portnum = l.getPortnum()
398         # record which port we're listening on, so we can grab the same one
399         # next time
400         fileutil.write_atomically(self._portnumfile, "%d\n" % portnum, mode="")
401
402         base_location = ",".join([ "%s:%d" % (addr, portnum)
403                                    for addr in local_addresses ])
404         location = self.get_config("node", "tub.location", base_location)
405         self.log("Tub location set to %s" % location)
406         self.tub.setLocation(location)
407
408         return self.tub
409
410     def when_tub_ready(self):
411         return self._tub_ready_observerlist.when_fired()
412
413     def add_service(self, s):
414         s.setServiceParent(self)
415         return s