]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/node.py
Disable bridging of foolscap logging to the Twisted log, and remove docs for it....
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / node.py
index a671d1dfbedf7776b1b88fc41e3f5d85b1884514..98ce3cb7c539aacec662488446ade79ea599747b 100644 (file)
@@ -11,7 +11,7 @@ from allmydata.util import log
 from allmydata.util import fileutil, iputil, observer
 from allmydata.util.assertutil import precondition, _assert
 from allmydata.util.fileutil import abspath_expanduser_unicode
-from allmydata.util.encodingutil import get_filesystem_encoding
+from allmydata.util.encodingutil import get_filesystem_encoding, quote_output
 
 # Add our application versions to the data that Foolscap's LogPublisher
 # reports.
@@ -46,6 +46,20 @@ class MissingConfigEntry(Exception):
 class OldConfigError(Exception):
     """ An obsolete config file was found. See
     docs/historical/configuration.rst. """
+    def __str__(self):
+        return ("Found pre-Tahoe-LAFS-v1.3 configuration file(s):\n"
+                "%s\n"
+                "See docs/historical/configuration.rst."
+                % "\n".join([quote_output(fname) for fname in self.args[0]]))
+
+class OldConfigOptionError(Exception):
+    pass
+
+class UnescapedHashError(Exception):
+    def __str__(self):
+        return ("The configuration entry %s contained an unescaped '#' character."
+                % quote_output("[%s]%s = %s" % self.args))
+
 
 class Node(service.MultiService):
     # this implements common functionality of both Client nodes and Introducer
@@ -79,12 +93,11 @@ class Node(service.MultiService):
         iputil.increase_rlimits()
 
     def init_tempdir(self):
-        local_tempdir_utf8 = "tmp" # default is NODEDIR/tmp/
-        tempdir = self.get_config("node", "tempdir", local_tempdir_utf8).decode('utf-8')
-        tempdir = os.path.join(self.basedir, tempdir)
+        tempdir_config = self.get_config("node", "tempdir", "tmp").decode('utf-8')
+        tempdir = abspath_expanduser_unicode(tempdir_config, base=self.basedir)
         if not os.path.exists(tempdir):
             fileutil.make_dirs(tempdir)
-        tempfile.tempdir = abspath_expanduser_unicode(tempdir)
+        tempfile.tempdir = tempdir
         # this should cause twisted.web.http (which uses
         # tempfile.TemporaryFile) to put large request bodies in the given
         # directory. Without this, the default temp dir is usually /tmp/,
@@ -92,16 +105,32 @@ class Node(service.MultiService):
         test_name = tempfile.mktemp()
         _assert(os.path.dirname(test_name) == tempdir, test_name, tempdir)
 
+    @staticmethod
+    def _contains_unescaped_hash(item):
+        characters = iter(item)
+        for c in characters:
+            if c == '\\':
+                characters.next()
+            elif c == '#':
+                return True
+
+        return False
+
     def get_config(self, section, option, default=_None, boolean=False):
         try:
             if boolean:
                 return self.config.getboolean(section, option)
-            return self.config.get(section, option)
+
+            item = self.config.get(section, option)
+            if option.endswith(".furl") and self._contains_unescaped_hash(item):
+                raise UnescapedHashError(section, option, item)
+
+            return item
         except (ConfigParser.NoOptionError, ConfigParser.NoSectionError):
             if default is _None:
-                fn = os.path.join(self.basedir, "tahoe.cfg")
+                fn = os.path.join(self.basedir, u"tahoe.cfg")
                 raise MissingConfigEntry("%s is missing the [%s]%s entry"
-                                         % (fn, section, option))
+                                         % (quote_output(fn), section, option))
             return default
 
     def set_config(self, section, option, value):
@@ -113,7 +142,33 @@ class Node(service.MultiService):
     def read_config(self):
         self.error_about_old_config_files()
         self.config = ConfigParser.SafeConfigParser()
