]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/node.py
Refactor tahoe.cfg handling to configutil.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / node.py
index 6b67f8d7d626c5227821258f2bc73afb0b5a99bc..f77c2ac26c5e597d0698b888cf7e6867364b4bca 100644 (file)
@@ -1,19 +1,23 @@
-
-import datetime, os.path, re, types
+import datetime, os.path, re, types, ConfigParser, tempfile
 from base64 import b32decode, b32encode
 
-import twisted
-from twisted.python import log
+from twisted.python import log as twlog
 from twisted.application import service
 from twisted.internet import defer, reactor
-from foolscap import Tub, eventual
-from allmydata.util import iputil, observer, humanreadable
-from allmydata.util.assertutil import precondition
-
-# Just to get their versions:
-import allmydata
-import zfec
-import foolscap
+from foolscap.api import Tub, eventually, app_versions
+import foolscap.logging.log
+from allmydata import get_package_versions, get_package_versions_string
+from allmydata.util import log
+from allmydata.util import fileutil, iputil, observer
+from allmydata.util.assertutil import precondition, _assert
+from allmydata.util.fileutil import abspath_expanduser_unicode
+from allmydata.util.encodingutil import get_filesystem_encoding, quote_output
+from allmydata.util import configutil
+
+# Add our application versions to the data that Foolscap's LogPublisher
+# reports.
+for thing, things_version in get_package_versions().iteritems():
+    app_versions.add_version(thing, str(things_version))
 
 # group 1 will be addr (dotted quad string), group 3 if any will be portnum (string)
 ADDR_RE=re.compile("^([1-9][0-9]*\.[1-9][0-9]*\.[1-9][0-9]*\.[1-9][0-9]*)(:([1-9][0-9]*))?$")
@@ -28,116 +32,280 @@ def formatTimeTahoeStyle(self, when):
     else:
         return d.isoformat(" ") + ".000Z"
 
+PRIV_README="""
+This directory contains files which contain private data for the Tahoe node,
+such as private keys.  On Unix-like systems, the permissions on this directory
+are set to disallow users other than its owner from reading the contents of
+the files.   See the 'configuration.rst' documentation file for details."""
+
+class _None: # used as a marker in get_config()
+    pass
+
+class MissingConfigEntry(Exception):
+    """ A required config entry was not found. """
+
+class OldConfigError(Exception):
+    """ An obsolete config file was found. See
+    docs/historical/configuration.rst. """
+    def __str__(self):
+        return ("Found pre-Tahoe-LAFS-v1.3 configuration file(s):\n"
+                "%s\n"
+                "See docs/historical/configuration.rst."
+                % "\n".join([quote_output(fname) for fname in self.args[0]]))
+
+class OldConfigOptionError(Exception):
+    pass
+
+class UnescapedHashError(Exception):
+    def __str__(self):
+        return ("The configuration entry %s contained an unescaped '#' character."
+                % quote_output("[%s]%s = %s" % self.args))
+
+
 class Node(service.MultiService):
-    # this implements common functionality of both Client nodes, Introducer 
-    # nodes, and Vdrive nodes
+    # this implements common functionality of both Client nodes and Introducer
+    # nodes.
     NODETYPE = "unknown NODETYPE"
     PORTNUMFILE = None
     CERTFILE = "node.pem"
-    LOCAL_IP_FILE = "advertised_ip_addresses"
+    GENERATED_FILES = []
 
-    def __init__(self, basedir="."):
+    def __init__(self, basedir=u"."):
         service.MultiService.__init__(self)
-        self.basedir = os.path.abspath(basedir)
+        self.basedir = abspath_expanduser_unicode(unicode(basedir))
+        self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
         self._tub_ready_observerlist = observer.OneShotObserverList()
-        certfile = os.path.join(self.basedir, self.CERTFILE)
+        fileutil.make_dirs(os.path.join(self.basedir, "private"), 0700)
+        open(os.path.join(self.basedir, "private", "README"), "w").write(PRIV_README)
+
+        # creates self.config
+        self.read_config()
+        nickname_utf8 = self.get_config("node", "nickname", "<unspecified>")
+        self.nickname = nickname_utf8.decode("utf-8")
+        assert type(self.nickname) is unicode
+
+        self.init_tempdir()
+        self.create_tub()
+        self.logSource="Node"
+
+        self.setup_ssh()
+        self.setup_logging()
+        self.log("Node constructed. " + get_package_versions_string())
+        iputil.increase_rlimits()
+
+    def init_tempdir(self):
+        tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8')
+        tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir)
+        if not os.path.exists(tempdir):
+            fileutil.make_dirs(tempdir)
+        tempfile.tempdir = tempdir
+        # this should cause twisted.web.http (which uses
+        # tempfile.TemporaryFile) to put large request bodies in the given
+        # directory. Without this, the default temp dir is usually /tmp/,
+        # which is frequently too small.
+        test_name = tempfile.mktemp()
+        _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir)
+
+    @staticmethod
+    def _contains_unescaped_hash(item):
+        characters = iter(item)
+        for c in characters:
+            if c == '\\':
+                characters.next()
+            elif c == '#':
+                return True
+
+        return False
+
+    def get_config(self, section, option, default=_None, boolean=False):
+        try:
+            if boolean:
+                return self.config.getboolean(section, option)
+
+            item = self.config.get(section, option)
+            if option.endswith(".furl") and self._contains_unescaped_hash(item):
+                raise UnescapedHashError(section, option, item)
+
+            return item
+        except (ConfigParser.NoOptionError, ConfigParser.NoSectionError):
+            if default is _None:
+                fn = os.path.join(self.basedir, u"tahoe.cfg")
+                raise MissingConfigEntry("%s is missing the [%s]%s entry"
+                                         % (quote_output(fn), section, option))
+            return default
+
+    def read_config(self):
+        self.error_about_old_config_files()
+        self.config = ConfigParser.SafeConfigParser()
+
+        tahoe_cfg = os.path.join(self.basedir, "tahoe.cfg")
+        try:
+            self.config = configutil.get_config(tahoe_cfg)
+        except EnvironmentError:
+            if os.path.exists(tahoe_cfg):
+                raise
+
+        cfg_tubport = self.get_config("node", "tub.port", "")
+        if not cfg_tubport:
+            # For 'tub.port', tahoe.cfg overrides the individual file on
+            # disk. So only read self._portnumfile if tahoe.cfg doesn't
+            # provide a value.
+            try:
+                file_tubport = fileutil.read(self._portnumfile).strip()
+                configutil.set_config(self.config, "node", "tub.port", file_tubport)
+            except EnvironmentError:
+                if os.path.exists(self._portnumfile):
+                    raise
+
+    def error_about_old_config_files(self):
+        """ If any old configuration files are detected, raise OldConfigError. """
+
+        oldfnames = set()
+        for name in [
+            'nickname', 'webport', 'keepalive_timeout', 'log_gatherer.furl',
+            'disconnect_timeout', 'advertised_ip_addresses', 'introducer.furl',
+            'helper.furl', 'key_generator.furl', 'stats_gatherer.furl',
+            'no_storage', 'readonly_storage', 'sizelimit',
+            'debug_discard_storage', 'run_helper']:
+            if name not in self.GENERATED_FILES:
+                fullfname = os.path.join(self.basedir, name)
+                if os.path.exists(fullfname):
+                    oldfnames.add(fullfname)
+        if oldfnames:
+            e = OldConfigError(oldfnames)
+            twlog.msg(e)
+            raise e
+
+    def create_tub(self):
+        certfile = os.path.join(self.basedir, "private", self.CERTFILE)
         self.tub = Tub(certFile=certfile)
-        os.chmod(certfile, 0600)
         self.tub.setOption("logLocalFailures", True)
         self.tub.setOption("logRemoteFailures", True)
+        self.tub.setOption("expose-remote-exception-types", False)
+
+        # see #521 for a discussion of how to pick these timeout values.
+        keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "")
+        if keepalive_timeout_s:
+            self.tub.setOption("keepaliveTimeout", int(keepalive_timeout_s))
+        disconnect_timeout_s = self.get_config("node", "timeout.disconnect", "")
+        if disconnect_timeout_s:
+            # N.B.: this is in seconds, so use "1800" to get 30min
+            self.tub.setOption("disconnectTimeout", int(disconnect_timeout_s))
+
         self.nodeid = b32decode(self.tub.tubID.upper()) # binary format
         self.write_config("my_nodeid", b32encode(self.nodeid).lower() + "\n")
         self.short_nodeid = b32encode(self.nodeid).lower()[:8] # ready for printing