-        self.config.read([os.path.join(self.basedir, "tahoe.cfg")])
+
+        tahoe_cfg = os.path.join(self.basedir, "tahoe.cfg")
+        try:
+            f = open(tahoe_cfg, "rb")
+            try:
+                # Skip any initial Byte Order Mark. Since this is an ordinary file, we
+                # don't need to handle incomplete reads, and can assume seekability.
+                if f.read(3) != '\xEF\xBB\xBF':
+                    f.seek(0)
+                self.config.readfp(f)
+            finally:
+                f.close()
+        except EnvironmentError:
+            if os.path.exists(tahoe_cfg):
+                raise
+
+        cfg_tubport = self.get_config("node", "tub.port", "")
+        if not cfg_tubport:
+            # For 'tub.port', tahoe.cfg overrides the individual file on
+            # disk. So only read self._portnumfile if tahoe.cfg doesn't
+            # provide a value.
+            try:
+                file_tubport = fileutil.read(self._portnumfile).strip()
+                self.set_config("node", "tub.port", file_tubport)
+            except EnvironmentError:
+                if os.path.exists(self._portnumfile):
+                    raise
 
     def error_about_old_config_files(self):
         """ If any old configuration files are detected, raise OldConfigError. """
@@ -128,10 +183,11 @@ class Node(service.MultiService):
             if name not in self.GENERATED_FILES:
                 fullfname = os.path.join(self.basedir, name)
                 if os.path.exists(fullfname):
-                    log.err("Found pre-Tahoe-LAFS-v1.3 configuration file: '%s'. See docs/historical/configuration.rst." % (fullfname,))
                     oldfnames.add(fullfname)
         if oldfnames:
-            raise OldConfigError(oldfnames)
+            e = OldConfigError(oldfnames)
+            twlog.msg(e)
+            raise e
 
     def create_tub(self):
         certfile = os.path.join(self.basedir, "private", self.CERTFILE)
@@ -163,40 +219,76 @@ class Node(service.MultiService):
     def setup_ssh(self):
         ssh_port = self.get_config("node", "ssh.port", "")
         if ssh_port:
-            ssh_keyfile = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
+            ssh_keyfile_config = self.get_config("node", "ssh.authorized_keys_file").decode('utf-8')
+            ssh_keyfile = abspath_expanduser_unicode(ssh_keyfile_config, base=self.basedir)
             from allmydata import manhole
-            m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile.encode(get_filesystem_encoding()))
+            m = manhole.AuthorizedKeysManhole(ssh_port, ssh_keyfile)
             m.setServiceParent(self)
-            self.log("AuthorizedKeysManhole listening on %s" % ssh_port)
+            self.log("AuthorizedKeysManhole listening on %s" % (ssh_port,))
 
     def get_app_versions(self):
         # TODO: merge this with allmydata.get_package_versions
         return dict(app_versions.versions)
 
+    def get_config_from_file(self, name, required=False):
+        """Get the (string) contents of a config file, or None if the file
+        did not exist. If required=True, raise an exception rather than
+        returning None. Any leading or trailing whitespace will be stripped
+        from the data."""
+        fn = os.path.join(self.basedir, name)
+        try:
+            return fileutil.read(fn).strip()
+        except EnvironmentError:
+            if not required:
+                return None
+            raise
+
     def write_private_config(self, name, value):
         """Write the (string) contents of a private config file (which is a
         config file that resides within the subdirectory named 'private'), and
-        return it. Any leading or trailing whitespace will be stripped from
-        the data.
+        return it.
         """
         privname = os.path.join(self.basedir, "private", name)
-        open(privname, "w").write(value.strip())
+        open(privname, "w").write(value)
 
-    def get_or_create_private_config(self, name, default):
+    def get_private_config(self, name, default=_None):
+        """Read the (string) contents of a private config file (which is a
+        config file that resides within the subdirectory named 'private'),
+        and return it. Return a default, or raise an error if one was not
+        given.
+        """
+        privname = os.path.join(self.basedir, "private", name)
+        try:
+            return fileutil.read(privname)
+        except EnvironmentError:
+            if os.path.exists(privname):
+                raise
+            if default is _None:
+                raise MissingConfigEntry("The required configuration file %s is missing."
+                                         % (quote_output(privname),))
+            return default
+
+    def get_or_create_private_config(self, name, default=_None):
         """Try to get the (string) contents of a private config file (which
         is a config file that resides within the subdirectory named
         'private'), and return it. Any leading or trailing whitespace will be
         stripped from the data.
 
-        If the file does not exist, try to create it using default, and
-        then return the value that was written. If 'default' is a string,
-        use it as a default value. If not, treat it as a 0-argument callable
-        which is expected to return a string.
+        If the file does not exist, and default is not given, report an error.
+        If the file does not exist and a default is specified, try to create
+        it using that default, and then return the value that was written.
+        If 'default' is a string, use it as a default value. If not, treat it
+        as a zero-argument callable that is expected to return a string.
         """
         privname = os.path.join(self.basedir, "private", name)
         try:
             value = fileutil.read(privname)
         except EnvironmentError:
+            if os.path.exists(privname):
+                raise
+            if default is _None:
+                raise MissingConfigEntry("The required configuration file %s is missing."
+                                         % (quote_output(privname),))
             if isinstance(default, basestring):
                 value = default
             else:
@@ -208,7 +300,7 @@ class Node(service.MultiService):
         """Write a string to a config file."""
         fn = os.path.join(self.basedir, name)
         try:
-            open(fn, mode).write(value)
+            fileutil.write(fn, value, mode)
         except EnvironmentError, e:
             self.log("Unable to write config file '%s'" % fn)
             self.log(e)
@@ -234,7 +326,6 @@ class Node(service.MultiService):
 
         service.MultiService.startService(self)
         d = defer.succeed(None)
-        d.addCallback(lambda res: iputil.get_local_addresses_async())
         d.addCallback(self._setup_tub)
         def _ready(res):
             self.log("%s running" % self.NODETYPE)
@@ -290,30 +381,45 @@ class Node(service.MultiService):
             self.tub.setOption("log-gatherer-furl", lgfurl)
         self.tub.setOption("log-gatherer-furlfile",
                            os.path.join(self.basedir, "log_gatherer.furl"))
-        self.tub.setOption("bridge-twisted-logs", True)
+
         incident_dir = os.path.join(self.basedir, "logs", "incidents")
-        # this doesn't quite work yet: unit tests fail
-        foolscap.logging.log.setLogDir(incident_dir)
+        foolscap.logging.log.setLogDir(incident_dir.encode(get_filesystem_encoding()))
 
     def log(self, *args, **kwargs):
         return log.msg(*args, **kwargs)
 
-    def _setup_tub(self, local_addresses):
+    def _setup_tub(self, ign):
         # we can't get a dynamically-assigned portnum until our Tub is
         # running, which means after startService.
         l = self.tub.getListeners()[0]
         portnum = l.getPortnum()
         # record which port we're listening on, so we can grab the same one
         # next time
-        open(self._portnumfile, "w").write("%d\n" % portnum)
-
-        base_location = ",".join([ "%s:%d" % (addr, portnum)
-                                   for addr in local_addresses ])
-        location = self.get_config("node", "tub.location", base_location)
-        self.log("Tub location set to %s" % location)
-        self.tub.setLocation(location)
-
-        return self.tub
+        fileutil.write_atomically(self._portnumfile, "%d\n" % portnum, mode="")
+
+        location = self.get_config("node", "tub.location", "AUTO")
+
+        # Replace the location "AUTO", if present, with the detected local addresses.
+        split_location = location.split(",")
+        if "AUTO" in split_location:
+            d = iputil.get_local_addresses_async()
+            def _add_local(local_addresses):
+                while "AUTO" in split_location:
+                    split_location.remove("AUTO")
+
+                split_location.extend([ "%s:%d" % (addr, portnum)
+                                        for addr in local_addresses ])
+                return ",".join(split_location)
+            d.addCallback(_add_local)
+        else:
+            d = defer.succeed(location)
+
+        def _got_location(location):
+            self.log("Tub location set to %s" % (location,))
+            self.tub.setLocation(location)
+            return self.tub
+        d.addCallback(_got_location)
+        return d
 
     def when_tub_ready(self):
         return self._tub_ready_observerlist.when_fired()