-        assert self.PORTNUMFILE, "Your node.Node subclass must provide PORTNUMFILE"
-        self._portnumfile = os.path.join(self.basedir, self.PORTNUMFILE)
-        try:
-            portnum = int(open(self._portnumfile, "rU").read())
-        except (EnvironmentError, ValueError):
-            portnum = 0
-        self.tub.listenOn("tcp:%d" % portnum)
+
+        tubport = self.get_config("node", "tub.port", "tcp:0")
+        self.tub.listenOn(tubport)
         # we must wait until our service has started before we can find out
         # our IP address and thus do tub.setLocation, and we can't register
         # any services with the Tub until after that point
         self.tub.setServiceParent(self)
-        self.logSource="Node"
-
-        AUTHKEYSFILEBASE = "authorized_keys."
-        for f in os.listdir(self.basedir):
-            if f.startswith(AUTHKEYSFILEBASE):
-                keyfile = os.path.join(self.basedir, f)
-                portnum = int(f[len(AUTHKEYSFILEBASE):])
-                from allmydata import manhole
-                m = manhole.AuthorizedKeysManhole(portnum, keyfile)
-                m.setServiceParent(self)
-                self.log("AuthorizedKeysManhole listening on %d" % portnum)
 
-        self.setup_logging()
-        self.log("Node constructed.  tahoe version: %s, foolscap: %s,"
-                 " twisted: %s, zfec: %s"
-                 % (allmydata.__version__, foolscap.__version__,
-                    twisted.__version__, zfec.__version__,))
-
-    def get_config(self, name, mode="r", required=False):
+    def setup_ssh(self):
+        ssh_port = self.get_config("node", "ssh.port", "")
+        if ssh_port:
+            ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
+            ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir)
+            from allmydata import manhole
+            m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile)
+            m.setServiceParent(self)
+            self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,))
+
+    def get_app_versions(self):
+        # TODO: merge this with allmydata.get_package_versions
+        return dict(app_versions.versions)
+
+    def get_config_from_file(self, name, required=False):
         """Get the (string) contents of a config file, or None if the file
         did not exist. If required=True, raise an exception rather than
         returning None. Any leading or trailing whitespace will be stripped
         from the data."""
         fn = os.path.join(self.basedir, name)
         try:
-            return open(fn, mode).read().strip()
+            return fileutil.read(fn).strip()
         except EnvironmentError:
             if not required:
                 return None
             raise
 
-    def get_or_create_config(self, name, default_fn, mode="w", filemode=None):
-        """Try to get the (string) contents of a config file, and return it.
-        Any leading or trailing whitespace will be stripped from the data.
-
-        If the file does not exist, try to create it using default_fn, and
-        then return the value that was written. If 'default_fn' is a string,
-        use it as a default value. If not, treat it as a 0-argument callable
-        which is expected to return a string.
+    def write_private_config(self, name, value):
+        """Write the (string) contents of a private config file (which is a
+        config file that resides within the subdirectory named 'private'), and
+        return it.
         """
-        value = self.get_config(name)
-        if value is None:
-            if isinstance(default_fn, (str, unicode)):
-                value = default_fn
+        privname = os.path.join(self.basedir, "private", name)
+        open(privname, "w").write(value)
+
+    def get_private_config(self, name, default=_None):
+        """Read the (string) contents of a private config file (which is a
+        config file that resides within the subdirectory named 'private'),
+        and return it. Return a default, or raise an error if one was not
+        given.
+        """
+        privname = os.path.join(self.basedir, "private", name)
+        try:
+            return fileutil.read(privname)
+        except EnvironmentError:
+            if os.path.exists(privname):
+                raise
+            if default is _None:
+                raise MissingConfigEntry("The required configuration file %s is missing."
+                                         % (quote_output(privname),))
+            return default
+
+    def get_or_create_private_config(self, name, default=_None):
+        """Try to get the (string) contents of a private config file (which
+        is a config file that resides within the subdirectory named
+        'private'), and return it. Any leading or trailing whitespace will be
+        stripped from the data.
+
+        If the file does not exist, and default is not given, report an error.
+        If the file does not exist and a default is specified, try to create
+        it using that default, and then return the value that was written.
+        If 'default' is a string, use it as a default value. If not, treat it
+        as a zero-argument callable that is expected to return a string.
+        """
+        privname = os.path.join(self.basedir, "private", name)
+        try:
+            value = fileutil.read(privname)
+        except EnvironmentError:
+            if os.path.exists(privname):
+                raise
+            if default is _None:
+                raise MissingConfigEntry("The required configuration file %s is missing."
+                                         % (quote_output(privname),))
+            if isinstance(default, basestring):
+                value = default
             else:
-                value = default_fn()
-            fn = os.path.join(self.basedir, name)
-            try:
-                f = open(fn, mode)
-                f.write(value)
-                f.close()
-                if filemode is not None:
-                    os.chmod(fn, filemode)
-            except EnvironmentError, e:
-                self.log("Unable to write config file '%s'" % fn)
-                self.log(e)
-            value = value.strip()
-        return value
+                value = default()
+            fileutil.write(privname, value)
+        return value.strip()
 
     def write_config(self, name, value, mode="w"):
         """Write a string to a config file."""
         fn = os.path.join(self.basedir, name)
         try:
-            open(fn, mode).write(value)
+            fileutil.write(fn, value, mode)
         except EnvironmentError, e:
             self.log("Unable to write config file '%s'" % fn)
             self.log(e)
 
-    def get_versions(self):
-        return {'allmydata': allmydata.__version__,
-                'foolscap': foolscap.__version__,
-                'twisted': twisted.__version__,
-                'zfec': zfec.__version__,
-                }
-
     def startService(self):
-        # note: this class can only be started and stopped once.
+        # Note: this class can be started and stopped at most once.
         self.log("Node.startService")
-        eventual.eventually(self._startService)
+        # Record the process id in the twisted log, after startService()
+        # (__init__ is called before fork(), but startService is called
+        # after). Note that Foolscap logs handle pid-logging by itself, no
+        # need to send a pid to the foolscap log here.
+        twlog.msg("My pid: %s" % os.getpid())
+        try:
+            os.chmod("twistd.pid", 0644)
+        except EnvironmentError:
+            pass
+        # Delay until the reactor is running.
+        eventually(self._startService)
 
     def _startService(self):
         precondition(reactor.running)
@@ -145,21 +313,24 @@ class Node(service.MultiService):
 
         service.MultiService.startService(self)
         d = defer.succeed(None)
-        d.addCallback(lambda res: iputil.get_local_addresses_async())
         d.addCallback(self._setup_tub)
-        d.addCallback(lambda res: self.tub_ready())
         def _ready(res):
             self.log("%s running" % self.NODETYPE)
             self._tub_ready_observerlist.fire(self)
             return self
         d.addCallback(_ready)
-        def _die(failure):
-            self.log('_startService() failed')
-            log.err(failure)
-            #reactor.stop() # for unknown reasons, reactor.stop() isn't working.  [ ] TODO
-            self.log('calling os.abort()')
-            os.abort()
-        d.addErrback(_die)
+        d.addErrback(self._service_startup_failed)
+
+    def _service_startup_failed(self, failure):
+        self.log('_startService() failed')
+        log.err(failure)
+        print "Node._startService failed, aborting"
+        print failure
+        #reactor.stop() # for unknown reasons, reactor.stop() isn't working.  [ ] TODO
+        self.log('calling os.abort()')
+        twlog.msg('calling os.abort()') # make sure it gets into twistd.log
+        print "calling os.abort()"
+        os.abort()
 
     def stopService(self):
         self.log("Node.stopService")
@@ -177,64 +348,65 @@ class Node(service.MultiService):
         return self.stopService()
 
     def setup_logging(self):
-        # we replace the formatTime() method of the log observer that twistd
-        # set up for us, with a method that uses better timestamps.
-        for o in log.theLogPublisher.observers:
+        # we replace the formatTime() method of the log observer that
+        # twistd set up for us, with a method that uses our preferred
+        # timestamp format.
+        for o in twlog.theLogPublisher.observers:
             # o might be a FileLogObserver's .emit method
             if type(o) is type(self.setup_logging): # bound method
                 ob = o.im_self
-                if isinstance(ob, log.FileLogObserver):
+                if isinstance(ob, twlog.FileLogObserver):
                     newmeth = types.UnboundMethodType(formatTimeTahoeStyle, ob, ob.__class__)
                     ob.formatTime = newmeth
         # TODO: twisted >2.5.0 offers maxRotatedFiles=50
 
-    def log(self, msg, src="", args=()):
-        if src:
-            logsrc = src
-        else:
-            logsrc = self.logSource
-        if args:
-            try:
-                msg = msg % tuple(map(humanreadable.hr, args))
-            except TypeError, e:
-                msg = "ERROR: output string '%s' contained invalid %% expansion, error: %s, args: %s\n" % (`msg`, e, `args`)
+        lgfurl_file = os.path.join(self.basedir, "private", "logport.furl").encode(get_filesystem_encoding())
+        self.tub.setOption("logport-furlfile", lgfurl_file)
+        lgfurl = self.get_config("node", "log_gatherer.furl", "")
+        if lgfurl:
+            # this is in addition to the contents of log-gatherer-furlfile
+            self.tub.setOption("log-gatherer-furl", lgfurl)
+        self.tub.setOption("log-gatherer-furlfile",
+                           os.path.join(self.basedir, "log_gatherer.furl"))
 
-        log.callWithContext({"system":logsrc},
-                            log.msg,
-                            (self.short_nodeid + ": " + humanreadable.hr(msg)))
+        incident_dir = os.path.join(self.basedir, "logs", "incidents")
+        foolscap.logging.log.setLogDir(incident_dir.encode(get_filesystem_encoding()))
 
-    def _setup_tub(self, local_addresses):
+    def log(self, *args, **kwargs):
+        return log.msg(*args, **kwargs)
+
+    def _setup_tub(self, ign):
         # we can't get a dynamically-assigned portnum until our Tub is
         # running, which means after startService.
         l = self.tub.getListeners()[0]
         portnum = l.getPortnum()
-        # record which port we're listening on, so we can grab the same one next time
-        open(self._portnumfile, "w").write("%d\n" % portnum)
-
-        local_addresses = [ "%s:%d" % (addr, portnum,) for addr in local_addresses ]
-
-        addresses = []
-        try:
-            for addrline in open(os.path.join(self.basedir, self.LOCAL_IP_FILE), "rU"):
-                mo = ADDR_RE.search(addrline)
-                if mo:
-                    (addr, dummy, aportnum,) = mo.groups()
-                    if aportnum is None:
-                        aportnum = portnum
-                    addresses.append("%s:%d" % (addr, int(aportnum),))
-        except EnvironmentError:
-            pass
-
-        addresses.extend(local_addresses)
-
-        location = ",".join(addresses)
-        self.log("Tub location set to %s" % location)
-        self.tub.setLocation(location)
-        return self.tub
+        # record which port we're listening on, so we can grab the same one
+        # next time
+        fileutil.write_atomically(self._portnumfile, "%d\n" % portnum, mode="")
+
+        location = self.get_config("node", "tub.location", "AUTO")
+
+        # Replace the location "AUTO", if present, with the detected local addresses.
+        split_location = location.split(",")
+        if "AUTO" in split_location:
+            d = iputil.get_local_addresses_async()
+            def _add_local(local_addresses):
+                while "AUTO" in split_location:
+                    split_location.remove("AUTO")
+
+                split_location.extend([ "%s:%d" % (addr, portnum)
+                                        for addr in local_addresses ])
+                return ",".join(split_location)
+            d.addCallback(_add_local)
+        else:
+            d = defer.succeed(location)
 
-    def tub_ready(self):
-        # called when the Tub is available for registerReference
-        pass
+        def _got_location(location):
+            self.log("Tub location set to %s" % (location,))
+            self.tub.setLocation(location)
+            return self.tub
+        d.addCallback(_got_location)
+        return d
 
     def when_tub_ready(self):
         return self._tub_ready_observerlist.when_fired()
@@ -242,4 +414,3 @@ class Node(service.MultiService):
     def add_service(self, s):
         s.setServiceParent(self)
         return s
-